欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

使用pytorch的LSTM实现MNIST数据集分类任务

程序员文章站 2024-03-24 23:40:16
...

 使用pytorch的LSTM实现MNIST数据集分类任务

"""
__author__:shuangrui Guo
__description__:
"""
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader


class Rnn_LSTM(nn.Module):
    def __init__(self,input_dim,hidden_dim,n_layers,n_classes):
        super(Rnn_LSTM,self).__init__()
        self.n_layers = n_layers
        self.hidden_dim = hidden_dim
        self.lstm = nn.LSTM(input_dim,hidden_dim,n_layers,batch_first=True)
        self.classifier = nn.Linear(hidden_dim,n_classes)
    def forward(self,x):
        out,(h_n,c_n) = self.lstm(x)
        x = h_n[-1,:,:]
        x = self.classifier(x)
        return x

#训练与测试代码
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5],[0.5])
])

train_set = torchvision.datasets.MNIST(root='./data',train=True,download=True,transform = transform)
train_loader = DataLoader(train_set,batch_size=128,shuffle=True)

test_set = torchvision.datasets.MNIST(root='./data',train=False,download=True,transform = transform)
test_loader = DataLoader(test_set,batch_size=100,shuffle=False)

net = Rnn_LSTM(28,10,2,10)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(),lr=0.1,momentum=0.9)


#训练
def train(epoch):
    print(f'epoch:{epoch}')
    net.train()
    train_loss=0
    correct = 0
    total = 0
    for batch_index,(inputs,targets) in enumerate(train_loader):
        inputs,targets = inputs.to('cpu'),targets.to('cpu')
        optimizer.zero_grad()
        outputs = net(torch.squeeze(inputs,1))
        loss = criterion(outputs,targets)
        loss.backward()
        optimizer.step()
        train_loss +=loss.item()
        _,predicted = outputs.max(1)
        total+=targets.size(0)
        correct +=predicted.eq(targets).sum().item()
        print(batch_index,len(train_loader),'Loss: %.3f | Acc: %.3f%% (%d/%d)'%(train_loss/(batch_index+1),100*correct/total,correct,total))
def test(epoch):
    global best_acc
    #Sets the module in evaluation mode.
    #如果在自己的网络module里面使用到了BN(加速训练)和Dropout正则化
    #那么在推理(predict)阶段,你需要用到eval()方法,告诉模型“我要开始预测了,你把mode换一下“
    #这样你网络输出的预测结果才能与你的测试集数据相对应。
    net.eval()
    test_loss=0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx,(inputs,targets) in enumerate(test_loader):
            inputs,targets = inputs.to('cpu'),targets.to('cpu')
            outputs = net(torch.squeeze(inputs,1))
            loss = criterion(outputs,targets)
            test_loss+=loss.item()
            _,predicted = outputs.max(1)
            print(predicted)
            total +=targets.size(0)
            correct += predicted.eq(targets).sum().item()
            print(batch_idx,len(train_loader),'Loss: %.3f | Acc: %.3f%% (%d/%d)'%(test_loss/(batch_idx+1),100*correct/total,correct,total))


for epoch in range(200):
    train(epoch)
    test(epoch)