欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

关于TensorFlow中tf.reshape()及shape的问题

程序员文章站 2022-05-31 18:48:08
...

我看别人都有转载请声明出处,我也写上,:)

转载请申明出处,https://blog.csdn.net/sinat_28704977/article/details/80626689

更重要的,如有错误,请批评指正,不胜感谢。


先上代码:

import numpy as np
import tensorflow as tf
a = tf.constant([
    [ [ 1.0, 2.0, 3.0, 4.0 ],
      [ 5.0, 6.0, 7.0, 8.0 ],
      [ 8.0, 7.0, 6.0, 5.0 ],
      [ 4.0, 3.0, 2.0, 1.0 ] ],
    [ [ 4.0, 3.0, 2.0, 1.0 ],
      [ 8.0, 7.0, 6.0, 5.0 ],
      [ 1.0, 2.0, 3.0, 4.0 ],
      [ 5.0, 6.0, 7.0, 8.0 ] ]
])
c=a
image_shape = c.get_shape()
b=tf.reshape(a,[4,4,2])
a = tf.reshape(a, [ 1, 4, 4, 2 ])
#这里reshape是强制转换,是直接从上面矩阵挨个取值并转换为目标shape,转换后的图片不是如上图一样的两个4*4数组即(2,4,4),而是(1,4,4,2)。

with tf.Session() as sess:
    g,h=(c,image_shape[-1].value)
    d = image_shape[ 1: ].as_list()#[4,4]一维列表
    dim=1
    for test in d:#image_shape为(2,4,4)tensor_shape类型,维数为3
        print(test)
        dim*=int(test)
    e=tf.reshape(a,[-1,dim])
    f=tf.reshape(e,(32,1))
    print(e,'\n',e.get_shape())
    image = sess.run(e)
    image2 = sess.run(f)
    print(image,image2)

问题1

代码中,a的shape为(2,4,4),并不是我们直观认为的图像格式4*4*2。这是需要注意的一点。

问题2

首先,tensor有get_shape()方法获得tensor的shape,类型为tensor_shape,然后,shape[1:]的意思是取从第二个维度开始的shape,例如返回的tensor_shape为(1,4,4,2),加上as_list()就成为[4,4,2]一个一维列表.

问题3

e=tf.reshape(a,[-1,dim])

这里不是转为列向量,因为如果是一个列向量,那么shape应该是(1,32,1),如果是列向量,则是(32,1)。这里转为:最后一个维度是32的一个向量,也就是(1,32)。要区分清楚。
结果图:
关于TensorFlow中tf.reshape()及shape的问题