欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

PyTorch学习笔记(12)模型容器

程序员文章站 2022-03-21 19:49:48
...

containers

包含 nn.Sequetial 按顺序包装多个网络层
nn.ModuleList 像python的list一样包装多个网络层
nn.ModuleDict 像python的dict一样包装多个网络层

nn.Sequetial 是nn.module 的容器,用于按顺序包装一组网络层

Sequential 两个特性

顺序性 各网络层之间严格按照顺序构建
自带forward() 自带的forward里,通过for 循环依次执行前向传播运算

容器 ModuleList

nn.ModuleList 是nn.module 的容器,用于包装一组网络层,以迭代方式调用网络层
append() 在ModuleList 后面添加网络层
extend() 拼接两个ModuleList
insert() 制定在ModuleList 中位置插入网络层

容器之ModuleLDict

nn.ModuleDict 是nn.module的容器,用于包装一组网络层,以索引方式调用网络层

主要方法

clear() 清空ModuleDict
items() 返回可迭代的尖嘴对(key-value pairs)
keys() 返回字典的键(key)
values() 返回字典的值(value)
pop() 返回一对键值,并从字典中删除

总结

nn.Sequential 顺序性,各网络层之间严格按顺序执行,常用于block构建
nn.ModuleList 迭代性,常用于大量重复网构建,通过for 循环实现重复构建
nn.ModuleDict 索引性,常用于可选择的网络层

# -*- coding: utf-8 -*-

import torch
import torchvision
import torch.nn as nn
from collections import OrderedDict


# ============================ Sequential
class LeNetSequential(nn.Module):
    def __init__(self, classes):
        super(LeNetSequential, self).__init__()
        self.features = nn.Sequential( # 利用Sequential 对卷积层 和池化层进行包装
            nn.Conv2d(3, 6, 5),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(6, 16, 5),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),)

        self.classifier = nn.Sequential(
            nn.Linear(16*5*5, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, classes),)
    # 前向传播
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size()[0], -1)
        x = self.classifier(x)
        return x


class LeNetSequentialOrderDict(nn.Module):
    def __init__(self, classes):
        super(LeNetSequentialOrderDict, self).__init__()

        self.features = nn.Sequential(OrderedDict({
            'conv1': nn.Conv2d(3, 6, 5),
            'relu1': nn.ReLU(inplace=True),
            'pool1': nn.MaxPool2d(kernel_size=2, stride=2),

            'conv2': nn.Conv2d(6, 16, 5),
            'relu2': nn.ReLU(inplace=True),
            'pool2': nn.MaxPool2d(kernel_size=2, stride=2),
        }))

        self.classifier = nn.Sequential(OrderedDict({
            'fc1': nn.Linear(16*5*5, 120),
            'relu3': nn.ReLU(),

            'fc2': nn.Linear(120, 84),
            'relu4': nn.ReLU(inplace=True),

            'fc3': nn.Linear(84, classes),
        }))

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size()[0], -1)
        x = self.classifier(x)
        return x


# net = LeNetSequential(classes=2)
# net = LeNetSequentialOrderDict(classes=2)
#
# fake_img = torch.randn((4, 3, 32, 32), dtype=torch.float32)
#
# output = net(fake_img)
#
# print(net)
# print(output)


# ============================ ModuleList
# 循环迭代实现20个全连接层
# 每个全连接层是10个神经元
class ModuleList(nn.Module):
    def __init__(self):
        super(ModuleList, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(20)])

    def forward(self, x):
        for i, linear in enumerate(self.linears):
            x = linear(x)
        return x


# net = ModuleList()
#
# print(net)
#
# fake_data = torch.ones((10, 10))
#
# output = net(fake_data)
#
# print(output)


# ============================ ModuleDict

class ModuleDict(nn.Module):
    def __init__(self):
        super(ModuleDict, self).__init__()
        self.choices = nn.ModuleDict({
            'conv': nn.Conv2d(10, 10, 3),
            'pool': nn.MaxPool2d(3)
        })

        self.activations = nn.ModuleDict({
            'relu': nn.ReLU(),
            'prelu': nn.PReLU()
        })

    def forward(self, x, choice, act):
        x = self.choices[choice](x)
        x = self.activations[act](x)
        return x


net = ModuleDict()

fake_img = torch.randn((4, 10, 32, 32))

output = net(fake_img, 'conv', 'relu')

print(output)



AlexNet

2012年以高出第二名10多个百分点的准确率获得ImageNet分类任务冠军,开创了卷积神经网络的新时代

特点:

1,采用ReLU: 替代饱和**函数,减轻梯度消失
2.采用LRN(local Response Normalization): 都对数据归一化,减轻梯度消失
3.Dropout:提高全连接层的鲁棒性,增加网络的泛化能力
4.Data Augmentation:TenCrop,色彩修改

alexnet = torchvision.models.AlexNet()