File size: 1,004 Bytes
8c48c68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch.nn as nn

class LSTM(nn.Module):
    def __init__(self, embedding_dim, hidden_dim, num_layers, output_dim):
        super(LSTM, self).__init__()
        self.lstm1 = nn.LSTM(embedding_dim, hidden_dim, num_layers, batch_first=True)
        self.lstm2 = nn.LSTM(hidden_dim, hidden_dim, num_layers, batch_first=True)
        self.lstm3 = nn.LSTM(hidden_dim, hidden_dim, num_layers, batch_first=True)
        self.lstm4 = nn.LSTM(hidden_dim, hidden_dim, num_layers, batch_first=True)
        self.lstm5 = nn.LSTM(hidden_dim, hidden_dim, num_layers, batch_first=True)
        self.o = nn.Linear(hidden_dim, output_dim)

    def forward(self, embedding):
        o_n1, (h_n1, c_n1) = self.lstm1(embedding)
        o_n2, (h_n2, c_n2) = self.lstm2(o_n1, (h_n1, c_n1))
        o_n3, (h_n3, c_n3) = self.lstm3(o_n2, (h_n2, c_n2))
        o_n4, (h_n4, c_n4) = self.lstm4(o_n3, (h_n3, c_n3))
        o_n5, (h_n5, c_n5) = self.lstm5(o_n4, (h_n4, c_n4))
        output = self.o(o_n5)
        return output