Pytorch LSTM
程序员文章站
2024-03-24 23:06:28
...
#定义模型
class rnn_classify(nn.Module):
def __init__(self, in_feature=28, hidden_feature=100, num_class=10, num_layers=2):
super(rnn_classify, self).__init__()
self.rnn = nn.LSTM(in_feature, hidden_feature, num_layers) #使用两层lstm
self.classifier = nn.Linear(hidden_feature, num_class) #将最后一个的rnn使用全连接的到最后的输出结果
def forward(self, x):
#x的大小为(batch,1,28,28),所以我们需要将其转化为rnn的输入格式(28,batch,28)
x = x.squeeze() #去掉(batch,1, 28,28)中的1,变成(batch,28, 28)
x = x.permute(2, 0, 1) #将最后一维放到第一维,变成(28, batch,28)
out, _ = self.rnn(x) #使用默认的隐藏状态,得到的out是(28, batch, hidden_feature)
out = out[-1,:,:] #取序列中的最后一个,大小是(batch, hidden_feature)
out = self.classifier(out) #得到分类结果
return out
net = rnn_classify()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adadelta(net.parameters(), 1e-1)
#author: www
squeeze()
上一篇: 数学建模笔记-斜抛运动建模