import torch import torch.nn as nn from multi_head_attention import MultiHeadAttention # Add this import from feedforward import FeedForward class TransformerBlock(nn.Module): def __init__(self, d_model, n_heads, ff_dim): super(TransformerBlock, self).__init__() self.attention = MultiHeadAttention(d_model, n_heads) self.ffn = FeedForward(d_model, ff_dim) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.dropout1 = nn.Dropout(0.1) self.dropout2 = nn.Dropout(0.1) def forward(self, x, mask=None): # Multi-head attention attn_out = self.attention(x, x, x, mask) x = self.norm1(x + self.dropout1(attn_out)) # Feedforward network ff_out = self.ffn(x) x = self.norm2(x + self.dropout2(ff_out)) return x