【pytorch官方文档学习之十】PyTorch: Saving & Loading Model for Inference

    科技2022-08-06  118

    本系列旨在通过阅读官方pytorch代码熟悉神经网络各个框架的实现方式和流程。

    PyTorch: Saving & Loading Model for Inference

    本文是对官方文档Saving & Loading Model for Inference的详细注释和个人理解,欢迎交流。

    Save/Load state_dict (Recommended)

    Save: torch.save(model.state_dict(), PATH) Load: model = TheModelClass(*args, **kwargs) model.load_state_dict(torch.load(PATH)) model.eval()

    NOTE

    The 1.6 release of PyTorch switched torch.save to use a new zipfile-based file format. torch.load still retains the ability to load files in the old format. If for any reason you want torch.save to use the old format, pass the kwarg _use_new_zipfile_serialization=False.

    要点 当我们保存一个模型的时候,只需要保存已经训练好的参数即可,而state_dict和torch.save()就可以做到。保存模型最常用的格式为.pth和.pt。在测试模型之前(before running inference),必须调用model.eval()来设置dropout和batch normalization来评估模型。否则会导致测试结果不一致。

    NOTE

    Notice that the load_state_dict()function takes a dictionary object, NOT a path to a saved object. This means that you must deserialize the saved state_dict before you pass it to the load_state_dict() function. For example, you CANNOT load using model.load_state_dict(PATH).

    Save/Load Entire Model

    Save: torch.save(model, PATH) Load: # Model class must be defined somewhere model = torch.load(PATH) model.eval() 要点 这种save/load的方法会保存使用pickle的整个模型。缺点是序列化的数据被绑定到特定的类,并且在保存模型时使用准确的目录结构,因为pickle并不会保存模型的类,而是会保存在载入数据期间会被用到的包含模型的类的文件的路径。保存模型最常用的格式为.pth和.pt。在测试模型之前(before running inference),必须调用model.eval()来设置dropout和batch normalization来评估模型。否则会导致测试结果不一致。
    Processed: 0.010, SQL: 8