PoseCrafts-API / model.py
Supeem's picture
test11
8c48c68
import torch.nn as nn
class LSTM(nn.Module):
def __init__(self, embedding_dim, hidden_dim, num_layers, output_dim):
super(LSTM, self).__init__()
self.lstm1 = nn.LSTM(embedding_dim, hidden_dim, num_layers, batch_first=True)
self.lstm2 = nn.LSTM(hidden_dim, hidden_dim, num_layers, batch_first=True)
self.lstm3 = nn.LSTM(hidden_dim, hidden_dim, num_layers, batch_first=True)
self.lstm4 = nn.LSTM(hidden_dim, hidden_dim, num_layers, batch_first=True)
self.lstm5 = nn.LSTM(hidden_dim, hidden_dim, num_layers, batch_first=True)
self.o = nn.Linear(hidden_dim, output_dim)
def forward(self, embedding):
o_n1, (h_n1, c_n1) = self.lstm1(embedding)
o_n2, (h_n2, c_n2) = self.lstm2(o_n1, (h_n1, c_n1))
o_n3, (h_n3, c_n3) = self.lstm3(o_n2, (h_n2, c_n2))
o_n4, (h_n4, c_n4) = self.lstm4(o_n3, (h_n3, c_n3))
o_n5, (h_n5, c_n5) = self.lstm5(o_n4, (h_n4, c_n4))
output = self.o(o_n5)
return output