欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

Pytorch 预训练模型下载和加载

程序员文章站 2022-06-13 16:02:28
...

PyTorch 加载和下载预训练模型可参考:pytorch预训练模型的下载地址以及解决下载速度慢的方法

- 下载地址

常用预训练模型在这里面:https://github.com/pytorch/vision/tree/master/torchvision/models

但是上述网址只有常见的 backbone (vgg, resnet, densenet, alexnet),在 GitHub 上,还找到了一个项目,提供 NASNet, ResNeXt, ResNet, InceptionV4, InceptionResnetV2, Xception, DPN 等预训练模型的下载:https://github.com/Cadene/pretrained-models.pytorch

具体下载位置是:https://data.lip6.fr/cadene/pretrainedmodels/

- 加载预训练模型

一般使用的是使用 model.load_state_dict() 函数。

model_urls = {  'resnet50': '/home/huihua/NewDisk1/pretrain_parameter/resnet50-19c8e357.pth',}
def resnet50(pretrained=False, **kwargs):
	model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
	if pretrained:
		model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
	return model

此时它会到指定的网站下载预训练模型到本地缓存中,本地缓存的位置(Linux系统)一般在:

.cache/torch/checkpoints

PyTorch 在加载模型时候首先检查本地缓存是否已经存在预训练模型,所以在本地缓存汇总预先放入已经下载的模型可快速加载模型。

如果需要更改预训练模型的位置,可以在文件开头加入:

os.environ['TORCH_HOME']= './pretrained_models/'

pretrained_models 文件夹下新建一个 checkpoints 文件夹并把预训练模型放入即可。

- 参考

  1. pytorch预训练模型下载URL及加载调用方法
  2. pytorch学习笔记之加载预训练模型