pytorch加载模型报错Unexpected key(s) in state_dict: module.conv1.weight, module.bn1
程序员文章站
2022-03-03 14:37:24
...
代码
import models
arch = 'resnet50'
model = models.__dict__[arch]()
checkpoint = torch.load(ckptFile)
model.load_state_dict(checkpoint['state_dict'])
model = torch.nn.DataParallel(model).cuda()
报错
Traceback (most recent call last):
File "/home/user1/project1/utils/eval.py", line 193, in <module>
test_models(r'/home/user1/models')
File "/home/user1/project1/utils/eval.py", line 157, in test_models
model.load_state_dict(checkpoint['state_dict'])
File "/home/user1/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 777, in load_state_dict
self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for ResNet:
Missing key(s) in state_dict: "conv1.weight", "bn1.weight", "bn1.bias", ...".
Unexpected key(s) in state_dict: "module.conv1.weight", "module.bn1.weight", "module.bn1.bias"...
原因
哥们,你顺序搞错了,应该先把模型放到GPU上,然后再开始load参数
解决
正确的顺序
import models
arch = 'resnet50'
model = models.__dict__[arch]()
model = torch.nn.DataParallel(model).cuda()
checkpoint = torch.load(ckptFile)
model.load_state_dict(checkpoint['state_dict'])
下一篇: Pyecharts之关系图(Graph)