欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

【Pytorch框架学习】之分类应用中迁移学习网络修改总结(3)

程序员文章站 2022-05-18 10:03:50
...

【Pytorch框架学习】之迁移学习网络分类总结(3)

一、内容

在做图像分类应用时,常常会使用一些经典的预训练网络,比如ResNet、VGG、Inception、DenseNet、EfficientNet、ResNeXt等。但是通常都是在ImageNet上预训练的,所以不能直接使用,因此需要最一些层的修改,这里总结了几种,方法是一样的,这里以二分类为例。
【Pytorch框架学习】之分类应用中迁移学习网络修改总结(3)

二、代码

import torch.nn as nn
import torchvision.models as models

num_class = 2

# ResNet
model = models.resnet50()
model.fc = nn.Linear(512, num_class)
print(model)

# AlexNet
model = models.AlexNet()
model.classifier[6] = nn.Linear(4096, num_class)
print(model)

# vgg
model = models.vgg16()
model.classifier[6] = nn.Linear(4096, num_class)
print(model)

# SqueezeNet
model = models.squeezenet1_0()
model.classifier[1] = nn.Conv2d(512, num_class, kernel_size=(1, 1), stride=(1, 1))
print(model)

# DenseNet
model = models.densenet121()
model.classifier = nn.Linear(1024, num_class)
print(model)

# Inception
model = models.inception_v3()
model.AuxLogits.fc = nn.Linear(768, num_class)
model.fc = nn.Linear(2048, num_class)
print(model)
相关标签: Pytorch框架