完整报错信息:
RuntimeError:Error(s) in loading state_dict for DataParallel:
Unexpected key(s) in state_dict:“module.resnet.bn1.num_batches_tracked”,"module.resnet.layer1.0.bn1.num_batches_tracked"等等,遇到这种错误,说明你训练模型和测试加载模型所使用的环境不一致,解决方法:
1>将环境改为一致
2>我当时的环境是:训练pytorch1.0,测试环境Pytorch0.4,只需要把加载模型那一块做一个简单的修改,如以下代码:

checkpoint_file = os.path.join(args.checkpoint, args.test+'.pth.tar')
checkpoint = torch.load(checkpoint_file)
model.load_state_dict(checkpoint['state_dict'])

上面是我出错的代码,解决方法是在最后一行的括号里面加上False,如下:

checkpoint_file = os.path.join(args.checkpoint, args.test+'.pth.tar')
checkpoint = torch.load(checkpoint_file) 
model.load_state_dict(checkpoint['state_dict'],False) # 修改处

2019.11.7
model.load_state_dict(state_dict, strict=True)

Copies parameters and buffers from :attr:state_dict into this module and its descendants. If :attr:strict is True, then the keys of :attr:state_dict must exactly match the keys returned by this module’s :meth:~torch.nn.Module.state_dict function
从属性state_dict里面复制参数到这个模块和它的后代。如果strict为True, state_dict的keys必须完全与这个模块的方法返回的keys相匹配。如果为False,就不需要保证匹配。

Arguments:
state_dict (dict): a dict containing parameters and persistent buffers.
strict (bool, optional): whether to strictly enforce that the keys in :attr:state_dict match the keys returned by this module’s:meth:~torch.nn.Module.state_dict function. Default: True

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐