欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

TensorFlow中Tensor的shape概念与tf ops:tf.reshape

程序员文章站 2022-05-31 19:01:37
...

田海立@CSDN 2020-10-17

图解NCHW与NHWC数据格式》中从逻辑表达和物理存储角度用图的方式讲述了NHWC与NCHW两种数据格式,数据shape是可以改变的,本文介绍TensorFlow里Tensor的Shape概念,并用图示和程序阐述了reshape运算。

 

一、TensorFlow中Tensor的Shape

TensorFlow中的数据都是由Tensor来表示,Shape相关有下列一些概念:

  • Rank:维数
  • Dimension:表达每一维长度
  • Size:所有的Dimension数值相乘,也就是Tensor里数据元素的尺寸了

rank为0/1/2的典型Tensor如下图所示:

TensorFlow中Tensor的shape概念与tf ops:tf.reshape

 

Tensor rank为3时,数据表达为:

TensorFlow中Tensor的shape概念与tf ops:tf.reshape

 

二、Tensor的逻辑表达与物理存储

如《图解NCHW与NHWC数据格式》中所述,数据可以从逻辑上和物理排布上去理解。而本文第一节中你可以仍从逻辑上去理解,还未牵涉到物理存储数据排布。

三维以下的比较容易理解,各个ML框架之间也没大的区别,对于三维(及以上)Tensor的排布就很不同了,这里着重介绍3-D。

我们已经知道TensorFlow的Tensor缺省是NHWC的,对于上面的shape(3, 2, 5)的Tensor【n为1】,在TensorFlow中应该是这样的:

TensorFlow中Tensor的shape概念与tf ops:tf.reshape

如果数据值按顺序排布如下,

      [[[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9]],

       [[10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19]],

       [[20, 21, 22, 23, 24],
        [25, 26, 27, 28, 29]]]

那么对应上面三维立方体的摆放应该如下:

TensorFlow中Tensor的shape概念与tf ops:tf.reshape

 

三、tf.reshape()运算

 

reshape原型如下:

tf.reshape(
    tensor, shape, name=None
)

 

3.1 tf.reshape()的不改变性

  1. tf.reshape()运算不改变数据的物理摆布,也就是说一个Tensor reshape到别的shape只是逻辑上shape的改变,存储的数据不会改变;
  2. tf.reshape()也就不会改变Tensor的size。指定的新的shape的size如果与原Tensor的size不一致就会报错。比如上面Shape(3, 2, 5)的Tensor就没法reshape成7x?。

有了上面两个原则,tf.reshape()运算就很容易理解了,物理存储不变,就看rank以及各个dimension怎么取了。

 

3.2 tf.reshape() 图示

比如,上面Tensor有30个数:从0~29顺序存储。可以存储为(3, 2, 5)【上面介绍过的3-D】,也可以存储为2D的(3, 10)或(6, 5),等。

TensorFlow中Tensor的shape概念与tf ops:tf.reshapeTensorFlow中Tensor的shape概念与tf ops:tf.reshape

 

3.3 程序实现如下:

TF2.0以后的版本上,直接可以执行,而不用还要在session下执行。当然前提是已经

import tensorflow as tf

1. 30个数的数据

>>> t = tf.range(30)
>>> t
<tf.Tensor: shape=(30,), dtype=int32, numpy=
array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29], dtype=int32)>
>>> 

2. shape(3, 2, 5)

>>> t = tf.reshape(t, [3,2,5])
>>> t
<tf.Tensor: shape=(3, 2, 5), dtype=int32, numpy=
array([[[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9]],

       [[10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19]],

       [[20, 21, 22, 23, 24],
        [25, 26, 27, 28, 29]]], dtype=int32)>
>>> 

3. shape(3, 10)

>>> t = tf.reshape(t, [3, 10])
>>> t
<tf.Tensor: shape=(3, 10), dtype=int32, numpy=
array([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9],
       [10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
       [20, 21, 22, 23, 24, 25, 26, 27, 28, 29]], dtype=int32)>
>>> 

4. shape(6, 5)

>>> t = tf.reshape(t, [6, 5])
>>> t
<tf.Tensor: shape=(6, 5), dtype=int32, numpy=
array([[ 0,  1,  2,  3,  4],
       [ 5,  6,  7,  8,  9],
       [10, 11, 12, 13, 14],
       [15, 16, 17, 18, 19],
       [20, 21, 22, 23, 24],
       [25, 26, 27, 28, 29]], dtype=int32)>
>>> 

 

四、小结

本文介绍了TensorFlow里Tensor的Shape概念,并用图示和实际程序解释了reshape的变化。