|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from typing import List, Optional |
|
|
from torch import Tensor |
|
|
from transformers import PretrainedConfig, PreTrainedModel |
|
|
|
|
|
|
|
|
class MMNLIConfig(PretrainedConfig): |
|
|
model_type = "mmnli" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
embedding_dim: int = 1024, |
|
|
hidden_dims: Optional[List[int]] = None, |
|
|
dropout: float = 0.1, |
|
|
activation: str = "TANH", |
|
|
norm_emb: bool = True, |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__(**kwargs) |
|
|
self.embedding_dim = embedding_dim |
|
|
self.hidden_dims = hidden_dims if hidden_dims is not None else [3072, 1536] |
|
|
self.dropout = dropout |
|
|
self.activation = activation |
|
|
self.norm_emb = norm_emb |
|
|
self.output_dim = 3 |
|
|
|
|
|
|
|
|
|
|
|
ACTIVATIONS = {"TANH": nn.Tanh, "RELU": nn.ReLU} |
|
|
|
|
|
|
|
|
class MMNLICore(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
embedding_dim: int, |
|
|
hidden_dims: List[int], |
|
|
dropout: float, |
|
|
activation: str, |
|
|
norm_emb: bool, |
|
|
): |
|
|
super().__init__() |
|
|
self.norm_emb = norm_emb |
|
|
|
|
|
if activation not in ACTIVATIONS: |
|
|
raise ValueError(f"Unrecognized activation: {activation}") |
|
|
|
|
|
|
|
|
input_dim = embedding_dim * 4 |
|
|
|
|
|
modules: List[nn.Module] = [] |
|
|
if dropout > 0: |
|
|
modules.append(nn.Dropout(p=dropout)) |
|
|
|
|
|
nprev = input_dim |
|
|
for h in hidden_dims: |
|
|
modules.append(nn.Linear(nprev, h)) |
|
|
modules.append(ACTIVATIONS[activation]()) |
|
|
if dropout > 0: |
|
|
modules.append(nn.Dropout(p=dropout)) |
|
|
nprev = h |
|
|
|
|
|
|
|
|
modules.append(nn.Linear(nprev, 3)) |
|
|
modules.append(nn.Softmax(dim=-1)) |
|
|
|
|
|
self.mlp = nn.Sequential(*modules) |
|
|
|
|
|
def _norm(self, emb: Optional[Tensor]) -> Optional[Tensor]: |
|
|
return F.normalize(emb) if (emb is not None and self.norm_emb) else emb |
|
|
|
|
|
def featurize(self, premise: Tensor, hypothesis: Tensor) -> Tensor: |
|
|
return torch.cat( |
|
|
[premise, hypothesis, premise * hypothesis, torch.abs(premise - hypothesis)], |
|
|
dim=-1, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
class MMNLIModel(PreTrainedModel): |
|
|
config_class = MMNLIConfig |
|
|
|
|
|
def __init__(self, config: MMNLIConfig): |
|
|
super().__init__(config) |
|
|
self.core = MMNLICore( |
|
|
embedding_dim=config.embedding_dim, |
|
|
hidden_dims=config.hidden_dims, |
|
|
dropout=config.dropout, |
|
|
activation=config.activation, |
|
|
norm_emb=config.norm_emb, |
|
|
) |
|
|
|
|
|
def forward(self, premise: Tensor, hypothesis: Tensor): |
|
|
premise = self.core._norm(premise) |
|
|
hypothesis = self.core._norm(hypothesis) |
|
|
proc = self.core.featurize(premise, hypothesis) |
|
|
return self.core.mlp(proc) |