欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

11 tensorflow 张量排序

程序员文章站 2024-03-11 17:07:07
...

张量排序

sort

argsort

top_k

# 张量排序
import tensorflow as tf

a = tf.random.shuffle(tf.range(5))
print(a.numpy())
# out:[3 4 0 1 2]

# 升序排列
print(tf.sort(a).numpy())
# out:[0 1 2 3 4]
# 降序排列
print(tf.sort(a, direction="DESCENDING").numpy())
# out:[4 3 2 1 0]

# 根据排序后的序列找到排序前元素的索引号
idx = tf.argsort(a, direction="DESCENDING")
print(idx)
# out:tf.Tensor([1 0 4 3 2], shape=(5,), dtype=int32)

# 利用索引号还原序列
print(tf.gather(a, idx))
# out:tf.Tensor([4 3 2 1 0], shape=(5,), dtype=int32)

# Top_k:返回最大的若干个值, 常用于Top_k Accuracy(预测值较大的多个可能)
res = tf.math.top_k(a, 2)
print(res.values)  # 值
# out:tf.Tensor([4 3], shape=(2,), dtype=int32)
print(res.indices)  # 索引
# out:tf.Tensor([1 0], shape=(2,), dtype=int32)