【pytorch

    科技2022-07-11  104

    步骤:

    1. 准备数据

    import torch x_data = torch.Tensor([[1.0],[2.0],[3.0]]) y_data = torch.Tensor([[2.0],[4.0],[6.0]])

    2. 设计模型:

    class LinearModel(torch.nn.Module): #继承 def __init__(self): #构造函数,初始化对象 super(LinearModel,self).__init__() #调用父类 self.linear = torch.nn.Linear(1,1) #(权重,偏置) def forward(self,x): y_pred = self.linear(x) #可调用对象 return y_pred model = LinearModel() #实例化

    3. 构造Loss损失和优化

    MSE:

     

    SGD:

    criterion = torch.nn.MSELoss(size_average=False) optimizer = torch.optim.SGD(model.parameters(),lr=0.01) #lr学习率

    4. 训练

    for epoch in range(100): y_pred = model(x_data) loss = criterion(y_pred,y_data) print(epoch,loss) optimizer.zero_grad() #梯度归零 loss.backward() optimizer.step()

    完整实现过程:

    import torch x_data = torch.Tensor([[1.0],[2.0],[3.0]]) y_data = torch.Tensor([[2.0],[4.0],[6.0]]) #准备数据 class LinearModel(torch.nn.Module): #继承 def __init__(self): #构造函数,初始化对象 super(LinearModel,self).__init__() #调用父类 self.linear = torch.nn.Linear(1,1) #(权重,偏置) def forward(self,x): y_pred = self.linear(x) #可调用对象 return y_pred model = LinearModel() #实例化 criterion = torch.nn.MSELoss(size_average=False) optimizer = torch.optim.SGD(model.parameters(),lr=0.01) #lr学习率 for epoch in range(100): y_pred = model(x_data) loss = criterion(y_pred,y_data) print(epoch,loss) optimizer.zero_grad() #梯度归零 loss.backward() optimizer.step() print('w = ',model.linear.weight.item()) print('b = ',model.linear.bias.item()) x_test = torch.Tensor([[4.0]]) y_test = model(x_test) print('y_pred = ',y_test.data)

    结果参考:

    0 tensor(113.5261, grad_fn=<MseLossBackward>) 1 tensor(50.8733, grad_fn=<MseLossBackward>) 2 tensor(22.9772, grad_fn=<MseLossBackward>) 3 tensor(10.5539, grad_fn=<MseLossBackward>) 4 tensor(5.0188, grad_fn=<MseLossBackward>) 5 tensor(2.5501, grad_fn=<MseLossBackward>) 6 tensor(1.4465, grad_fn=<MseLossBackward>) 7 tensor(0.9508, grad_fn=<MseLossBackward>) 8 tensor(0.7257, grad_fn=<MseLossBackward>) 9 tensor(0.6211, grad_fn=<MseLossBackward>) 10 tensor(0.5703, grad_fn=<MseLossBackward>) 11 tensor(0.5435, grad_fn=<MseLossBackward>) 12 tensor(0.5273, grad_fn=<MseLossBackward>) 13 tensor(0.5161, grad_fn=<MseLossBackward>) 14 tensor(0.5070, grad_fn=<MseLossBackward>) 15 tensor(0.4990, grad_fn=<MseLossBackward>) 16 tensor(0.4915, grad_fn=<MseLossBackward>) 17 tensor(0.4843, grad_fn=<MseLossBackward>) 18 tensor(0.4773, grad_fn=<MseLossBackward>) 19 tensor(0.4704, grad_fn=<MseLossBackward>) 20 tensor(0.4636, grad_fn=<MseLossBackward>) 21 tensor(0.4569, grad_fn=<MseLossBackward>) 22 tensor(0.4504, grad_fn=<MseLossBackward>) 23 tensor(0.4439, grad_fn=<MseLossBackward>) 24 tensor(0.4375, grad_fn=<MseLossBackward>) 25 tensor(0.4312, grad_fn=<MseLossBackward>) 26 tensor(0.4250, grad_fn=<MseLossBackward>) 27 tensor(0.4189, grad_fn=<MseLossBackward>) 28 tensor(0.4129, grad_fn=<MseLossBackward>) 29 tensor(0.4070, grad_fn=<MseLossBackward>) 30 tensor(0.4011, grad_fn=<MseLossBackward>) 31 tensor(0.3953, grad_fn=<MseLossBackward>) 32 tensor(0.3897, grad_fn=<MseLossBackward>) 33 tensor(0.3841, grad_fn=<MseLossBackward>) 34 tensor(0.3785, grad_fn=<MseLossBackward>) 35 tensor(0.3731, grad_fn=<MseLossBackward>) 36 tensor(0.3677, grad_fn=<MseLossBackward>) 37 tensor(0.3625, grad_fn=<MseLossBackward>) 38 tensor(0.3572, grad_fn=<MseLossBackward>) 39 tensor(0.3521, grad_fn=<MseLossBackward>) 40 tensor(0.3471, grad_fn=<MseLossBackward>) 41 tensor(0.3421, grad_fn=<MseLossBackward>) 42 tensor(0.3371, grad_fn=<MseLossBackward>) 43 tensor(0.3323, grad_fn=<MseLossBackward>) 44 tensor(0.3275, grad_fn=<MseLossBackward>) 45 tensor(0.3228, grad_fn=<MseLossBackward>) 46 tensor(0.3182, grad_fn=<MseLossBackward>) 47 tensor(0.3136, grad_fn=<MseLossBackward>) 48 tensor(0.3091, grad_fn=<MseLossBackward>) 49 tensor(0.3047, grad_fn=<MseLossBackward>) 50 tensor(0.3003, grad_fn=<MseLossBackward>) 51 tensor(0.2960, grad_fn=<MseLossBackward>) 52 tensor(0.2917, grad_fn=<MseLossBackward>) 53 tensor(0.2875, grad_fn=<MseLossBackward>) 54 tensor(0.2834, grad_fn=<MseLossBackward>) 55 tensor(0.2793, grad_fn=<MseLossBackward>) 56 tensor(0.2753, grad_fn=<MseLossBackward>) 57 tensor(0.2713, grad_fn=<MseLossBackward>) 58 tensor(0.2674, grad_fn=<MseLossBackward>) 59 tensor(0.2636, grad_fn=<MseLossBackward>) 60 tensor(0.2598, grad_fn=<MseLossBackward>) 61 tensor(0.2561, grad_fn=<MseLossBackward>) 62 tensor(0.2524, grad_fn=<MseLossBackward>) 63 tensor(0.2488, grad_fn=<MseLossBackward>) 64 tensor(0.2452, grad_fn=<MseLossBackward>) 65 tensor(0.2417, grad_fn=<MseLossBackward>) 66 tensor(0.2382, grad_fn=<MseLossBackward>) 67 tensor(0.2348, grad_fn=<MseLossBackward>) 68 tensor(0.2314, grad_fn=<MseLossBackward>) 69 tensor(0.2281, grad_fn=<MseLossBackward>) 70 tensor(0.2248, grad_fn=<MseLossBackward>) 71 tensor(0.2216, grad_fn=<MseLossBackward>) 72 tensor(0.2184, grad_fn=<MseLossBackward>) 73 tensor(0.2152, grad_fn=<MseLossBackward>) 74 tensor(0.2122, grad_fn=<MseLossBackward>) 75 tensor(0.2091, grad_fn=<MseLossBackward>) 76 tensor(0.2061, grad_fn=<MseLossBackward>) 77 tensor(0.2031, grad_fn=<MseLossBackward>) 78 tensor(0.2002, grad_fn=<MseLossBackward>) 79 tensor(0.1973, grad_fn=<MseLossBackward>) 80 tensor(0.1945, grad_fn=<MseLossBackward>) 81 tensor(0.1917, grad_fn=<MseLossBackward>) 82 tensor(0.1890, grad_fn=<MseLossBackward>) 83 tensor(0.1862, grad_fn=<MseLossBackward>) 84 tensor(0.1836, grad_fn=<MseLossBackward>) 85 tensor(0.1809, grad_fn=<MseLossBackward>) 86 tensor(0.1783, grad_fn=<MseLossBackward>) 87 tensor(0.1758, grad_fn=<MseLossBackward>) 88 tensor(0.1732, grad_fn=<MseLossBackward>) 89 tensor(0.1707, grad_fn=<MseLossBackward>) 90 tensor(0.1683, grad_fn=<MseLossBackward>) 91 tensor(0.1659, grad_fn=<MseLossBackward>) 92 tensor(0.1635, grad_fn=<MseLossBackward>) 93 tensor(0.1611, grad_fn=<MseLossBackward>) 94 tensor(0.1588, grad_fn=<MseLossBackward>) 95 tensor(0.1565, grad_fn=<MseLossBackward>) 96 tensor(0.1543, grad_fn=<MseLossBackward>) 97 tensor(0.1521, grad_fn=<MseLossBackward>) 98 tensor(0.1499, grad_fn=<MseLossBackward>) 99 tensor(0.1477, grad_fn=<MseLossBackward>) w = 1.7441235780715942 b = 0.5816670656204224 y_pred = tensor([[7.5582]])

     

    Processed: 0.058, SQL: 8