rewicks commited on
Commit
14ce8c9
·
verified ·
1 Parent(s): 761f422

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +29 -5
model.py CHANGED
@@ -110,6 +110,10 @@ class LidirlCNN(PreTrainedModel):
110
  self.multilabel = config.multilabel
111
  self.monte_carlo = config.montecarlo_layer
112
 
 
 
 
 
113
 
114
  def forward(self, inputs, lengths):
115
  inputs = inputs[:, :self.max_length]
@@ -123,9 +127,29 @@ class LidirlCNN(PreTrainedModel):
123
  return projection
124
 
125
  def __call__(self, inputs, lengths):
126
- logits = self.forward(inputs, lengths)
 
 
 
 
 
 
 
 
 
 
127
  if self.multilabel:
128
- probs = torch.sigmoid(logits)
129
- else:
130
- probs = torch.softmax(logits, dim=-1)
131
- return probs
 
 
 
 
 
 
 
 
 
 
 
110
  self.multilabel = config.multilabel
111
  self.monte_carlo = config.montecarlo_layer
112
 
113
+ self.labels = ["" for _ in config.labels]
114
+ for key, value in config.labels.items():
115
+ self.labels[value] = key
116
+
117
 
118
  def forward(self, inputs, lengths):
119
  inputs = inputs[:, :self.max_length]
 
127
  return projection
128
 
129
  def __call__(self, inputs, lengths):
130
+ # this is inference only model
131
+ with torch.no_grad():
132
+ logits = self.forward(inputs, lengths)
133
+ if self.multilabel:
134
+ probs = torch.sigmoid(logits)
135
+ else:
136
+ probs = torch.softmax(logits, dim=-1)
137
+ return probs
138
+
139
+ def predict(self, inputs, lengths, threshold=0.5):
140
+ probs = self.__call__(inputs, lengths)
141
  if self.multilabel:
142
+ batch_idx, label_idx = torch.where(probs > threshold)
143
+ output = []
144
+ for batch, label in zip(batch_idx, label_idx):
145
+ if len(output) < batch.item():
146
+ output.append([])
147
+ label_string = self.labels
148
+ output[-1].append(
149
+ (self.labels[label.item()], probs[batch, label])
150
+ )
151
+ return output
152
+
153
+
154
+
155
+