Pytorch 预训练模型下载和加载
程序员文章站
2022-06-13 16:02:28
...
PyTorch 加载和下载预训练模型可参考:pytorch预训练模型的下载地址以及解决下载速度慢的方法
- 下载地址
常用预训练模型在这里面:https://github.com/pytorch/vision/tree/master/torchvision/models
但是上述网址只有常见的 backbone (vgg, resnet, densenet, alexnet),在 GitHub 上,还找到了一个项目,提供 NASNet, ResNeXt, ResNet, InceptionV4, InceptionResnetV2, Xception, DPN 等预训练模型的下载:https://github.com/Cadene/pretrained-models.pytorch
具体下载位置是:https://data.lip6.fr/cadene/pretrainedmodels/
- 加载预训练模型
一般使用的是使用 model.load_state_dict()
函数。
model_urls = { 'resnet50': '/home/huihua/NewDisk1/pretrain_parameter/resnet50-19c8e357.pth',}
def resnet50(pretrained=False, **kwargs):
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
return model
此时它会到指定的网站下载预训练模型到本地缓存中,本地缓存的位置(Linux系统)一般在:
.cache/torch/checkpoints
PyTorch 在加载模型时候首先检查本地缓存是否已经存在预训练模型,所以在本地缓存汇总预先放入已经下载的模型可快速加载模型。
如果需要更改预训练模型的位置,可以在文件开头加入:
os.environ['TORCH_HOME']= './pretrained_models/'
在 pretrained_models
文件夹下新建一个 checkpoints
文件夹并把预训练模型放入即可。
- 参考
上一篇: FastApi 文件上传upload
下一篇: 显示器维护六大纪律