欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

pytorch_pretrain_finetune

程序员文章站 2022-05-27 09:45:09
...

introduction

这是一个如何从 PyTorch 加载预训练模型 (VGG16) 的小示例,并对其进行修改以在 CIFAR10 数据集上进行训练。
相同的方法可以很好地推广到其他数据集,但可能需要更改对网络的修改。

# Imports
import torch
import torchvision
import torch.nn as nn  # All neural network modules, nn.Linear, nn.Conv2d, BatchNorm, Loss functions
import torch.optim as optim  # For all Optimization algorithms, SGD, Adam, etc.
import torch.nn.functional as F  # All functions that don't have any parameters
from torch.utils.data import (
    DataLoader,
)  # Gives easier dataset managment and creates mini batches
import torchvision.datasets as datasets  # Has standard datasets we can import in a nice way
import torchvision.transforms as transforms  # Transformations we can perform on our dataset

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters
num_classes = 10
learning_rate = 1e-3
batch_size = 1024
num_epochs = 5
# Load pretrain model & modify it
model = torchvision.models.vgg16(pretrained=True)
# print(model)
# sys.exit()

对网络做调整

# Simple Identity class that let's input pass without changes
class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()

    def forward(self, x):
        return x

# If you want to do finetuning then set requires_grad = False
# Remove these two lines if you want to train entire model,
# and only want to load the pretrain weights.
for param in model.parameters():
    param.requires_grad = False

model.avgpool = Identity()
model.classifier = nn.Sequential(
    nn.Linear(512, 100), nn.ReLU(), nn.Linear(100, num_classes)
)
#for i in range(1,7):
#	 model.classifier[i] = Identity()
model.to(device)
相关标签: pytorch pytorch