【pytorch官方文档学习之九】PyTorch: state

    科技2022-08-05  120

    本系列旨在通过阅读官方pytorch代码熟悉神经网络各个框架的实现方式和流程。

    【pytorch官方文档学习之九】PyTorch: state_dict

    本文是对官方文档What is a state_dict?的详细注释和个人理解,欢迎交流。state_dict的作用 在pytorch中,每一个torch.nn.Module模型的可学习参数learnable parameters(weighs和biases)都存储到该模型的model.parameters()中。但是怎样将神经网络的每一层layer与model.patameters()建立联系呢?state_dict应运而生。state_dict为python的字典,可以将每一层与其参数建立映射。state_dict的适用范围 只有拥有可学习参数learnable parameters(convolutional layers, linear layers, etc.)和registered buffers(batchnorm’s running_mean)的层才可以使用state_dict。 优化器torch.optim也可以使用state_dict,用以表示优化器的状态和所使用的超参数。实例 以下是包含state_dict的训练分类器的代码: # Define model class TheModelClass(nn.Module): def __init__(self): super(TheModelClass, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1, 16 * 5 * 5) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x # Initialize model model = TheModelClass() # Initialize optimizer optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) # Print model's state_dict print("Model's state_dict:") for param_tensor in model.state_dict(): print(param_tensor, "\t", model.state_dict()[param_tensor].size()) # Print optimizer's state_dict print("Optimizer's state_dict:") for var_name in optimizer.state_dict(): print(var_name, "\t", optimizer.state_dict()[var_name]) Model's state_dict: conv1.weight torch.Size([6, 3, 5, 5]) conv1.bias torch.Size([6]) conv2.weight torch.Size([16, 6, 5, 5]) conv2.bias torch.Size([16]) fc1.weight torch.Size([120, 400]) fc1.bias torch.Size([120]) fc2.weight torch.Size([84, 120]) fc2.bias torch.Size([84]) fc3.weight torch.Size([10, 84]) fc3.bias torch.Size([10]) Optimizer's state_dict: state {} param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [4675713712, 4675713784, 4675714000, 4675714072, 4675714216, 4675714288, 4675714432, 4675714504, 4675714648, 4675714720]}]
    Processed: 0.012, SQL: 8