import torch.nn as nn import torch class TransformerBlock(nn.Module): def __init__(self, d_model, n_heads, ff_dim): super().__init__() self.attention = nn.MultiheadAttention(d_model, n_heads, batch_first=True) self.ff = nn.Sequential( nn.Linear(d_model, ff_dim), nn.ReLU(), nn.Linear(ff_dim, d_model), ) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) def forward(self, x): attn_output, _ = self.attention(x, x, x) x = self.norm1(x + attn_output) x = self.norm2(x + self.ff(x)) return x class TransformerModel(nn.Module): def __init__(self, vocab_size, d_model, n_heads, n_layers, max_len): super().__init__() self.embedding = nn.Embedding(vocab_size, d_model) self.pos_embedding = nn.Parameter(torch.randn(1, max_len, d_model)) self.transformer_blocks = nn.ModuleList([ TransformerBlock(d_model, n_heads, ff_dim=4*d_model) for _ in range(n_layers) ]) self.output = nn.Linear(d_model, vocab_size) def forward(self, x): x = self.embedding(x) + self.pos_embedding[:, :x.size(1), :] for block in self.transformer_blocks: x = block(x) return self.output(x)