欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

Error(s) in loading state_dict for DataParallel: Missing key(s) in state_dict: “module.conv0.weight

程序员文章站 2022-05-27 10:32:23
...

在加载已经训练好的模型时,报错。

报错描述:

Error(s) in loading state_dict for DataParallel:
Missing key(s) in state_dict: “module.conv0.weight”, “module.bn0.weight”, “module.bn0.bias”, “module.bn0.running_mean”, “module.bn0.running_var”, “module.conv1.weight”, “module.bn1.weight”, “module.bn1.bias”, “module.bn1.running_mean”, “module.bn1.running_var”, “module.conv2.weight”, “module.bn2.weight”, “module.bn2.bias”, “module.bn2.running_mean”, “module.bn2.running_var”, “module.conv3.weight”, “module.bn3.weight”, “module.bn3.bias”, “module.bn3.running_mean”, “module.bn3.running_var”, “module.conv4.weight”, “module.bn4.weight”, “module.bn4.bias”, “module.bn4.running_mean”, “module.bn4.running_var”, “module.conv5.weight”, “module.bn5.weight”, “module.bn5.bias”, “module.bn5.running_mean”, “module.bn5.running_var”, “module.fc.weight”, “module.fc.bias”.
Unexpected key(s) in state_dict: “epoch”, “state_dict”, “best_prec1”.

原因:

保存模型的代码:

save_checkpoint({
        'state_dict': model.state_dict(),
        'best_prec1': best_prec1,
    }, is_best, filename=os.path.join(args.save_dir, 'model.th'))

该函数:

# 保存最新和最佳模型
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    """
    Save the training model
    """
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')

由以上可知,在调用torch.save函数时,state对应的是{ 'state_dict': model.state_dict(), 'best_prec1': best_prec1, },因此在加载的时候要指明值。

解决方案:

修改为:

model.load_state_dict(torch.load("C:\\Users\\83543\\Desktop\\model_best.pth.tar")['state_dict'])
相关标签: torch