tf.where(cond)–返回元素为True的坐标 tf.where(cond,A,B)表示根据cond对A,B进行筛选
例子
import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' import tensorflow as tf import matplotlib.pyplot as plt def fun(x): """ :param x: [b,2] :return: """ z = tf.math.sin(x[...,0]) + tf.math.sin(x[...,1]) return z if __name__ == '__main__': x = tf.linspace(0.,2*3.14,500) y = tf.linspace(0.,2*3.14,500) # [500,500] point_x, point_y = tf.meshgrid(x,y) # [500,500,2] points = tf.stack([point_x,point_y],axis=2) print('points:',points.shape) z = fun(points) print('z:',z.shape) plt.figure('plot 2d func value') plt.imshow(z,origin='lower',interpolation='none') plt.colorbar() plt.figure('plot 2d func contour') # 画出等高线 plt.contour(point_x,point_y,z) plt.colorbar() plt.show() points: (500, 500, 2) z: (500, 500)
