欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

手把手教你实现PyTorch的MNIST数据集

程序员文章站 2022-06-19 08:08:16
概述mnist 包含 0~9 的手写数字, 共有 60000 个训练集和 10000 个测试集. 数据的格式为单通道 28*28 的灰度图.获取数据def get_data(): """获取数据...

概述

mnist 包含 0~9 的手写数字, 共有 60000 个训练集和 10000 个测试集. 数据的格式为单通道 28*28 的灰度图.

手把手教你实现PyTorch的MNIST数据集

获取数据

def get_data():
    """获取数据"""

    # 获取测试集
    train = torchvision.datasets.mnist(root="./data", train=true, download=true,
                                       transform=torchvision.transforms.compose([
                                           torchvision.transforms.totensor(),  # 转换成张量
                                           torchvision.transforms.normalize((0.1307,), (0.3081,))  # 标准化
                                       ]))
    train_loader = dataloader(train, batch_size=batch_size)  # 分割测试集

    # 获取测试集
    test = torchvision.datasets.mnist(root="./data", train=false, download=true,
                                      transform=torchvision.transforms.compose([
                                          torchvision.transforms.totensor(),  # 转换成张量
                                          torchvision.transforms.normalize((0.1307,), (0.3081,))  # 标准化
                                      ]))
    test_loader = dataloader(test, batch_size=batch_size)  # 分割训练

    # 返回分割好的训练集和测试集
    return train_loader, test_loader

网络模型

class model(torch.nn.module):
    def __init__(self):
        super(model, self).__init__()

        # 卷积层
        self.conv1 = torch.nn.conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
        self.conv2 = torch.nn.conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))

        # dropout层
        self.dropout1 = torch.nn.dropout(0.25)
        self.dropout2 = torch.nn.dropout(0.5)

        # 全连接层
        self.fc1 = torch.nn.linear(9216, 128)
        self.fc2 = torch.nn.linear(128, 10)

    def forward(self, x):
        """前向传播"""
        
        # [b, 1, 28, 28] => [b, 32, 26, 26]
        out = self.conv1(x)
        out = f.relu(out)
        
        # [b, 32, 26, 26] => [b, 64, 24, 24]
        out = self.conv2(out)
        out = f.relu(out)

        # [b, 64, 24, 24] => [b, 64, 12, 12]
        out = f.max_pool2d(out, 2)
        out = self.dropout1(out)
        
        # [b, 64, 12, 12] => [b, 64 * 12 * 12] => [b, 9216]
        out = torch.flatten(out, 1)
        
        # [b, 9216] => [b, 128]
        out = self.fc1(out)
        out = f.relu(out)

        # [b, 128] => [b, 10]
        out = self.dropout2(out)
        out = self.fc2(out)

        output = f.log_softmax(out, dim=1)

        return output

train 函数

def train(model, epoch, train_loader):
    """训练"""

    # 训练模式
    model.train()

    # 迭代
    for step, (x, y) in enumerate(train_loader):
        # 加速
        if use_cuda:
            model = model.cuda()
            x, y = x.cuda(), y.cuda()

        # 梯度清零
        optimizer.zero_grad()

        output = model(x)

        # 计算损失
        loss = f.nll_loss(output, y)

        # 反向传播
        loss.backward()

        # 更新梯度
        optimizer.step()

        # 打印损失
        if step % 50 == 0:
            print('epoch: {}, step {}, loss: {}'.format(epoch, step, loss))

test 函数

def test(model, test_loader):
    """测试"""
    
    # 测试模式
    model.eval()

    # 存放正确个数
    correct = 0

    with torch.no_grad():
        for x, y in test_loader:

            # 加速
            if use_cuda:
                model = model.cuda()
                x, y = x.cuda(), y.cuda()

            # 获取结果
            output = model(x)

            # 预测结果
            pred = output.argmax(dim=1, keepdim=true)

            # 计算准确个数
            correct += pred.eq(y.view_as(pred)).sum().item()

    # 计算准确率
    accuracy = correct / len(test_loader.dataset) * 100

    # 输出准确
    print("test accuracy: {}%".format(accuracy))

main 函数

def main():
    # 获取数据
    train_loader, test_loader = get_data()
    
    # 迭代
    for epoch in range(iteration_num):
        print("\n================ epoch: {} ================".format(epoch))
        train(network, epoch, train_loader)
        test(network, test_loader)

完整代码:

import torch
import torchvision
import torch.nn.functional as f
from torch.utils.data import dataloader
class model(torch.nn.module):
    def __init__(self):
        super(model, self).__init__()

        # 卷积层
        self.conv1 = torch.nn.conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
        self.conv2 = torch.nn.conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))

        # dropout层
        self.dropout1 = torch.nn.dropout(0.25)
        self.dropout2 = torch.nn.dropout(0.5)

        # 全连接层
        self.fc1 = torch.nn.linear(9216, 128)
        self.fc2 = torch.nn.linear(128, 10)

    def forward(self, x):
        """前向传播"""
        
        # [b, 1, 28, 28] => [b, 32, 26, 26]
        out = self.conv1(x)
        out = f.relu(out)
        
        # [b, 32, 26, 26] => [b, 64, 24, 24]
        out = self.conv2(out)
        out = f.relu(out)

        # [b, 64, 24, 24] => [b, 64, 12, 12]
        out = f.max_pool2d(out, 2)
        out = self.dropout1(out)
        
        # [b, 64, 12, 12] => [b, 64 * 12 * 12] => [b, 9216]
        out = torch.flatten(out, 1)
        
        # [b, 9216] => [b, 128]
        out = self.fc1(out)
        out = f.relu(out)

        # [b, 128] => [b, 10]
        out = self.dropout2(out)
        out = self.fc2(out)

        output = f.log_softmax(out, dim=1)

        return output


# 定义超参数
batch_size = 64  # 一次训练的样本数目
learning_rate = 0.0001  # 学习率
iteration_num = 5  # 迭代次数
network = model()  # 实例化网络
print(network)  # 调试输出网络结构
optimizer = torch.optim.adam(network.parameters(), lr=learning_rate)  # 优化器

# gpu 加速
use_cuda = torch.cuda.is_available()
print("是否使用 gpu 加速:", use_cuda)


def get_data():
    """获取数据"""

    # 获取测试集
    train = torchvision.datasets.mnist(root="./data", train=true, download=true,
                                       transform=torchvision.transforms.compose([
                                           torchvision.transforms.totensor(),  # 转换成张量
                                           torchvision.transforms.normalize((0.1307,), (0.3081,))  # 标准化
                                       ]))
    train_loader = dataloader(train, batch_size=batch_size)  # 分割测试集

    # 获取测试集
    test = torchvision.datasets.mnist(root="./data", train=false, download=true,
                                      transform=torchvision.transforms.compose([
                                          torchvision.transforms.totensor(),  # 转换成张量
                                          torchvision.transforms.normalize((0.1307,), (0.3081,))  # 标准化
                                      ]))
    test_loader = dataloader(test, batch_size=batch_size)  # 分割训练

    # 返回分割好的训练集和测试集
    return train_loader, test_loader


def train(model, epoch, train_loader):
    """训练"""

    # 训练模式
    model.train()

    # 迭代
    for step, (x, y) in enumerate(train_loader):
        # 加速
        if use_cuda:
            model = model.cuda()
            x, y = x.cuda(), y.cuda()

        # 梯度清零
        optimizer.zero_grad()

        output = model(x)

        # 计算损失
        loss = f.nll_loss(output, y)

        # 反向传播
        loss.backward()

        # 更新梯度
        optimizer.step()

        # 打印损失
        if step % 50 == 0:
            print('epoch: {}, step {}, loss: {}'.format(epoch, step, loss))


def test(model, test_loader):
    """测试"""

    # 测试模式
    model.eval()

    # 存放正确个数
    correct = 0

    with torch.no_grad():
        for x, y in test_loader:

            # 加速
            if use_cuda:
                model = model.cuda()
                x, y = x.cuda(), y.cuda()

            # 获取结果
            output = model(x)

            # 预测结果
            pred = output.argmax(dim=1, keepdim=true)

            # 计算准确个数
            correct += pred.eq(y.view_as(pred)).sum().item()

    # 计算准确率
    accuracy = correct / len(test_loader.dataset) * 100

    # 输出准确
    print("test accuracy: {}%".format(accuracy))


def main():
    # 获取数据
    train_loader, test_loader = get_data()

    # 迭代
    for epoch in range(iteration_num):
        print("\n================ epoch: {} ================".format(epoch))
        train(network, epoch, train_loader)
        test(network, test_loader)

if __name__ == "__main__":
    main()

输出结果:

model(
  (conv1): conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
  (conv2): conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  (dropout1): dropout(p=0.25, inplace=false)
  (dropout2): dropout(p=0.5, inplace=false)
  (fc1): linear(in_features=9216, out_features=128, bias=true)
  (fc2): linear(in_features=128, out_features=10, bias=true)
)
是否使用 gpu 加速: true

================ epoch: 0 ================
epoch: 0, step 0, loss: 2.3131277561187744
epoch: 0, step 50, loss: 1.0419045686721802
epoch: 0, step 100, loss: 0.6259541511535645
epoch: 0, step 150, loss: 0.7194482684135437
epoch: 0, step 200, loss: 0.4020516574382782
epoch: 0, step 250, loss: 0.6890509128570557
epoch: 0, step 300, loss: 0.28660136461257935
epoch: 0, step 350, loss: 0.3277580738067627
epoch: 0, step 400, loss: 0.2750288248062134
epoch: 0, step 450, loss: 0.28428223729133606
epoch: 0, step 500, loss: 0.3514065444469452
epoch: 0, step 550, loss: 0.23386947810649872
epoch: 0, step 600, loss: 0.25338059663772583
epoch: 0, step 650, loss: 0.1743898093700409
epoch: 0, step 700, loss: 0.35752204060554504
epoch: 0, step 750, loss: 0.17575909197330475
epoch: 0, step 800, loss: 0.20604261755943298
epoch: 0, step 850, loss: 0.17389622330665588
epoch: 0, step 900, loss: 0.3188241124153137
test accuracy: 96.56%

================ epoch: 1 ================
epoch: 1, step 0, loss: 0.23558208346366882
epoch: 1, step 50, loss: 0.13511177897453308
epoch: 1, step 100, loss: 0.18823786079883575
epoch: 1, step 150, loss: 0.2644936144351959
epoch: 1, step 200, loss: 0.145077645778656
epoch: 1, step 250, loss: 0.30574971437454224
epoch: 1, step 300, loss: 0.2386859953403473
epoch: 1, step 350, loss: 0.08346735686063766
epoch: 1, step 400, loss: 0.10480977594852448
epoch: 1, step 450, loss: 0.07280707359313965
epoch: 1, step 500, loss: 0.20928426086902618
epoch: 1, step 550, loss: 0.20455852150917053
epoch: 1, step 600, loss: 0.10085935145616531
epoch: 1, step 650, loss: 0.13476189970970154
epoch: 1, step 700, loss: 0.19087043404579163
epoch: 1, step 750, loss: 0.0981522724032402
epoch: 1, step 800, loss: 0.1961515098810196
epoch: 1, step 850, loss: 0.041140712797641754
epoch: 1, step 900, loss: 0.250461220741272
test accuracy: 98.03%

================ epoch: 2 ================
epoch: 2, step 0, loss: 0.09572553634643555
epoch: 2, step 50, loss: 0.10370486229658127
epoch: 2, step 100, loss: 0.17737184464931488
epoch: 2, step 150, loss: 0.1570713371038437
epoch: 2, step 200, loss: 0.07462178170681
epoch: 2, step 250, loss: 0.18744900822639465
epoch: 2, step 300, loss: 0.09910508990287781
epoch: 2, step 350, loss: 0.08929706364870071
epoch: 2, step 400, loss: 0.07703761011362076
epoch: 2, step 450, loss: 0.10133732110261917
epoch: 2, step 500, loss: 0.1314031481742859
epoch: 2, step 550, loss: 0.10394387692213058
epoch: 2, step 600, loss: 0.11612939089536667
epoch: 2, step 650, loss: 0.17494803667068481
epoch: 2, step 700, loss: 0.11065669357776642
epoch: 2, step 750, loss: 0.061209067702293396
epoch: 2, step 800, loss: 0.14715790748596191
epoch: 2, step 850, loss: 0.03930797800421715
epoch: 2, step 900, loss: 0.18030673265457153
test accuracy: 98.46000000000001%

================ epoch: 3 ================
epoch: 3, step 0, loss: 0.09266342222690582
epoch: 3, step 50, loss: 0.0414913073182106
epoch: 3, step 100, loss: 0.2152961939573288
epoch: 3, step 150, loss: 0.12287424504756927
epoch: 3, step 200, loss: 0.13468700647354126
epoch: 3, step 250, loss: 0.11967387050390244
epoch: 3, step 300, loss: 0.11301510035991669
epoch: 3, step 350, loss: 0.037447575479745865
epoch: 3, step 400, loss: 0.04699449613690376
epoch: 3, step 450, loss: 0.05472381412982941
epoch: 3, step 500, loss: 0.09839300811290741
epoch: 3, step 550, loss: 0.07964356243610382
epoch: 3, step 600, loss: 0.08182843774557114
epoch: 3, step 650, loss: 0.05514759197831154
epoch: 3, step 700, loss: 0.13785190880298615
epoch: 3, step 750, loss: 0.062480345368385315
epoch: 3, step 800, loss: 0.120387002825737
epoch: 3, step 850, loss: 0.04458726942539215
epoch: 3, step 900, loss: 0.17119190096855164
test accuracy: 98.55000000000001%

================ epoch: 4 ================
epoch: 4, step 0, loss: 0.08094145357608795
epoch: 4, step 50, loss: 0.05615215748548508
epoch: 4, step 100, loss: 0.07766406238079071
epoch: 4, step 150, loss: 0.07915271818637848
epoch: 4, step 200, loss: 0.1301635503768921
epoch: 4, step 250, loss: 0.12118984013795853
epoch: 4, step 300, loss: 0.073218435049057
epoch: 4, step 350, loss: 0.04517696052789688
epoch: 4, step 400, loss: 0.08493026345968246
epoch: 4, step 450, loss: 0.03904269263148308
epoch: 4, step 500, loss: 0.09386837482452393
epoch: 4, step 550, loss: 0.12583576142787933
epoch: 4, step 600, loss: 0.09053893387317657
epoch: 4, step 650, loss: 0.06912104040384293
epoch: 4, step 700, loss: 0.1502612829208374
epoch: 4, step 750, loss: 0.07162325084209442
epoch: 4, step 800, loss: 0.10512275993824005
epoch: 4, step 850, loss: 0.028180215507745743
epoch: 4, step 900, loss: 0.08492615073919296
test accuracy: 98.69%

到此这篇关于手把手教你实现pytorch的mnist数据集的文章就介绍到这了,更多相关pytorch mnist数据集内容请搜索以前的文章或继续浏览下面的相关文章希望大家以后多多支持!