接上一篇
Pytorch框架下的语义分割实战(网络搭建)
今天我们看一看Z博是怎么训练网络的。。。
附上代码
''' train.py ''' from datetime import datetime import numpy as np import torch import torch.nn as nn import torch.optim as optim from dataset import test_dataloader, train_dataloader from model import FCNs,VGGNet #我将FCN文件名改成了model,所以是from model import... #同样将BagData文件名改成了dataset,所以是from dataset import... device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') #device=torch.device("cpu") #使用cpu #device=torch.device("cuda") #使用GPU def train(epo_num=50, show_vgg_params=False): #vis = visdom.Visdom() #pytorch中的可视化工具 vgg_model = VGGNet(requires_grad=True, show_params=show_vgg_params) fcn_model = FCNs(pretrained_net=vgg_model, n_class=2) #将模型加载到指定设备上 fcn_model = fcn_model.to(device) criterion = nn.BCELoss().to(device) optimizer = optim.SGD(fcn_model.parameters(), lr=1e-2, momentum=0.7) all_train_iter_loss = [] all_test_iter_loss = [] #计算时间 prev_time = datetime.now() for epo in range(epo_num): train_loss = 0 fcn_model.train() for index, (bag, bag_msk) in enumerate(train_dataloader): # bag.shape is torch.Size([4, 3, 160, 160]) # bag_msk.shape is torch.Size([4, 2, 160, 160]) bag = bag.to(device) bag_msk = bag_msk.to(device) optimizer.zero_grad() output = fcn_model(bag) output = torch.sigmoid(output) # output.shape is torch.Size([4, 2, 160, 160]) # print(output) # print(bag_msk) loss = criterion(output, bag_msk) loss.backward() iter_loss = loss.item() all_train_iter_loss.append(iter_loss) train_loss += iter_loss optimizer.step() output_np = output.cpu().detach().numpy().copy() # output_np.shape = (4, 2, 160, 160) output_np = np.argmin(output_np, axis=1) bag_msk_np = bag_msk.cpu().detach().numpy().copy() # bag_msk_np.shape = (4, 2, 160, 160) bag_msk_np = np.argmin(bag_msk_np, axis=1) test_loss = 0 fcn_model.eval() with torch.no_grad(): for index, (bag, bag_msk) in enumerate(test_dataloader): bag = bag.to(device) bag_msk = bag_msk.to(device) optimizer.zero_grad() output = fcn_model(bag) output = torch.sigmoid(output) # output.shape is torch.Size([4, 2, 160, 160]) loss = criterion(output, bag_msk) #预测和原标签图的差 iter_loss = loss.item() #item得到一个元素张量里面的元素值,一般用于返回loss,acc all_test_iter_loss.append(iter_loss) test_loss += iter_loss output_np = output.cpu().detach().numpy().copy() # output_np.shape = (4, 2, 160, 160) output_np = np.argmin(output_np, axis=1) bag_msk_np = bag_msk.cpu().detach().numpy().copy() # bag_msk_np.shape = (4, 2, 160, 160) bag_msk_np = np.argmin(bag_msk_np, axis=1) cur_time = datetime.now() h, remainder = divmod((cur_time - prev_time).seconds, 3600) m, s = divmod(remainder, 60) time_str = "Time %02d:%02d:%02d" % (h, m, s) prev_time = cur_time print('epoch train loss = %f, epoch test loss = %f, %s' %(train_loss/len(train_dataloader), test_loss/len(test_dataloader), time_str)) if np.mod(epo, 5) == 0: torch.save(fcn_model, '/Bags/results/fcn_model_{}.pt'.format(epo)) print('fcn_model_{}.pt'.format(epo)) if __name__ == "__main__": train(epo_num=100, show_vgg_params=False)pytorch中nn.CrossEntropyLoss为交叉熵损失函数,用于解决多分类、二分类问题。 BCELoss是Binary CrossEntropyLoss缩写,nn.BCELoss()是二元交叉熵损失函数,只能解决二分类问题。 某博主说使用此函数时,前面需要加上Sigmoid函数,加上nn.Sigmoid()语句即可。
数学公式为Loss = -w*[p*log(q) + (1-p)*log(1-q)],其中p、q分别为理论标签、实际预测值,w为权重。这里的log对应数学上的ln。
参数说明:torch.nn.BCELoss(weight=None,size_average=True,reduce=True,reduction='mean') ①weight必须和target的shape一致,默认为None。 ②默认情况下reduce=True,size_average=True ③若reduce=False,size_average无效,返回值为向量形式的loss ④若reduce=True,size_average=True,返回loss的均值,即loss.mean() ⑤若reduce=True,size_average=False,返回loss的和,即loss.sum() ⑥若reduction=‘none’,返回向量形式的loss ⑦若reduction=‘sum’,返回loss的和 ⑧若reduction=‘elementwise_mean’,返回loss的平均值 ⑨若reduction=‘mean’,返回loss的平均值
下面用代码展示一下此函数的作用吧...
import torch import torch.nn as nn m = nn.Sigmoid() loss = nn.BCELoss(size_average=False, reduce=False) input = torch.randn(3, requires_grad=True) #empty()是构造一个张量函数,构造了一个1*3的张量,取值范围为[0,2-1],用来初始化target #如果函数有下标'_',表示此函数是Tensor中的内建函数 target = torch.empty(3).random_(2) lossinput = m(input) output = loss(lossinput, target) #print(input) print(' The value of input is', lossinput) print('\n The value of output target is', target) print('\n The value of loss is', output) ''' Out: The value of input is tensor([0.4625, 0.6193, 0.3062], grad_fn=<SigmoidBackward>) The value of output target is tensor([1., 0., 0.]) The value of loss is tensor([0.7712, 0.9658, 0.3656], grad_fn=<BinaryCrossEntropyBackward>) '''SGD的全称是Stochastic Gradient Descent(随机梯度下降)。 梯度下降法主要有三种:批量、随机、小批量 ①批量梯度下降法:整个训练数据集计算梯度 ②随机梯度下降法:随机取一个样本计算梯度 ③小批量梯度下降法:选取少数(即batch_size的大小)样本组成一个小批量样本,用这个小批量样本计算梯度。
fcn_model.parameters():获取fcn_model网络中的参数。搭建好神经网络后,网络的参数都保存在parameters()函数中。
learning rate(lr):学习率较小时,收敛到极值的速度较慢。学习率较大时,容易在收敛过程中发生震荡。
momentum:物理学中指力对时间的累积。每次x的更新量v=-dx*lr,其中dx是目标函数func(x)对x的一阶导。 当本次梯度下降-dx*lr的方向与上次更新量v的方向相同时,上次的更新量能够对本次的搜索起到一个正向加速的作用。 当本次梯度下降-dx*lr的方向与上次更新量v的方向相反时,上次的更新量能够对本次的搜索起到一个减速的作用。
此函数用于一个可遍历的数据对象,如list,tuple等。返回值是数据和对应的下标,一般用在for循环中。 下面看代码吧...
seasons = ['Spring', 'Summer', 'Fall', 'Winter'] list(enumerate(seasons)) ''' Out: [(0, 'Spring'), (1, 'Summer'), (2, 'Fall'), (3, 'Winter')] ''' list(enumerate(seasons, start=1)) ''' Out: [(1, 'Spring'), (2, 'Summer'), (3, 'Fall'), (4, 'Winter')] ''' for s in enumerate(seasons): print(s) print('\n') for i,s in enumerate(seasons): print(s,i) print('\n') for i,s in enumerate(seasons): print(i,s) ''' Out: (0, 'Spring') (1, 'Summer') (2, 'Fall') (3, 'Winter') Spring 0 Summer 1 Fall 2 Winter 3 0 Spring 1 Summer 2 Fall 3 Winter '''.cpu():将数据移至cpu中
.detach():返回一个新的Variable,从当前计算图中分离下来,但是仍指向原变量的存放位置,只是没有梯度,即使后期令requi_grad=True,也不具有梯度。后面进行反向传播时,需要调用detach()的Variable就会停止,不会再继续向前传播。 参考网址:https://blog.csdn.net/weixin_33913332/article/details/93300411
.numpy():将tensor转为numpy数据
.copy():复制
该函数时给出一组数据中最小值的下标。
import numpy as np data = [1,4,2,6,0,0] print(np.argmin(data)) data1 = [[0,2,3],[1,2,3],[4,3,2]] #默认将列表展平,显示最小值的下标 print(np.argmin(data1)) print(np.argmin(data1, axis=1)) data2 = [[[1,2]],[[3,0]],[[4,9]]] print(np.argmin(data2)) ''' Out: 4 0 [0 0 2] 3 '''此函数把除数和余数运算结果结合起来,返回一个包含商和余数的元组。
print('a:', divmod(7,2)) print('b:', divmod(8,2)) ''' Out: a: (3, 1) b: (4, 0) '''好了,到这里就先告一段落了,有不对的地方,欢迎留言指正哦....
下个项目见(*๓´╰╯`๓)