PyTorch 错误 导出onnx提示 Unexpected key(s) in state_dict: “module.classifier.0.weight
程序员文章站
2022-05-27 10:32:17
...
PyTorch 错误 导出onnx提示 Unexpected key(s) in state_dict: "module.classifier.0.weight
flyfish
Traceback (most recent call last):
model.load_state_dict(checkpoint['state_dict'])
File "torch/nn/modules/module.py", line 1052, in load_state_dict
self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for Model:
Missing key(s) in state_dict: "classifier.0.weight", "classifier.0.bias", "classifier.3.weight", "classifier.3.bias", "classifier.5.weight", "classifier.5.bias".
Unexpected key(s) in state_dict: "module.classifier.0.weight", "module.classifier.0.bias", "module.classifier.3.weight", "module.classifier.3.bias", "module.classifier.5.weight", "module.classifier.5.bias".
错误原因
多GPU训练模型直接保存,多了module关键字,module.classifier.0.weight 应是classifier.0.weight
平常单个GPU写法是
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = net().to(device)
多GPU写法
model = torch.nn.DataParallel(net()).cuda()
保存的不仅仅只有网络,还有epoch等其他信息
state={
'epoch': epoch + 1,
'state_dict': model.state_dict(),
'test_acc': test_acc,
'best_acc': best_acc,
'optimizer' : optimizer.state_dict(),
}
filepath = os.path.join("checkpoint", filename)
torch.save(state, filepath)
解决方案,在导出onnx时,除去module. 这一共是7个字符
代码写法1
#args.model是路径
checkpoint = torch.load(args.model)
new_state_dict = collections.OrderedDict()
for k, v in checkpoint['state_dict'].items():
name = k[7:] # remove "module."
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
代码写法2
model.load_state_dict({k.replace('module.',''):v for k,v in checkpoint['state_dict'].items()})
不推荐写法strict=False
model.load_state_dict(checkpoint['state_dict'],strict=False)