LSTM, GRU
程序员文章站
2024-03-24 23:57:46
...
《越狱》 * Break 1-4季 片头
class RegLSTM(nn.Module):
def __init__(self, inp_dim, out_dim, mid_dim, mid_layers):
super(RegLSTM, self).__init__()
self.rnn = nn.LSTM(inp_dim, mid_dim, mid_layers)
self.reg = nn.Sequential(nn.Linear(mid_dim, out_dim),)
def forward(self, x):
x, (h, c) = self.rnn(x) # (seq, batch, hidden)
seq_len, batch_size, hid_dim = x.shape
x = x.view(-1, hid_dim)
x = self.reg(x)
x = x.view(seq_len, batch_size, -1)
return x
class RegGRU(nn.Module):
def __init__(self, inp_dim, out_