欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

PyTorch 错误 导出onnx提示 Unexpected key(s) in state_dict: “module.classifier.0.weight

程序员文章站 2022-05-27 10:32:17
...

PyTorch 错误 导出onnx提示 Unexpected key(s) in state_dict: "module.classifier.0.weight

flyfish

Traceback (most recent call last):
    model.load_state_dict(checkpoint['state_dict'])
  File "torch/nn/modules/module.py", line 1052, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for Model:
	Missing key(s) in state_dict: "classifier.0.weight", "classifier.0.bias", "classifier.3.weight", "classifier.3.bias", "classifier.5.weight", "classifier.5.bias". 
	Unexpected key(s) in state_dict: "module.classifier.0.weight", "module.classifier.0.bias", "module.classifier.3.weight", "module.classifier.3.bias", "module.classifier.5.weight", "module.classifier.5.bias". 

错误原因
多GPU训练模型直接保存,多了module关键字,module.classifier.0.weight 应是classifier.0.weight

平常单个GPU写法是

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = net().to(device)

多GPU写法

model = torch.nn.DataParallel(net()).cuda()

保存的不仅仅只有网络,还有epoch等其他信息

state={
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'test_acc': test_acc,
                'best_acc': best_acc,
                'optimizer' : optimizer.state_dict(),
            }

filepath = os.path.join("checkpoint", filename)
torch.save(state, filepath)

解决方案,在导出onnx时,除去module. 这一共是7个字符

代码写法1

#args.model是路径
checkpoint = torch.load(args.model)
new_state_dict = collections.OrderedDict()

for k, v in checkpoint['state_dict'].items():
    name = k[7:]  # remove "module."
    new_state_dict[name] = v
model.load_state_dict(new_state_dict)

代码写法2

model.load_state_dict({k.replace('module.',''):v for k,v in checkpoint['state_dict'].items()})

不推荐写法strict=False

model.load_state_dict(checkpoint['state_dict'],strict=False)
相关标签: 深度学习