欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

Unexpected key(s) in state_dict: “dense_block1.denselayer1.norm.1

程序员文章站 2022-05-27 10:32:29
...

 

Unexpected key(s) in state_dict: "dense_block1.denselayer1.norm.1

 

from torchvision.models import densenet121
from collections import OrderedDict

model = densenet121(pretrained=False)

state_dict =torch.load(model_weight_path)
# 初始化一个空 dict
new_state_dict = OrderedDict()
# 修改 key
for k, v in state_dict.items():
    if 'denseblock' in k:
        param = k.split(".")
        k = ".".join(param[:-3] + [param[-3]+param[-2]] + [param[-1]])
    new_state_dict[k] = v
    model.load_state_dict(new_state_dict)

我的解决方法:

# 初始化一个空 dict
new_state_dict = OrderedDict()
# 修改 key
for k, v in state_dict.items():
    k=k.replace('module.', '')
    if 'dense_block' in k:
        if "norm" in k or "conv.1" in k or "conv.2" in k:
            param = k.split(".")
            k = ".".join(param[:-3] + [param[-3]+param[-2]] + [param[-1]])
        new_state_dict[k] = v
    else:
        new_state_dict[k] = v

 

相关标签: torch