欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

解决Pytorch下报错Missing key(s) in state_dict: "resnet.conv1.0.weight",和 Unexpected key(s) in state_dict

程序员文章站 2022-05-27 10:32:35
...

运行predict.py时报错如下:
RuntimeError: Error(s) in loading state_dict for VisitNet:
Missing key(s) in state_dict: “resnet.conv1.0.weight”, “resnet.conv1.1.weight”, “resnet.conv1.1.bias”, “resnet.conv1.1.running_mean”, “resnet.conv1.1.running_var”, “resnet.conv1.3.weight”, “resnet.conv1.4.weight”, “resnet.conv1.4.bias”, “resnet.conv1.4.running_mean”, “resnet.conv1.4.running_var”, “resnet.conv1.6.weight”, "

Unexpected key(s) in state_dict: “module.resnet.conv1.0.weight”, “module.resnet.conv1.1.weight”, “module.resnet.conv1.1.bias”, …

原因是训练时加入了

model = nn.DataParallel(model).cuda()

而测试时没有加入。

解决方法:
predict.py加入

model = nn.DataParallel(model).cuda()

即可,如下OK。

model = Net().cuda()
for name, param in model.named_parameters():
    print(name, param.shape)
model = nn.DataParallel(model).cuda()
model.load_state_dict(torch.load(model_path))
相关标签: bug bug