参考书籍:《深度学习框架pytorch快速开发与实战》
1、导入常用包
import torch
import torch
.nn
as nn
from torch
.autograd
import Variable
import torch
.utils
.data
as Data
import torchvision
import matplotlib
.pyplot
as plt
2、设置超参数
BATCH_SIZE
= 50
EPOCH
= 3
LR
= 0.001
3、数据集下载及处理
(转化为pytorch处理的tensor格式)
train_data
= torchvision
.datasets
.MNIST
(
root
= './mnist',
train
= True,
transform
= torchvision
.transforms
.ToTensor
(),
download
= True
)
test_data
= torchvision
.datasets
.MNIST
(
root
= './mnist',
train
= False,
transform
= torchvision
.transforms
.ToTensor
(),
download
= True
)
print(train_data
.data
.size
())
print(test_data
.data
.size
())
DataLoader可以把数据集分割为batch_size大小
train_loader
= Data
.DataLoader
(dataset
=train_data
, batch_size
=BATCH_SIZE
, shuffle
=True)
test_x
= test_data
.data
.reshape
(-1,1,28,28)
test_x
= torch
.true_divide
(test_x
,255)
test_y
= test_data
.targets
4、模型构建
class CNN(nn
.Module
):
def __init__(self
):
super(CNN
,self
).__init__
();
self
.conv1
= nn
.Sequential
(
nn
.Conv2d
(in_channels
=1, out_channels
=16, kernel_size
=5, stride
=1, padding
=2),
nn
.ReLU
(),
nn
.MaxPool2d
(kernel_size
=2)
)
self
.conv2
= nn
.Sequential
(
nn
.Conv2d
(16,32,5,1,2),
nn
.ReLU
(),
nn
.MaxPool2d
(2)
)
self
.out
= nn
.Linear
(32*7*7, 10)
def forward(self
,x
):
x
= self
.conv1
(x
)
x
= self
.conv2
(x
)
x
= x
.view
(x
.size
(0),-1)
output
= self
.out
(x
)
return output
5、实例化模型
cnn
= CNN
()
optimizer
= torch
.optim
.Adam
(cnn
.parameters
(), lr
=LR
)
loss_function
= nn
.CrossEntropyLoss
()
6、训练模型、评价指标
for epoch
in range(EPOCH
):
for step
,(x
,y
) in enumerate(train_loader
):
b_x
= Variable
(x
)
b_y
= Variable
(y
)
output
= cnn
(b_x
)
loss
= loss_function
(output
,y
)
optimizer
.zero_grad
()
loss
.backward
()
optimizer
.step
()
if step
%1000 == 0:
test_output
= cnn
(test_x
)
pred_y
= torch
.max(test_output
,1)[1].data
.squeeze
()
accuracy
=torch
.true_divide
(sum(pred_y
==test_y
), test_y
.size
(0))
print('Epoch:',epoch
,'|step:',step
,'|train loss:%4f' % loss
.item
(),
'test accuracy:%.4f' % accuracy
)
7、看一下预测结果
test_output
= cnn
(test_x
[:20])
pred_y
= torch
.max(test_output
,1)[1].data
.squeeze
()
print(pred_y
[:20],'prediction number')
print(test_y
[:20],'real number')
转载请注明原文地址:https://blackberry.8miu.com/read-34854.html