11 tensorflow 张量排序
程序员文章站
2024-03-11 17:07:07
...
sort
argsort
top_k
# 张量排序
import tensorflow as tf
a = tf.random.shuffle(tf.range(5))
print(a.numpy())
# out:[3 4 0 1 2]
# 升序排列
print(tf.sort(a).numpy())
# out:[0 1 2 3 4]
# 降序排列
print(tf.sort(a, direction="DESCENDING").numpy())
# out:[4 3 2 1 0]
# 根据排序后的序列找到排序前元素的索引号
idx = tf.argsort(a, direction="DESCENDING")
print(idx)
# out:tf.Tensor([1 0 4 3 2], shape=(5,), dtype=int32)
# 利用索引号还原序列
print(tf.gather(a, idx))
# out:tf.Tensor([4 3 2 1 0], shape=(5,), dtype=int32)
# Top_k:返回最大的若干个值, 常用于Top_k Accuracy(预测值较大的多个可能)
res = tf.math.top_k(a, 2)
print(res.values) # 值
# out:tf.Tensor([4 3], shape=(2,), dtype=int32)
print(res.indices) # 索引
# out:tf.Tensor([1 0], shape=(2,), dtype=int32)