TensorFlow中Tensor的shape概念与tf ops:tf.reshape
田海立@CSDN 2020-10-17
《图解NCHW与NHWC数据格式》中从逻辑表达和物理存储角度用图的方式讲述了NHWC与NCHW两种数据格式,数据shape是可以改变的,本文介绍TensorFlow里Tensor的Shape概念,并用图示和程序阐述了reshape运算。
一、TensorFlow中Tensor的Shape
TensorFlow中的数据都是由Tensor来表示,Shape相关有下列一些概念:
- Rank:维数
- Dimension:表达每一维长度
- Size:所有的Dimension数值相乘,也就是Tensor里数据元素的尺寸了
rank为0/1/2的典型Tensor如下图所示:
Tensor rank为3时,数据表达为:
二、Tensor的逻辑表达与物理存储
如《图解NCHW与NHWC数据格式》中所述,数据可以从逻辑上和物理排布上去理解。而本文第一节中你可以仍从逻辑上去理解,还未牵涉到物理存储数据排布。
三维以下的比较容易理解,各个ML框架之间也没大的区别,对于三维(及以上)Tensor的排布就很不同了,这里着重介绍3-D。
我们已经知道TensorFlow的Tensor缺省是NHWC的,对于上面的shape(3, 2, 5)的Tensor【n为1】,在TensorFlow中应该是这样的:
如果数据值按顺序排布如下,
[[[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9]],
[[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19]],
[[20, 21, 22, 23, 24],
[25, 26, 27, 28, 29]]]
那么对应上面三维立方体的摆放应该如下:
三、tf.reshape()运算
reshape原型如下:
tf.reshape(
tensor, shape, name=None
)
3.1 tf.reshape()的不改变性
- tf.reshape()运算不改变数据的物理摆布,也就是说一个Tensor reshape到别的shape只是逻辑上shape的改变,存储的数据不会改变;
- tf.reshape()也就不会改变Tensor的size。指定的新的shape的size如果与原Tensor的size不一致就会报错。比如上面Shape(3, 2, 5)的Tensor就没法reshape成7x?。
有了上面两个原则,tf.reshape()运算就很容易理解了,物理存储不变,就看rank以及各个dimension怎么取了。
3.2 tf.reshape() 图示
比如,上面Tensor有30个数:从0~29顺序存储。可以存储为(3, 2, 5)【上面介绍过的3-D】,也可以存储为2D的(3, 10)或(6, 5),等。
3.3 程序实现如下:
TF2.0以后的版本上,直接可以执行,而不用还要在session下执行。当然前提是已经
import tensorflow as tf
1. 30个数的数据
>>> t = tf.range(30)
>>> t
<tf.Tensor: shape=(30,), dtype=int32, numpy=
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29], dtype=int32)>
>>>
2. shape(3, 2, 5)
>>> t = tf.reshape(t, [3,2,5])
>>> t
<tf.Tensor: shape=(3, 2, 5), dtype=int32, numpy=
array([[[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9]],
[[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19]],
[[20, 21, 22, 23, 24],
[25, 26, 27, 28, 29]]], dtype=int32)>
>>>
3. shape(3, 10)
>>> t = tf.reshape(t, [3, 10])
>>> t
<tf.Tensor: shape=(3, 10), dtype=int32, numpy=
array([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
[20, 21, 22, 23, 24, 25, 26, 27, 28, 29]], dtype=int32)>
>>>
4. shape(6, 5)
>>> t = tf.reshape(t, [6, 5])
>>> t
<tf.Tensor: shape=(6, 5), dtype=int32, numpy=
array([[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19],
[20, 21, 22, 23, 24],
[25, 26, 27, 28, 29]], dtype=int32)>
>>>
四、小结
本文介绍了TensorFlow里Tensor的Shape概念,并用图示和实际程序解释了reshape的变化。
上一篇: 帆软报表平台的使用方法 数据分析