multimodal_nli_model / modeling_mmnli.py
oist's picture
Initial commit of MMNLI model with LFS
6ffd2bc
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Optional
from torch import Tensor
from transformers import PretrainedConfig, PreTrainedModel
# ---------------- CONFIG ---------------- #
class MMNLIConfig(PretrainedConfig):
model_type = "mmnli"
def __init__(
self,
embedding_dim: int = 1024,
hidden_dims: Optional[List[int]] = None,
dropout: float = 0.1,
activation: str = "TANH",
norm_emb: bool = True,
**kwargs,
):
super().__init__(**kwargs)
self.embedding_dim = embedding_dim
self.hidden_dims = hidden_dims if hidden_dims is not None else [3072, 1536]
self.dropout = dropout
self.activation = activation
self.norm_emb = norm_emb
self.output_dim = 3 # entailment, contradiction, neutral
# ---------------- CORE MODEL ---------------- #
ACTIVATIONS = {"TANH": nn.Tanh, "RELU": nn.ReLU}
class MMNLICore(nn.Module):
def __init__(
self,
embedding_dim: int,
hidden_dims: List[int],
dropout: float,
activation: str,
norm_emb: bool,
):
super().__init__()
self.norm_emb = norm_emb
if activation not in ACTIVATIONS:
raise ValueError(f"Unrecognized activation: {activation}")
# Input: concatenation of [p, h, p*h, |p-h|] => 4 * embedding_dim
input_dim = embedding_dim * 4
modules: List[nn.Module] = []
if dropout > 0:
modules.append(nn.Dropout(p=dropout))
nprev = input_dim
for h in hidden_dims:
modules.append(nn.Linear(nprev, h))
modules.append(ACTIVATIONS[activation]())
if dropout > 0:
modules.append(nn.Dropout(p=dropout))
nprev = h
# Final classifier layer: 3-way softmax
modules.append(nn.Linear(nprev, 3))
modules.append(nn.Softmax(dim=-1))
self.mlp = nn.Sequential(*modules)
def _norm(self, emb: Optional[Tensor]) -> Optional[Tensor]:
return F.normalize(emb) if (emb is not None and self.norm_emb) else emb
def featurize(self, premise: Tensor, hypothesis: Tensor) -> Tensor:
return torch.cat(
[premise, hypothesis, premise * hypothesis, torch.abs(premise - hypothesis)],
dim=-1,
)
# ---------------- HF MODEL WRAPPER ---------------- #
class MMNLIModel(PreTrainedModel):
config_class = MMNLIConfig
def __init__(self, config: MMNLIConfig):
super().__init__(config)
self.core = MMNLICore(
embedding_dim=config.embedding_dim,
hidden_dims=config.hidden_dims,
dropout=config.dropout,
activation=config.activation,
norm_emb=config.norm_emb,
)
def forward(self, premise: Tensor, hypothesis: Tensor):
premise = self.core._norm(premise)
hypothesis = self.core._norm(hypothesis)
proc = self.core.featurize(premise, hypothesis)
return self.core.mlp(proc)