Pytorch--训练并保存的模型利用 torch.load() 在GPU、CPU上加载

    科技2022-08-28  97

     

    # 直接加载模型 model.load_state_dict(torch.load('./data/my_model.pkl')) #GPU训练的模型加载到CPU上: model.load_state_dict(torch.load('./data/my_model.pkl', map_location=lambda storage, loc: storage)) #加载到GPU1上: model.load_state_dict(torch.load('./data/my_model.pkl', map_location=lambda storage, loc: storage.cuda(1))) #从GPU1 移动到 GPU0: model.load_state_dict(torch.load('./data/my_model.pkl', map_location={'cuda:1':'cuda:0'}))

     

    Processed: 0.013, SQL: 9