一、序列化和反序列化
二、模型保存与加载的两种方式
三、断点续训
模型保存在内存中。需要把模型从内存上搬到硬盘中进行存储,以备以后。
模型的保存与加载,也称为模型的序列化和反序列化。为什么这么叫呢?
模型在内存中是以一个形式存储的,但是在内存中,对象不能长久保存。而在硬盘中,是以01的二进制形式保存的,是二进制序列。序列化是指将内存中的某一个对象保存到硬盘中,以二进制序列的形式存储下来,这就是一个序列化的过程。
什么是反序列化呢?就是将存储的二进制数放到内存中,得到对象。
(1)torch.save
obj:可以是模型,张量、参数等都可以,python中一切皆为对象。
f:输出路径。
(2) torch.load
f:路径
map_location:具体使用在将GPU时讲。
(1)保存整个model
比较懒的方法。耗时。
例:保存整个模型,然后再加载
步骤一:保存整个模型
# -*- coding: utf-8 -*- """ # @file name : model_save1.py # @brief : 模型的保存 """ import torch import numpy as np import torch.nn as nn class LeNet2(nn.Module): def __init__(self, classes): super(LeNet2, self).__init__() self.features = nn.Sequential( nn.Conv2d(3, 6, 5), nn.ReLU(), nn.MaxPool2d(2, 2), nn.Conv2d(6, 16, 5), nn.ReLU(), nn.MaxPool2d(2, 2) ) self.classifier = nn.Sequential( nn.Linear(16*5*5, 120), nn.ReLU(), nn.Linear(120, 84), nn.ReLU(), nn.Linear(84, classes) ) def forward(self, x): x = self.features(x) x = x.view(x.size()[0], -1) x = self.classifier(x) return x def initialize(self): for p in self.parameters(): p.data.fill_(20191104) net = LeNet2(classes=2019) # "训练" print("训练前: ", net.features[0].weight[0, ...]) net.initialize() print("训练后: ", net.features[0].weight[0, ...]) # 保存整个模型 path_model = "./model.pkl" torch.save(net, path_model) #保存模型输出结果:
可以看到多了一个model.pkl文件。
步骤二:加载整个模型。
# -*- coding: utf-8 -*- """ # @file name : model_load1.py # @brief : 模型的加载 """ import torch import numpy as np import torch.nn as nn class LeNet2(nn.Module): #不写会报错 def __init__(self, classes): super(LeNet2, self).__init__() self.features = nn.Sequential( nn.Conv2d(3, 6, 5), nn.ReLU(), nn.MaxPool2d(2, 2), nn.Conv2d(6, 16, 5), nn.ReLU(), nn.MaxPool2d(2, 2) ) self.classifier = nn.Sequential( nn.Linear(16*5*5, 120), nn.ReLU(), nn.Linear(120, 84), nn.ReLU(), nn.Linear(84, classes) ) def forward(self, x): x = self.features(x) x = x.view(x.size()[0], -1) x = self.classifier(x) return x def initialize(self): for p in self.parameters(): p.data.fill_(20191104) # ================================== load net =========================== path_model = "./model.pkl" net_load = torch.load(path_model) print(net_load)结果:
调试也可以看到权重已经加载进来了。
(2) 保存模型参数
官方推荐这种方法。
例:
步骤一:
# -*- coding: utf-8 -*- """ # @file name : model_save2.py # @brief : 模型的保存 """ import torch import numpy as np import torch.nn as nn class LeNet2(nn.Module): def __init__(self, classes): super(LeNet2, self).__init__() self.features = nn.Sequential( nn.Conv2d(3, 6, 5), nn.ReLU(), nn.MaxPool2d(2, 2), nn.Conv2d(6, 16, 5), nn.ReLU(), nn.MaxPool2d(2, 2) ) self.classifier = nn.Sequential( nn.Linear(16*5*5, 120), nn.ReLU(), nn.Linear(120, 84), nn.ReLU(), nn.Linear(84, classes) ) def forward(self, x): x = self.features(x) x = x.view(x.size()[0], -1) x = self.classifier(x) return x def initialize(self): for p in self.parameters(): p.data.fill_(20191104) net = LeNet2(classes=2019) # "训练" print("训练前: ", net.features[0].weight[0, ...]) net.initialize() print("训练后: ", net.features[0].weight[0, ...]) # 保存模型参数 net_state_dict = net.state_dict() #第一步,把模型中的可学习参数拿出来,返回state_dict path_state_dict = "./model_state_dict.pkl" torch.save(net_state_dict, path_state_dict) #第二步,把state_dict保存到硬盘
结果:
步骤二:加载模型
# -*- coding: utf-8 -*- """ # @file name : model_load2.py # @brief : 模型的加载 """ import torch import numpy as np import torch.nn as nn class LeNet2(nn.Module): def __init__(self, classes): super(LeNet2, self).__init__() self.features = nn.Sequential( nn.Conv2d(3, 6, 5), nn.ReLU(), nn.MaxPool2d(2, 2), nn.Conv2d(6, 16, 5), nn.ReLU(), nn.MaxPool2d(2, 2) ) self.classifier = nn.Sequential( nn.Linear(16*5*5, 120), nn.ReLU(), nn.Linear(120, 84), nn.ReLU(), nn.Linear(84, classes) ) def forward(self, x): x = self.features(x) x = x.view(x.size()[0], -1) x = self.classifier(x) return x def initialize(self): for p in self.parameters(): p.data.fill_(20191104) net_new = LeNet2(classes=2019) print("加载前: ", net_new.features[0].weight[0, ...]) # ================================== 加载参数 path_state_dict = "./model_state_dict.pkl" state_dict_load = torch.load(path_state_dict) #第一步,把state_dict从硬盘拿到内存 #print(state_dict_load.keys()) net_new.load_state_dict(state_dict_load) #第二步,把state_dict放到模型中 print("加载后: ", net_new.features[0].weight[0, ...])
结果:
官方推荐第二种方法。
断点续训需要保存哪些数据呢?
数据、模型、损失函数、优化器,其中只有模型和优化器会随着迭代不断变化。模型中的权值等可学习参数;优化器中的参数,例如momentum需要用到之前的信息等。
因此checkpoint需要保存的数据有:模型、优化器、迭代次数。
例:
步骤一:模拟中断
# -*- coding: utf-8 -*- """ # @file name : save_checkpoint.py # @brief : 模拟训练意外停止 """ import os import random import numpy as np import torch import torch.nn as nn from torch.utils.data import DataLoader import torchvision.transforms as transforms import torch.optim as optim from PIL import Image from matplotlib import pyplot as plt import sys hello_pytorch_DIR = os.path.abspath(os.path.dirname(__file__)+os.path.sep+".."+os.path.sep+"..") sys.path.append(hello_pytorch_DIR) from model.lenet import LeNet from tools.my_dataset import RMBDataset from tools.common_tools import set_seed import torchvision set_seed(1) # 设置随机种子 rmb_label = {"1": 0, "100": 1} # 参数设置 checkpoint_interval = 5 MAX_EPOCH = 10 BATCH_SIZE = 16 LR = 0.01 log_interval = 10 val_interval = 1 # ============================ step 1/5 数据 ============================ BASE_DIR = os.path.dirname(os.path.abspath(__file__)) split_dir = os.path.abspath(os.path.join(BASE_DIR, "data", "rmb_split")) train_dir = os.path.join(split_dir, "train") valid_dir = os.path.join(split_dir, "valid") if not os.path.exists(split_dir): raise Exception(r"数据 {} 不存在, 回到lesson-06\1_split_dataset.py生成数据".format(split_dir)) norm_mean = [0.485, 0.456, 0.406] norm_std = [0.229, 0.224, 0.225] train_transform = transforms.Compose([ transforms.Resize((32, 32)), transforms.RandomCrop(32, padding=4), transforms.RandomGrayscale(p=0.8), transforms.ToTensor(), transforms.Normalize(norm_mean, norm_std), ]) valid_transform = transforms.Compose([ transforms.Resize((32, 32)), transforms.ToTensor(), transforms.Normalize(norm_mean, norm_std), ]) # 构建MyDataset实例 train_data = RMBDataset(data_dir=train_dir, transform=train_transform) valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform) # 构建DataLoder train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True) valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE) # ============================ step 2/5 模型 ============================ net = LeNet(classes=2) net.initialize_weights() # ============================ step 3/5 损失函数 ============================ criterion = nn.CrossEntropyLoss() # 选择损失函数 # ============================ step 4/5 优化器 ============================ optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9) # 选择优化器 scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=6, gamma=0.1) # 设置学习率下降策略 # ============================ step 5/5 训练 ============================ train_curve = list() valid_curve = list() start_epoch = -1 for epoch in range(start_epoch+1, MAX_EPOCH): loss_mean = 0. correct = 0. total = 0. net.train() for i, data in enumerate(train_loader): # forward inputs, labels = data outputs = net(inputs) # backward optimizer.zero_grad() loss = criterion(outputs, labels) loss.backward() # update weights optimizer.step() # 统计分类情况 _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).squeeze().sum().numpy() # 打印训练信息 loss_mean += loss.item() train_curve.append(loss.item()) if (i+1) % log_interval == 0: loss_mean = loss_mean / log_interval print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format( epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total)) loss_mean = 0. scheduler.step() # 更新学习率 if (epoch+1) % checkpoint_interval == 0: checkpoint = {"model_state_dict": net.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "epoch": epoch} path_checkpoint = "./checkpoint_{}_epoch.pkl".format(epoch) torch.save(checkpoint, path_checkpoint) if epoch > 5: print("训练意外中断...") break # validate the model if (epoch+1) % val_interval == 0: correct_val = 0. total_val = 0. loss_val = 0. net.eval() with torch.no_grad(): for j, data in enumerate(valid_loader): inputs, labels = data outputs = net(inputs) loss = criterion(outputs, labels) _, predicted = torch.max(outputs.data, 1) total_val += labels.size(0) correct_val += (predicted == labels).squeeze().sum().numpy() loss_val += loss.item() valid_curve.append(loss.item()) print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format( epoch, MAX_EPOCH, j+1, len(valid_loader), loss_val/len(valid_loader), correct / total)) train_x = range(len(train_curve)) train_y = train_curve train_iters = len(train_loader) valid_x = np.arange(1, len(valid_curve)+1) * train_iters*val_interval # 由于valid中记录的是epochloss,需要对记录点进行转换到iterations valid_y = valid_curve plt.plot(train_x, train_y, label='Train') plt.plot(valid_x, valid_y, label='Valid') plt.legend(loc='upper right') plt.ylabel('loss value') plt.xlabel('Iteration') plt.show()结果:
步骤二:恢复训练
# -*- coding: utf-8 -*- """ # @file name : save_checkpoint.py # @brief : 断点续训功能 """ import os import random import numpy as np import torch import torch.nn as nn from torch.utils.data import DataLoader import torchvision.transforms as transforms import torch.optim as optim from PIL import Image from matplotlib import pyplot as plt import sys hello_pytorch_DIR = os.path.abspath(os.path.dirname(__file__)+os.path.sep+".."+os.path.sep+"..") sys.path.append(hello_pytorch_DIR) from model.lenet import LeNet from tools.my_dataset import RMBDataset from tools.common_tools import set_seed import torchvision set_seed(1) # 设置随机种子 rmb_label = {"1": 0, "100": 1} # 参数设置 checkpoint_interval = 5 MAX_EPOCH = 10 BATCH_SIZE = 16 LR = 0.01 log_interval = 10 val_interval = 1 # ============================ step 1/5 数据 ============================ BASE_DIR = os.path.dirname(os.path.abspath(__file__)) split_dir = os.path.abspath(os.path.join(BASE_DIR, "data", "rmb_split")) train_dir = os.path.join(split_dir, "train") valid_dir = os.path.join(split_dir, "valid") if not os.path.exists(split_dir): raise Exception(r"数据 {} 不存在, 回到lesson-06\1_split_dataset.py生成数据".format(split_dir)) norm_mean = [0.485, 0.456, 0.406] norm_std = [0.229, 0.224, 0.225] train_transform = transforms.Compose([ transforms.Resize((32, 32)), transforms.RandomCrop(32, padding=4), transforms.RandomGrayscale(p=0.8), transforms.ToTensor(), transforms.Normalize(norm_mean, norm_std), ]) valid_transform = transforms.Compose([ transforms.Resize((32, 32)), transforms.ToTensor(), transforms.Normalize(norm_mean, norm_std), ]) # 构建MyDataset实例 train_data = RMBDataset(data_dir=train_dir, transform=train_transform) valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform) # 构建DataLoder train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True) valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE) # ============================ step 2/5 模型 ============================ net = LeNet(classes=2) net.initialize_weights() # ============================ step 3/5 损失函数 ============================ criterion = nn.CrossEntropyLoss() # 选择损失函数 # ============================ step 4/5 优化器 ============================ optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9) # 选择优化器 scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=6, gamma=0.1) # 设置学习率下降策略 # ============================ step 5+/5 断点恢复 ============================ path_checkpoint = "./checkpoint_4_epoch.pkl" checkpoint = torch.load(path_checkpoint) net.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) start_epoch = checkpoint['epoch'] scheduler.last_epoch = start_epoch # ============================ step 5/5 训练 ============================ train_curve = list() valid_curve = list() for epoch in range(start_epoch + 1, MAX_EPOCH): loss_mean = 0. correct = 0. total = 0. net.train() for i, data in enumerate(train_loader): # forward inputs, labels = data outputs = net(inputs) # backward optimizer.zero_grad() loss = criterion(outputs, labels) loss.backward() # update weights optimizer.step() # 统计分类情况 _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).squeeze().sum().numpy() # 打印训练信息 loss_mean += loss.item() train_curve.append(loss.item()) if (i+1) % log_interval == 0: loss_mean = loss_mean / log_interval print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format( epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total)) loss_mean = 0. scheduler.step() # 更新学习率 if (epoch+1) % checkpoint_interval == 0: checkpoint = {"model_state_dict": net.state_dict(), "optimizer_state_dic": optimizer.state_dict(), "loss": loss, "epoch": epoch} path_checkpoint = "./checkpint_{}_epoch.pkl".format(epoch) torch.save(checkpoint, path_checkpoint) # if epoch > 5: # print("训练意外中断...") # break # validate the model if (epoch+1) % val_interval == 0: correct_val = 0. total_val = 0. loss_val = 0. net.eval() with torch.no_grad(): for j, data in enumerate(valid_loader): inputs, labels = data outputs = net(inputs) loss = criterion(outputs, labels) _, predicted = torch.max(outputs.data, 1) total_val += labels.size(0) correct_val += (predicted == labels).squeeze().sum().numpy() loss_val += loss.item() valid_curve.append(loss.item()) print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format( epoch, MAX_EPOCH, j+1, len(valid_loader), loss_val/len(valid_loader), correct / total)) train_x = range(len(train_curve)) train_y = train_curve train_iters = len(train_loader) valid_x = np.arange(1, len(valid_curve)+1) * train_iters*val_interval # 由于valid中记录的是epochloss,需要对记录点进行转换到iterations valid_y = valid_curve plt.plot(train_x, train_y, label='Train') plt.plot(valid_x, valid_y, label='Valid') plt.legend(loc='upper right') plt.ylabel('loss value') plt.xlabel('Iteration') plt.show()
结果:
