欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

10 tensorflow tf.gather 和 tf.gather_nd

程序员文章站 2024-03-11 18:08:19
...

tf.gather和tf.gather_nd


tf.gather和tf.gather_nd理解起来比较麻烦,这里以一个具体场景来说明。
学生成绩:shape=[4,35,8]
具体意义:四个班 0,1,2,3 , 每班35个学生序号0-34,8门课程的成绩课程号0-7 (为了容易表示都从0开始)。

tf.gather

gather就是收集的意思,最简单的用法:

a = tf.random.normal([4, 35, 8])
print(tf.gather(a, axis=0, indices=[2, 3]).shape)
# out:(2, 35, 8)

输出2和3班的成绩。类似于:

print(a[2:4].shape)

区别在于其可以按indices给定的顺序收集数据,而不是只能按索引顺序获取数据。

print(tf.gather(a, axis=0, indices=[2, 1, 3, 0]).shape)  # 按indices收集数据
# out:(4, 35, 8)

成绩的顺序是:2班,1班,3班,0班。
同样:

print(tf.gather(a, axis=1, indices=[2, 3, 7, 9, 16]).shape)
# out:(4, 5, 8)

四个班里2,3,7,9,16学生的成绩。
tf.gather的功能:收集数据和变换顺序

tf.gather_nd

如何获得0班序号为0同学,1班序号为1同学,2班序号为2的数据,3班序号为3的成绩?这就需要用到 tf.gather_nd。
先简单的了解一下。

print(tf.gather_nd(a, [0]).shape)  # 0班的成绩
print(tf.gather_nd(a, [0, 1]).shape)  # 0班1号同学的成绩
print(tf.gather_nd(a, [0, 1, 2]).shape)  # 0班1号同学课程号为2的成绩
print(tf.gather_nd(a, [[0, 1, 2]]).shape)  # 0班1号同学课程号为2的成绩,并放在一个列表中。

输出:

(35, 8) 
(8,)
()
(1,)

具体问题:获得0班序号为0同学,1班序号为1同学,2班序号为2的数据,3班序号为3的成绩。

print(tf.gather_nd(a, [[0, 0], [1, 1], [2, 2], [3, 3]]).shape)
# out:(4, 8)

再加上课程:

print(tf.gather_nd(a, [[0, 0, 0], [1, 2, 3], [2, 2, 2], [3, 4, 4]]).shape)
# out:(4,)

表示:0班序号为0同学的课程号为0的成绩,1班序号为2同学的课程号为3的成绩,2班序号为2同学的课程号为2的成绩,3班序号为4同学的课程号为4的成绩。
同样可以将收集到的数据放入一个张量中。

print(tf.gather_nd(a, [[[0, 0, 0], [1, 2, 3], [2, 2, 2], [3, 4, 4]]]).shape)
# out:(1, 4)