tensorflow之tf.where()函数

    科技2025-04-24  16

    tf.where()函数的语法格式如下:

    import tensorflow as tf tf.where( condition, x=None, y=None, name=None )

    作用:该函数的作用是根据condition,返回相对应的x或y,返回值是一个tf.bool类型的Tensor。

    例1:

    import tensorflow as tf sess=tf.Session() A =tf.where(False,123,321) >>> print(A) Tensor("Select:0", shape=(), dtype=int32) >>> print(sess.run(A)) 321 >>> B=tf.where(True,123,321) >>> print(sess.run(B)) 123 sess.close()

    例2:

    sess=tf.Session() >>> X = [["China","Henan","Changsha"], ... ["You","Love","China"]] >>> Y = [["America","Shanxi","lvliang"], ... ["I","Like","Country"]] print(sess.run(tf.where(condition_1,X,Y))) [[b'China' b'Shanxi' b'lvliang'] [b'I' b'Love' b'China']] sess.close()

    由以上两个例子我们可以清楚地看到,tf.where()的作用就是根据condition返回相对应的X 或 Y值。若condition=True,则返回对应X的值,False则返回对应的Y值。

    以上内容,如有错误,敬请批评指正!谢谢!

    Processed: 0.011, SQL: 8