File size: 3,062 Bytes
6ffd2bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Optional
from torch import Tensor
from transformers import PretrainedConfig, PreTrainedModel
# ---------------- CONFIG ---------------- #
class MMNLIConfig(PretrainedConfig):
model_type = "mmnli"
def __init__(
self,
embedding_dim: int = 1024,
hidden_dims: Optional[List[int]] = None,
dropout: float = 0.1,
activation: str = "TANH",
norm_emb: bool = True,
**kwargs,
):
super().__init__(**kwargs)
self.embedding_dim = embedding_dim
self.hidden_dims = hidden_dims if hidden_dims is not None else [3072, 1536]
self.dropout = dropout
self.activation = activation
self.norm_emb = norm_emb
self.output_dim = 3 # entailment, contradiction, neutral
# ---------------- CORE MODEL ---------------- #
ACTIVATIONS = {"TANH": nn.Tanh, "RELU": nn.ReLU}
class MMNLICore(nn.Module):
def __init__(
self,
embedding_dim: int,
hidden_dims: List[int],
dropout: float,
activation: str,
norm_emb: bool,
):
super().__init__()
self.norm_emb = norm_emb
if activation not in ACTIVATIONS:
raise ValueError(f"Unrecognized activation: {activation}")
# Input: concatenation of [p, h, p*h, |p-h|] => 4 * embedding_dim
input_dim = embedding_dim * 4
modules: List[nn.Module] = []
if dropout > 0:
modules.append(nn.Dropout(p=dropout))
nprev = input_dim
for h in hidden_dims:
modules.append(nn.Linear(nprev, h))
modules.append(ACTIVATIONS[activation]())
if dropout > 0:
modules.append(nn.Dropout(p=dropout))
nprev = h
# Final classifier layer: 3-way softmax
modules.append(nn.Linear(nprev, 3))
modules.append(nn.Softmax(dim=-1))
self.mlp = nn.Sequential(*modules)
def _norm(self, emb: Optional[Tensor]) -> Optional[Tensor]:
return F.normalize(emb) if (emb is not None and self.norm_emb) else emb
def featurize(self, premise: Tensor, hypothesis: Tensor) -> Tensor:
return torch.cat(
[premise, hypothesis, premise * hypothesis, torch.abs(premise - hypothesis)],
dim=-1,
)
# ---------------- HF MODEL WRAPPER ---------------- #
class MMNLIModel(PreTrainedModel):
config_class = MMNLIConfig
def __init__(self, config: MMNLIConfig):
super().__init__(config)
self.core = MMNLICore(
embedding_dim=config.embedding_dim,
hidden_dims=config.hidden_dims,
dropout=config.dropout,
activation=config.activation,
norm_emb=config.norm_emb,
)
def forward(self, premise: Tensor, hypothesis: Tensor):
premise = self.core._norm(premise)
hypothesis = self.core._norm(hypothesis)
proc = self.core.featurize(premise, hypothesis)
return self.core.mlp(proc) |