10 tensorflow tf.gather 和 tf.gather_nd
程序员文章站
2024-03-11 18:08:19
...
tf.gather和tf.gather_nd
tf.gather和tf.gather_nd理解起来比较麻烦,这里以一个具体场景来说明。
学生成绩:shape=[4,35,8]
具体意义:四个班 0,1,2,3 , 每班35个学生序号0-34,8门课程的成绩课程号0-7 (为了容易表示都从0开始)。
tf.gather
gather就是收集的意思,最简单的用法:
a = tf.random.normal([4, 35, 8])
print(tf.gather(a, axis=0, indices=[2, 3]).shape)
# out:(2, 35, 8)
输出2和3班的成绩。类似于:
print(a[2:4].shape)
区别在于其可以按indices给定的顺序收集数据,而不是只能按索引顺序获取数据。
print(tf.gather(a, axis=0, indices=[2, 1, 3, 0]).shape) # 按indices收集数据
# out:(4, 35, 8)
成绩的顺序是:2班,1班,3班,0班。
同样:
print(tf.gather(a, axis=1, indices=[2, 3, 7, 9, 16]).shape)
# out:(4, 5, 8)
四个班里2,3,7,9,16学生的成绩。
tf.gather的功能:收集数据和变换顺序
tf.gather_nd
如何获得0班序号为0同学,1班序号为1同学,2班序号为2的数据,3班序号为3的成绩?这就需要用到 tf.gather_nd。
先简单的了解一下。
print(tf.gather_nd(a, [0]).shape) # 0班的成绩
print(tf.gather_nd(a, [0, 1]).shape) # 0班1号同学的成绩
print(tf.gather_nd(a, [0, 1, 2]).shape) # 0班1号同学课程号为2的成绩
print(tf.gather_nd(a, [[0, 1, 2]]).shape) # 0班1号同学课程号为2的成绩,并放在一个列表中。
输出:
(35, 8)
(8,)
()
(1,)
具体问题:获得0班序号为0同学,1班序号为1同学,2班序号为2的数据,3班序号为3的成绩。
print(tf.gather_nd(a, [[0, 0], [1, 1], [2, 2], [3, 3]]).shape)
# out:(4, 8)
再加上课程:
print(tf.gather_nd(a, [[0, 0, 0], [1, 2, 3], [2, 2, 2], [3, 4, 4]]).shape)
# out:(4,)
表示:0班序号为0同学的课程号为0的成绩,1班序号为2同学的课程号为3的成绩,2班序号为2同学的课程号为2的成绩,3班序号为4同学的课程号为4的成绩。
同样可以将收集到的数据放入一个张量中。
print(tf.gather_nd(a, [[[0, 0, 0], [1, 2, 3], [2, 2, 2], [3, 4, 4]]]).shape)
# out:(1, 4)
推荐阅读
-
10 tensorflow tf.gather 和 tf.gather_nd
-
win10下python3.5.2和tensorflow安装环境搭建教程
-
win10下tensorflow和matplotlib安装教程
-
TensorFlow在Win10上的安装注意和步骤
-
tensorflow中tf.slice和tf.gather切片函数的使用
-
tensorflow中tf.slice和tf.gather切片函数的使用
-
Win10 安装Anaconda、Pycharm、Tensorflow和Pytorch
-
〖 tensorflow2.0笔记10〗:全连接层和输出方式!
-
win10下python3.5.2和tensorflow安装环境搭建教程
-
win10下tensorflow和matplotlib安装教程