深度学习2.08.tensorflow的高阶操作之张量排序

    科技2025-07-18  7

    文章目录

    张量的排序1.tf.sort-排序/tf.argsort-排序并返回索引2.tf.math.top_k-最大值的前几个3.Top-k-Accuracy-预测准确度

    张量的排序

    1.tf.sort-排序/tf.argsort-排序并返回索引

    2.tf.math.top_k-最大值的前几个

    3.Top-k-Accuracy-预测准确度

    # 将无关信息屏蔽掉 import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' import tensorflow as tf tf.random.set_seed(2467) # output->[b,n] target->[b,] def accuracy(output,target,topk=(1,)): maxk = max(topk) batch_size = target.shape[0] # 返回最大值前maxk个的索引 pred = tf.math.top_k(output,maxk).indices # 转置 pred = tf.transpose(pred,perm=[1,0]) # 将target广播成pred形状 target_ = tf.broadcast_to(target,pred.shape) # 比较 correct = tf.equal(pred,target_) res = [] for k in topk: correct_k = tf.cast(tf.reshape(correct[:k],[-1]),dtype=tf.float32) # print('123=',correct_k) correct_k = tf.reduce_sum(correct_k) acc = float(correct_k / batch_size) res.append(acc) return res if __name__ == '__main__': # 正态分布 output = tf.random.normal([10,6]) # 使6类概率总和为1 output = tf.math.softmax(output,axis=1) # 均匀分布 target = tf.random.uniform([10],maxval=6,dtype=tf.int32) print('prob:',output.numpy()) pred = tf.argmax(output,axis=1) print('pred:',pred.numpy()) print('label:',target.numpy()) acc = accuracy(output,target,topk=(1,2,3,4,5,6)) print('top-1-6 acc:',acc) prob: [[0.25310278 0.21715644 0.16043882 0.13088997 0.04334083 0.19507109] [0.05892418 0.04548917 0.00926314 0.14529602 0.66777605 0.07325139] [0.09742808 0.08304427 0.07460099 0.04067177 0.626185 0.07806987] [0.20478569 0.12294924 0.12010485 0.13751231 0.36418733 0.05046057] [0.11872064 0.31072393 0.12530336 0.1552888 0.2132587 0.07670452] [0.01519807 0.09672114 0.1460476 0.00934331 0.5649092 0.16778067] [0.04199061 0.18141054 0.06647632 0.6006175 0.03198383 0.07752118] [0.09226219 0.2346089 0.13022321 0.16295874 0.05362028 0.3263266 ] [0.07019574 0.0861177 0.10912605 0.10521299 0.2152082 0.4141393 ] [0.01882887 0.26597694 0.19122466 0.24109262 0.14920162 0.13367532]] pred: [0 4 4 4 1 4 3 5 5 1] label: [0 2 3 4 2 4 2 3 5 5] top-1-6 acc: [0.4000000059604645, 0.4000000059604645, 0.5, 0.699999988079071, 0.800000011920929, 1.0]
    Processed: 0.016, SQL: 8