深度学习2.0-5.tensorflow的基础操作之前向传播(张量)实战

    科技2024-06-20  71

    文章目录

    前向传播(张量)实战

    前向传播(张量)实战

    # 将无关信息屏蔽掉 import os # 取值有四个:0,1,2,3,分别对应INFO,WARNING,ERROR,FATAL os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' import tensorflow as tf from tensorflow import keras from tensorflow.keras import datasets # 加载数据集 x:[60k,28,28] y:[60k] (x, y), _ = datasets.mnist.load_data() # x: [0~255] => [0~1.] x = tf.convert_to_tensor(x, dtype=tf.float32) / 255 y = tf.convert_to_tensor(y, dtype=tf.int32) print(x.shape, y.shape, x.dtype, y.dtype) # 查看数据范围 print(tf.reduce_min(x), tf.reduce_max(x)) print(tf.reduce_min(y), tf.reduce_max(y)) # from_tensor_slices表示切分传入Tensor的第一个维度,生成相应的dataset train_db = tf.data.Dataset.from_tensor_slices((x,y)).batch(128) train_iter = iter(train_db) sample = next(train_iter) print('batch:', sample[0].shape, sample[1].shape) # batch: (128, 28, 28) (128,) # 构建权值 # X: [b,784] => [b,256] => [b,128] => [b,10] # W: [dim_in,dim_out] b: [dim_out] # truncated_normal表示正态分布,设置为均值为0,方差为0.1的范围,可解决梯度爆炸与梯度离散的情况 w1 = tf.Variable(tf.random.truncated_normal([784, 256], stddev=0.1)) b1 = tf.Variable(tf.zeros([256])) w2 = tf.Variable(tf.random.truncated_normal([256, 128], stddev=0.1)) b2 = tf.Variable(tf.zeros([128])) w3 = tf.Variable(tf.random.truncated_normal([128, 10], stddev=0.1)) b3 = tf.Variable(tf.zeros([10])) lr = 1e-3 for epoch in range(10): # 迭代整个数据集 for step, (x, y) in enumerate(train_db): # for every batch # x:[128,28,28] y:[128] # [b,28,28] => [b,28*28] x = tf.reshape(x, [-1, 28*28]) # 求梯度 # GradientTape默认只会跟踪类型为tf.Variable with tf.GradientTape() as tape: # x: [b, 28*28] # h1 = x@w1 + b1 # [b, 784]@[784, 256] + [256] => [b, 256] + [256] => [b, 256] + [b, 256] h1 = x @ w1 + tf.broadcast_to(b1, [x.shape[0], 256]) h1 = tf.nn.relu(h1) # [b, 256] => [b, 128] h2 = h1 @ w2 + b2 h2 = tf.nn.relu(h2) # [b, 128] => [b, 10] out = h2 @ w3 + b3 # compute loss # out: [b, 10] # y: [b] => [b, 10] y_onehot = tf.one_hot(y, depth=10) # mse = mean(sum(y-out)^2) # [b, 10] loss = tf.square(y_onehot - out) # mean: scalar loss = tf.reduce_mean(loss) # 计算梯度,返回对应的梯度列表 grads = tape.gradient(loss, [w1, b1, w2, b2, w3, b3]) # print(grads) # 更新可训练参数 # w1.assign_sub表示原地更新,保持tf.Variable w1.assign_sub(lr * grads[0]) b1.assign_sub(lr * grads[1]) w2.assign_sub(lr * grads[2]) b2.assign_sub(lr * grads[3]) w3.assign_sub(lr * grads[4]) b3.assign_sub(lr * grads[5]) if step % 100 == 0: print(epoch,step, 'loss', float(loss))

    运算结果:

    (60000, 28, 28) (60000,) <dtype: 'float32'> <dtype: 'int32'> tf.Tensor(0.0, shape=(), dtype=float32) tf.Tensor(1.0, shape=(), dtype=float32) tf.Tensor(0, shape=(), dtype=int32) tf.Tensor(9, shape=(), dtype=int32) batch: (128, 28, 28) (128,) 0 0 loss 0.3418075442314148 0 100 loss 0.20874373614788055 0 200 loss 0.19508568942546844 0 300 loss 0.17686240375041962 0 400 loss 0.17524364590644836 1 0 loss 0.1492575705051422 1 100 loss 0.15576913952827454 1 200 loss 0.15462367236614227 1 300 loss 0.14648546278476715 1 400 loss 0.14724934101104736 2 0 loss 0.12436133623123169 2 100 loss 0.13543009757995605 2 200 loss 0.13396985828876495 2 300 loss 0.1286272406578064 2 400 loss 0.13028298318386078 3 0 loss 0.10929928719997406 3 100 loss 0.12242629379034042 3 200 loss 0.12099750339984894 3 300 loss 0.11676524579524994 3 400 loss 0.11878874152898788 4 0 loss 0.09928615391254425 4 100 loss 0.11343791335821152 4 200 loss 0.11198092997074127 4 300 loss 0.10829611867666245 4 400 loss 0.11033062636852264 5 0 loss 0.09216360747814178 5 100 loss 0.1067122370004654 5 200 loss 0.10542497783899307 5 300 loss 0.10182454437017441 5 400 loss 0.1037207618355751 6 0 loss 0.08672191202640533 6 100 loss 0.10142513364553452 6 200 loss 0.10031310468912125 6 300 loss 0.09669476002454758 6 400 loss 0.09842309355735779 7 0 loss 0.0824439525604248 7 100 loss 0.09712380170822144 7 200 loss 0.09613628685474396 7 300 loss 0.09249486774206161 7 400 loss 0.09409850835800171 8 0 loss 0.07895340770483017 8 100 loss 0.09356103837490082 8 200 loss 0.09263540804386139 8 300 loss 0.08899327367544174 8 400 loss 0.09041593968868256 9 0 loss 0.07604382187128067 9 100 loss 0.09052439779043198 9 200 loss 0.08964134752750397 9 300 loss 0.08601444214582443 9 400 loss 0.08727019280195236
    Processed: 0.009, SQL: 8