File size: 3,062 Bytes
6ffd2bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Optional
from torch import Tensor
from transformers import PretrainedConfig, PreTrainedModel

# ---------------- CONFIG ---------------- #
class MMNLIConfig(PretrainedConfig):
    model_type = "mmnli"

    def __init__(
        self,
        embedding_dim: int = 1024,
        hidden_dims: Optional[List[int]] = None,
        dropout: float = 0.1,
        activation: str = "TANH",
        norm_emb: bool = True,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.embedding_dim = embedding_dim
        self.hidden_dims = hidden_dims if hidden_dims is not None else [3072, 1536]
        self.dropout = dropout
        self.activation = activation
        self.norm_emb = norm_emb
        self.output_dim = 3  # entailment, contradiction, neutral


# ---------------- CORE MODEL ---------------- #
ACTIVATIONS = {"TANH": nn.Tanh, "RELU": nn.ReLU}


class MMNLICore(nn.Module):
    def __init__(
        self,
        embedding_dim: int,
        hidden_dims: List[int],
        dropout: float,
        activation: str,
        norm_emb: bool,
    ):
        super().__init__()
        self.norm_emb = norm_emb

        if activation not in ACTIVATIONS:
            raise ValueError(f"Unrecognized activation: {activation}")

        # Input: concatenation of [p, h, p*h, |p-h|] => 4 * embedding_dim
        input_dim = embedding_dim * 4

        modules: List[nn.Module] = []
        if dropout > 0:
            modules.append(nn.Dropout(p=dropout))

        nprev = input_dim
        for h in hidden_dims:
            modules.append(nn.Linear(nprev, h))
            modules.append(ACTIVATIONS[activation]())
            if dropout > 0:
                modules.append(nn.Dropout(p=dropout))
            nprev = h

        # Final classifier layer: 3-way softmax
        modules.append(nn.Linear(nprev, 3))
        modules.append(nn.Softmax(dim=-1))

        self.mlp = nn.Sequential(*modules)

    def _norm(self, emb: Optional[Tensor]) -> Optional[Tensor]:
        return F.normalize(emb) if (emb is not None and self.norm_emb) else emb

    def featurize(self, premise: Tensor, hypothesis: Tensor) -> Tensor:
        return torch.cat(
            [premise, hypothesis, premise * hypothesis, torch.abs(premise - hypothesis)],
            dim=-1,
        )


# ---------------- HF MODEL WRAPPER ---------------- #
class MMNLIModel(PreTrainedModel):
    config_class = MMNLIConfig

    def __init__(self, config: MMNLIConfig):
        super().__init__(config)
        self.core = MMNLICore(
            embedding_dim=config.embedding_dim,
            hidden_dims=config.hidden_dims,
            dropout=config.dropout,
            activation=config.activation,
            norm_emb=config.norm_emb,
        )

    def forward(self, premise: Tensor, hypothesis: Tensor):
        premise = self.core._norm(premise)
        hypothesis = self.core._norm(hypothesis)
        proc = self.core.featurize(premise, hypothesis)
        return self.core.mlp(proc)