本系列旨在通过阅读官方pytorch代码熟悉神经网络各个框架的实现方式和流程。
【pytorch官方文档学习之九】PyTorch: state_dict
本文是对官方文档What is a state_dict?的详细注释和个人理解,欢迎交流。state_dict的作用 在pytorch中,每一个torch.nn.Module模型的可学习参数learnable parameters(weighs和biases)都存储到该模型的model.parameters()中。但是怎样将神经网络的每一层layer与model.patameters()建立联系呢?state_dict应运而生。state_dict为python的字典,可以将每一层与其参数建立映射。state_dict的适用范围 只有拥有可学习参数learnable parameters(convolutional layers, linear layers, etc.)和registered buffers(batchnorm’s running_mean)的层才可以使用state_dict。 优化器torch.optim也可以使用state_dict,用以表示优化器的状态和所使用的超参数。实例 以下是包含state_dict的训练分类器的代码:
class TheModelClass(nn
.Module
):
def __init__(self
):
super(TheModelClass
, self
).__init__
()
self
.conv1
= nn
.Conv2d
(3, 6, 5)
self
.pool
= nn
.MaxPool2d
(2, 2)
self
.conv2
= nn
.Conv2d
(6, 16, 5)
self
.fc1
= nn
.Linear
(16 * 5 * 5, 120)
self
.fc2
= nn
.Linear
(120, 84)
self
.fc3
= nn
.Linear
(84, 10)
def forward(self
, x
):
x
= self
.pool
(F
.relu
(self
.conv1
(x
)))
x
= self
.pool
(F
.relu
(self
.conv2
(x
)))
x
= x
.view
(-1, 16 * 5 * 5)
x
= F
.relu
(self
.fc1
(x
))
x
= F
.relu
(self
.fc2
(x
))
x
= self
.fc3
(x
)
return x
model
= TheModelClass
()
optimizer
= optim
.SGD
(model
.parameters
(), lr
=0.001, momentum
=0.9)
print("Model's state_dict:")
for param_tensor
in model
.state_dict
():
print(param_tensor
, "\t", model
.state_dict
()[param_tensor
].size
())
print("Optimizer's state_dict:")
for var_name
in optimizer
.state_dict
():
print(var_name
, "\t", optimizer
.state_dict
()[var_name
])
Model's state_dict
:
conv1
.weight torch
.Size
([6, 3, 5, 5])
conv1
.bias torch
.Size
([6])
conv2
.weight torch
.Size
([16, 6, 5, 5])
conv2
.bias torch
.Size
([16])
fc1
.weight torch
.Size
([120, 400])
fc1
.bias torch
.Size
([120])
fc2
.weight torch
.Size
([84, 120])
fc2
.bias torch
.Size
([84])
fc3
.weight torch
.Size
([10, 84])
fc3
.bias torch
.Size
([10])
Optimizer's state_dict
:
state
{}
param_groups
[{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [4675713712, 4675713784, 4675714000, 4675714072, 4675714216, 4675714288, 4675714432, 4675714504, 4675714648, 4675714720]}]