from transformers import RobertaConfig from transformers.modeling_outputs import TokenClassifierOutput from transformers.models.roberta.modeling_roberta import RobertaModel from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel import torch.nn as nn import torch class LabelRobertaForTokenClassification(RobertaPreTrainedModel): config_class = RobertaConfig def __init__(self, config): super().__init__(config) self.roberta = RobertaModel(config, add_pooling_layer=False) self.num_labels = 33 self.hidden = nn.Linear(768*2, 768) self.relu = nn.ReLU() self.out = nn.Linear(768, self.num_labels) self.loss_fct = nn.CrossEntropyLoss() def batched_index_select(self, input, dim, index): views = [input.shape[0]] + \ [1 if i != dim else -1 for i in range(1, len(input.shape))] expanse = list(input.shape) expanse[0] = -1 expanse[dim] = -1 index = index.view(views).expand(expanse) return torch.gather(input, dim, index) def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, labels=None, **kwargs): loss = 0.0 output = self.roberta(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)[0] batch_size, seq_len, _ = output.size() logits = [] for i in range(seq_len): current_token = output[:, i, :] connected_with_index = kwargs["head_labels"][:, i] connected_with_index[connected_with_index==-100] = 0 connected_with_embedding = self.batched_index_select(output.clone(), 1, connected_with_index.clone()) combined_embeddings = torch.cat((current_token, connected_with_embedding.squeeze(1)), -1) pred = self.out(self.relu(self.hidden(combined_embeddings))) pred = pred.view(-1, self.num_labels) logits.append(pred) if labels is not None: current_loss = self.loss_fct(pred, labels[:, i].view(-1)) if not torch.all(labels[:, i] == -100): loss += current_loss loss = loss/seq_len logits = torch.stack(logits, dim=1) output = TokenClassifierOutput(loss=loss, logits=logits) return output