tensorFlow之模型定义、训练、保存

    科技2025-08-13  9

    tensorFlow之模型定义、训练、保存

    本例说明如何使用 tensorflow 定义模型、训练、保存, 详细请看官方例子

    import tensorflow as tf import matplotlib.pyplot as plt tf.__version__ '2.1.0' # The actual line TRUE_W = 3.0 TRUE_B = 2.0 NUM_EXAMPLES = 1000 # A vector of random x values x = tf.random.normal(shape=[NUM_EXAMPLES]) # Generate some noise noise = tf.random.normal(shape=[NUM_EXAMPLES]) # Calculate y y = x * TRUE_W + TRUE_B + noise plt.scatter(x, y, c="b") <matplotlib.collections.PathCollection at 0x23c94c5f888>

    class MyModel(tf.Module): def __init__(self,**kwargs): super().__init__(**kwargs) self.w=tf.Variable(5.0) self.b=tf.Variable(0.0) def __call__(self,x): return self.w*x+self.b model=MyModel() print("Variables:",model.variables) Variables: (<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=0.0>, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=5.0>) model(3.0).numpy() 15.0 def loss(target_y,predicted_y): return tf.reduce_mean(tf.square(target_y-predicted_y)) plt.scatter(x,y,c='b') <matplotlib.collections.PathCollection at 0x23c95d16108>

    plt.scatter(x,y) plt.scatter(x,model(x),c='r') <matplotlib.collections.PathCollection at 0x23c95d81908>

    print("Current loss: %1.6f" % loss(model(x), y).numpy()) Current loss: 9.385428 def train(model,x,y,learning_rate): with tf.GradientTape() as t: current_loss = loss(y, model(x)) dw,db=t.gradient(current_loss,[model.w,model.b]) model.w.assign_sub(learning_rate*dw) model.b.assign_sub(learning_rate*db) model=MyModel() Ws,bs,losses=[],[],[] epochs=range(10) def trainging_loop(model,x,y): for epoch in epochs: train(model,x,y,0.1) Ws.append(model.w.numpy()) bs.append(model.b.numpy()) current_loss=loss(y,model(x)).numpy() losses.append(current_loss) print("Epoch %2d, W: %1f,b: %1.2f, current loss: %2.2f" % (epoch,Ws[-1],bs[-1],current_loss)) # print("Epoch {}, w: {}, b{},current loss :{}".format(epoch,Ws[-1],bs[-1],current_loss)) trainging_loop(model,x,y) Epoch 0, W: 4.576263,b: 0.42, current loss: 6.23 Epoch 1, W: 4.243350,b: 0.75, current loss: 4.27 Epoch 2, W: 3.981811,b: 1.01, current loss: 3.05 Epoch 3, W: 3.776360,b: 1.21, current loss: 2.29 Epoch 4, W: 3.614980,b: 1.37, current loss: 1.83 Epoch 5, W: 3.488228,b: 1.50, current loss: 1.53 Epoch 6, W: 3.388681,b: 1.61, current loss: 1.35 Epoch 7, W: 3.310508,b: 1.69, current loss: 1.24 Epoch 8, W: 3.249123,b: 1.75, current loss: 1.17 Epoch 9, W: 3.200926,b: 1.80, current loss: 1.13 plt.scatter(x,y,c='b') plt.scatter(x,model(x),c='r') <matplotlib.collections.PathCollection at 0x23c95e00f48>

    plt.plot(Ws,label='w') plt.plot(bs,label='b') plt.legend() <matplotlib.legend.Legend at 0x23c95e71fc8>

    class MyModelKeras(tf.keras.Model): def __init__(self,**kwargs): super().__init__(**kwargs) self.w=tf.Variable(5.0) self.b=tf.Variable(0.0) def __call__(self,x,**kwargs): return self.w*x+self.b keras_model=MyModelKeras() trainging_loop(keras_model,x,y) keras_model.save_weights("my_checkpoint") Epoch 0, W: 4.576263,b: 0.42, current loss: 6.23 Epoch 1, W: 4.243350,b: 0.75, current loss: 4.27 Epoch 2, W: 3.981811,b: 1.01, current loss: 3.05 Epoch 3, W: 3.776360,b: 1.21, current loss: 2.29 Epoch 4, W: 3.614980,b: 1.37, current loss: 1.83 Epoch 5, W: 3.488228,b: 1.50, current loss: 1.53 Epoch 6, W: 3.388681,b: 1.61, current loss: 1.35 Epoch 7, W: 3.310508,b: 1.69, current loss: 1.24 Epoch 8, W: 3.249123,b: 1.75, current loss: 1.17 Epoch 9, W: 3.200926,b: 1.80, current loss: 1.13 keras_model = MyModelKeras() keras_model.compile(run_eagerly=False,optimizer=tf.keras.optimizers.SGD(learning_rate=0.1),loss=tf.keras.losses.mean_squared_error) keras_model.fit(x, y, epochs=10, batch_size=1000) Train on 1000 samples Epoch 1/10 1000/1000 [==============================] - 0s 183us/sample - loss: 9.3854 Epoch 2/10 1000/1000 [==============================] - 0s 1us/sample - loss: 6.2282 Epoch 3/10 1000/1000 [==============================] - 0s 1us/sample - loss: 4.2678 Epoch 4/10 1000/1000 [==============================] - 0s 1us/sample - loss: 3.0505 Epoch 5/10 1000/1000 [==============================] - 0s 1us/sample - loss: 2.2946 Epoch 6/10 1000/1000 [==============================] - 0s 1us/sample - loss: 1.8252 Epoch 7/10 1000/1000 [==============================] - 0s 1us/sample - loss: 1.5336 Epoch 8/10 1000/1000 [==============================] - 0s 2us/sample - loss: 1.3526 Epoch 9/10 1000/1000 [==============================] - 0s 1us/sample - loss: 1.2401 Epoch 10/10 1000/1000 [==============================] - 0s 1us/sample - loss: 1.1703 <tensorflow.python.keras.callbacks.History at 0x23c95fa9e08>
    Processed: 0.009, SQL: 8