Update model.py
Browse files
model.py
CHANGED
|
@@ -110,6 +110,10 @@ class LidirlCNN(PreTrainedModel):
|
|
| 110 |
self.multilabel = config.multilabel
|
| 111 |
self.monte_carlo = config.montecarlo_layer
|
| 112 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
|
| 114 |
def forward(self, inputs, lengths):
|
| 115 |
inputs = inputs[:, :self.max_length]
|
|
@@ -123,9 +127,29 @@ class LidirlCNN(PreTrainedModel):
|
|
| 123 |
return projection
|
| 124 |
|
| 125 |
def __call__(self, inputs, lengths):
|
| 126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
if self.multilabel:
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
self.multilabel = config.multilabel
|
| 111 |
self.monte_carlo = config.montecarlo_layer
|
| 112 |
|
| 113 |
+
self.labels = ["" for _ in config.labels]
|
| 114 |
+
for key, value in config.labels.items():
|
| 115 |
+
self.labels[value] = key
|
| 116 |
+
|
| 117 |
|
| 118 |
def forward(self, inputs, lengths):
|
| 119 |
inputs = inputs[:, :self.max_length]
|
|
|
|
| 127 |
return projection
|
| 128 |
|
| 129 |
def __call__(self, inputs, lengths):
|
| 130 |
+
# this is inference only model
|
| 131 |
+
with torch.no_grad():
|
| 132 |
+
logits = self.forward(inputs, lengths)
|
| 133 |
+
if self.multilabel:
|
| 134 |
+
probs = torch.sigmoid(logits)
|
| 135 |
+
else:
|
| 136 |
+
probs = torch.softmax(logits, dim=-1)
|
| 137 |
+
return probs
|
| 138 |
+
|
| 139 |
+
def predict(self, inputs, lengths, threshold=0.5):
|
| 140 |
+
probs = self.__call__(inputs, lengths)
|
| 141 |
if self.multilabel:
|
| 142 |
+
batch_idx, label_idx = torch.where(probs > threshold)
|
| 143 |
+
output = []
|
| 144 |
+
for batch, label in zip(batch_idx, label_idx):
|
| 145 |
+
if len(output) < batch.item():
|
| 146 |
+
output.append([])
|
| 147 |
+
label_string = self.labels
|
| 148 |
+
output[-1].append(
|
| 149 |
+
(self.labels[label.item()], probs[batch, label])
|
| 150 |
+
)
|
| 151 |
+
return output
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
|