pytorch加载模型时报错
RuntimeError: Error(s) in loading state_dict for DataParallel:
Missing key(s) in state_dict: “module.backbone.layers.0.stage_1.layers.0.weight”,
这是因为加载的预训练模型之前使用了torch.nn.DataParallel(),而此时没有使用,所以可以加上该模块或者去掉。
解决方法一:加上torch.nn.DataParallel()模块
model
= torch
.nn
.DataParallel
(model
)
torch
.backends
.cudnn
.benchmark
= True
model
.load_state_dict
(torch
.load
(model_path
))
解决方法二:将字典键值中的module.替换掉
model
.load_state_dict
({k
.replace
('module.', ''): v
for k
, v
in torch
.load
(model_path
).items
()})