关于tensorflow 中 placeholder 与 reshape的一点坑
程序员文章站
2022-03-11 16:59:30
...
在搭LeNet-5 模型时,在卷积层的输出到全连接层时,使用了reshape将四维的矩阵转化维2维矩阵时,发生了错误:
起初以为时类型转换发生了错误,然后演算过后发现并没有错误。然后改了下 训练数据的输入格式
# 定义输入输出placeholder, **修改前**
x = tf.placeholder(tf.float32,
[None,
mnist_inference.IMAGE_SIZE,
mnist_inference.IMAGE_SIZE,
mnist_inference.NUM_CHANNELS],
name='x-input')
# 定义输入输出placeholder。**修改后**
x = tf.placeholder(tf.float32,
[BATCH_SIZE,
mnist_inference.IMAGE_SIZE,
mnist_inference.IMAGE_SIZE,
mnist_inference.NUM_CHANNELS],
name='x-input')
然后错误没有了,写了一个简单的验证程序,来验针下placeholder这里出现的问题。
import tensorflow as tf
import numpy as np
x = tf.placeholder(tf.float32, [None,2,2,2],name='x-input')
x_shape = x.get_shape().as_list()
len = x_shape[1] * x_shape[2] * x_shape[3]
x_reshaped = tf.reshape(x, [x_shape[0],len])
y = x_reshaped + 1
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
data = np.arange(2 * 2 * 2 * 2).reshape([2, 2, 2, 2]).astype('float32')
out = sess.run(y, feed_dict={x:data})
print(out)
然后发生了如下错误,通过黄色标注的字体可以发现时发生在了reshape是发生了错误,reshape()函数无法识别None这里发生的转换,所以报错。
所以问题出现在reshape函数这块,reshap() 函数无法识别转化列表中的None是多少,这时可以使用python中的自动推导,也就是
# x_reshaped = tf.reshape(x, [x[0], len])
x_reshaped = tf.reshape(x, [-1,len])
这样完美解决问题,placeholder的shape参数可以为
x = tf.placeholder(tf.float32,
[None,
mnist_inference.IMAGE_SIZE,
mnist_inference.IMAGE_SIZE,
mnist_inference.NUM_CHANNELS],
name='x-input')
省时省力
上一篇: wamp下修改mysql访问密码的解决方法_PHP教程
下一篇: placeholder中实现换行