pytorch_pretrain_finetune
程序员文章站
2022-05-27 09:45:09
...
introduction
这是一个如何从 PyTorch 加载预训练模型 (VGG16) 的小示例,并对其进行修改以在 CIFAR10 数据集上进行训练。
相同的方法可以很好地推广到其他数据集,但可能需要更改对网络的修改。
# Imports
import torch
import torchvision
import torch.nn as nn # All neural network modules, nn.Linear, nn.Conv2d, BatchNorm, Loss functions
import torch.optim as optim # For all Optimization algorithms, SGD, Adam, etc.
import torch.nn.functional as F # All functions that don't have any parameters
from torch.utils.data import (
DataLoader,
) # Gives easier dataset managment and creates mini batches
import torchvision.datasets as datasets # Has standard datasets we can import in a nice way
import torchvision.transforms as transforms # Transformations we can perform on our dataset
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Hyperparameters
num_classes = 10
learning_rate = 1e-3
batch_size = 1024
num_epochs = 5
# Load pretrain model & modify it
model = torchvision.models.vgg16(pretrained=True)
# print(model)
# sys.exit()
对网络做调整
# Simple Identity class that let's input pass without changes
class Identity(nn.Module):
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
return x
# If you want to do finetuning then set requires_grad = False
# Remove these two lines if you want to train entire model,
# and only want to load the pretrain weights.
for param in model.parameters():
param.requires_grad = False
model.avgpool = Identity()
model.classifier = nn.Sequential(
nn.Linear(512, 100), nn.ReLU(), nn.Linear(100, num_classes)
)
#for i in range(1,7):
# model.classifier[i] = Identity()
model.to(device)
上一篇: JavaScript 中的面向对象编程
下一篇: Focal loss PyTorch实现