pytorch部分模型参数加载

    科技2025-09-16  43

    这里参考了一种部分模型权重加载的方法。

    def initialize_weights(self): resnet50 = models.resnet50(pretrained=True) pretrained_dict = resnet50.state_dict() all_params = {} for k, v in self.resnet.state_dict().items(): if k in pretrained_dict.keys(): v = pretrained_dict[k] all_params[k] = v # elif '_1' in k: # name = k.split('_1')[0] + k.split('_1')[1] # v = pretrained_dict[name] # all_params[k] = v # elif '_2' in k: # name = k.split('_2')[0] + k.split('_2')[1] # v = pretrained_dict[name] # all_params[k] = v assert len(all_params.keys()) == len(self.resnet.state_dict().keys()) self.resnet.load_state_dict(all_params) print('[INFO] initialize weights from resnet50')
    Processed: 0.009, SQL: 8