欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

关于tensorflow 中 placeholder 与 reshape的一点坑

程序员文章站 2022-03-11 16:59:30
...

在搭LeNet-5 模型时,在卷积层的输出到全连接层时,使用了reshape将四维的矩阵转化维2维矩阵时,发生了错误:
关于tensorflow 中 placeholder 与 reshape的一点坑
起初以为时类型转换发生了错误,然后演算过后发现并没有错误。然后改了下 训练数据的输入格式

    # 定义输入输出placeholder, **修改前**
    x = tf.placeholder(tf.float32,
                       [None,
                        mnist_inference.IMAGE_SIZE,
                        mnist_inference.IMAGE_SIZE,
                        mnist_inference.NUM_CHANNELS],
                       name='x-input')
    # 定义输入输出placeholder。**修改后**
    x = tf.placeholder(tf.float32,
                       [BATCH_SIZE,
                        mnist_inference.IMAGE_SIZE,
                        mnist_inference.IMAGE_SIZE,
                        mnist_inference.NUM_CHANNELS],
                       name='x-input')

然后错误没有了,写了一个简单的验证程序,来验针下placeholder这里出现的问题。

import tensorflow as tf
import numpy as np

x = tf.placeholder(tf.float32, [None,2,2,2],name='x-input')
x_shape = x.get_shape().as_list()
len = x_shape[1] * x_shape[2] * x_shape[3]
x_reshaped = tf.reshape(x, [x_shape[0],len])

y = x_reshaped + 1
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    data = np.arange(2 * 2 * 2 * 2).reshape([2, 2, 2, 2]).astype('float32')
    out = sess.run(y, feed_dict={x:data})
    print(out)

然后发生了如下错误,通过黄色标注的字体可以发现时发生在了reshape是发生了错误,reshape()函数无法识别None这里发生的转换,所以报错。
关于tensorflow 中 placeholder 与 reshape的一点坑

所以问题出现在reshape函数这块reshap() 函数无法识别转化列表中的None是多少,这时可以使用python中的自动推导,也就是

# x_reshaped = tf.reshape(x, [x[0], len])
x_reshaped = tf.reshape(x, [-1,len])

这样完美解决问题,placeholder的shape参数可以为

x = tf.placeholder(tf.float32,
                       [None,
                        mnist_inference.IMAGE_SIZE,
                        mnist_inference.IMAGE_SIZE,
                        mnist_inference.NUM_CHANNELS],
                       name='x-input')

省时省力

相关标签: placeholder reshape