欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

pytorch加载模型报错Unexpected key(s) in state_dict: module.conv1.weight, module.bn1

程序员文章站 2022-03-03 14:37:24
...

文章目录

代码

import models

arch = 'resnet50'
model = models.__dict__[arch]()

checkpoint = torch.load(ckptFile)
model.load_state_dict(checkpoint['state_dict'])

model = torch.nn.DataParallel(model).cuda()

报错

Traceback (most recent call last):
  File "/home/user1/project1/utils/eval.py", line 193, in <module>
    test_models(r'/home/user1/models')
  File "/home/user1/project1/utils/eval.py", line 157, in test_models
    model.load_state_dict(checkpoint['state_dict'])
  File "/home/user1/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 777, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for ResNet:
	Missing key(s) in state_dict: "conv1.weight", "bn1.weight", "bn1.bias", ...". 
	Unexpected key(s) in state_dict: "module.conv1.weight", "module.bn1.weight", "module.bn1.bias"...

原因

哥们,你顺序搞错了,应该先把模型放到GPU上,然后再开始load参数

解决

正确的顺序

import models

arch = 'resnet50'
model = models.__dict__[arch]()

model = torch.nn.DataParallel(model).cuda()

checkpoint = torch.load(ckptFile)
model.load_state_dict(checkpoint['state_dict'])