解决Pytorch下报错Missing key(s) in state_dict: "resnet.conv1.0.weight",和 Unexpected key(s) in state_dict
程序员文章站
2022-05-27 10:32:35
...
运行predict.py时报错如下:
RuntimeError: Error(s) in loading state_dict for VisitNet:
Missing key(s) in state_dict: “resnet.conv1.0.weight”, “resnet.conv1.1.weight”, “resnet.conv1.1.bias”, “resnet.conv1.1.running_mean”, “resnet.conv1.1.running_var”, “resnet.conv1.3.weight”, “resnet.conv1.4.weight”, “resnet.conv1.4.bias”, “resnet.conv1.4.running_mean”, “resnet.conv1.4.running_var”, “resnet.conv1.6.weight”, "
…
Unexpected key(s) in state_dict: “module.resnet.conv1.0.weight”, “module.resnet.conv1.1.weight”, “module.resnet.conv1.1.bias”, …
原因是训练时加入了
model = nn.DataParallel(model).cuda()
而测试时没有加入。
解决方法:
predict.py加入
model = nn.DataParallel(model).cuda()
即可,如下OK。
model = Net().cuda()
for name, param in model.named_parameters():
print(name, param.shape)
model = nn.DataParallel(model).cuda()
model.load_state_dict(torch.load(model_path))