| from transformers import RobertaConfig | |
| from transformers.modeling_outputs import TokenClassifierOutput | |
| from transformers.models.roberta.modeling_roberta import RobertaModel | |
| from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel | |
| import torch.nn as nn | |
| import torch | |
| class LabelRobertaForTokenClassification(RobertaPreTrainedModel): | |
| config_class = RobertaConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.roberta = RobertaModel(config, add_pooling_layer=False) | |
| self.num_labels = 33 | |
| self.hidden = nn.Linear(768*2, 768) | |
| self.relu = nn.ReLU() | |
| self.out = nn.Linear(768, self.num_labels) | |
| self.loss_fct = nn.CrossEntropyLoss() | |
| def batched_index_select(self, input, dim, index): | |
| views = [input.shape[0]] + \ | |
| [1 if i != dim else -1 for i in range(1, len(input.shape))] | |
| expanse = list(input.shape) | |
| expanse[0] = -1 | |
| expanse[dim] = -1 | |
| index = index.view(views).expand(expanse) | |
| return torch.gather(input, dim, index) | |
| def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, | |
| labels=None, **kwargs): | |
| loss = 0.0 | |
| output = self.roberta(input_ids, attention_mask=attention_mask, | |
| token_type_ids=token_type_ids)[0] | |
| batch_size, seq_len, _ = output.size() | |
| logits = [] | |
| for i in range(seq_len): | |
| current_token = output[:, i, :] | |
| connected_with_index = kwargs["head_labels"][:, i] | |
| connected_with_index[connected_with_index==-100] = 0 | |
| connected_with_embedding = self.batched_index_select(output.clone(), 1, connected_with_index.clone()) | |
| combined_embeddings = torch.cat((current_token, connected_with_embedding.squeeze(1)), -1) | |
| pred = self.out(self.relu(self.hidden(combined_embeddings))) | |
| pred = pred.view(-1, self.num_labels) | |
| logits.append(pred) | |
| if labels is not None: | |
| current_loss = self.loss_fct(pred, labels[:, i].view(-1)) | |
| if not torch.all(labels[:, i] == -100): | |
| loss += current_loss | |
| loss = loss/seq_len | |
| logits = torch.stack(logits, dim=1) | |
| output = TokenClassifierOutput(loss=loss, logits=logits) | |
| return output |