depenBERTa_labler_perseus / modeling_depenberta_labler.py
bowphs's picture
Update modeling_depenberta_labler.py
f77af38
from transformers import RobertaConfig
from transformers.modeling_outputs import TokenClassifierOutput
from transformers.models.roberta.modeling_roberta import RobertaModel
from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel
import torch.nn as nn
import torch
class LabelRobertaForTokenClassification(RobertaPreTrainedModel):
config_class = RobertaConfig
def __init__(self, config):
super().__init__(config)
self.roberta = RobertaModel(config, add_pooling_layer=False)
self.num_labels = 33
self.hidden = nn.Linear(768*2, 768)
self.relu = nn.ReLU()
self.out = nn.Linear(768, self.num_labels)
self.loss_fct = nn.CrossEntropyLoss()
def batched_index_select(self, input, dim, index):
views = [input.shape[0]] + \
[1 if i != dim else -1 for i in range(1, len(input.shape))]
expanse = list(input.shape)
expanse[0] = -1
expanse[dim] = -1
index = index.view(views).expand(expanse)
return torch.gather(input, dim, index)
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
labels=None, **kwargs):
loss = 0.0
output = self.roberta(input_ids, attention_mask=attention_mask,
token_type_ids=token_type_ids)[0]
batch_size, seq_len, _ = output.size()
logits = []
for i in range(seq_len):
current_token = output[:, i, :]
connected_with_index = kwargs["head_labels"][:, i]
connected_with_index[connected_with_index==-100] = 0
connected_with_embedding = self.batched_index_select(output.clone(), 1, connected_with_index.clone())
combined_embeddings = torch.cat((current_token, connected_with_embedding.squeeze(1)), -1)
pred = self.out(self.relu(self.hidden(combined_embeddings)))
pred = pred.view(-1, self.num_labels)
logits.append(pred)
if labels is not None:
current_loss = self.loss_fct(pred, labels[:, i].view(-1))
if not torch.all(labels[:, i] == -100):
loss += current_loss
loss = loss/seq_len
logits = torch.stack(logits, dim=1)
output = TokenClassifierOutput(loss=loss, logits=logits)
return output