tensorflow 读取并显示图片
# -*- coding: utf-8 -*-
import tensorflow as tf
filename = '2.jpg'
with tf.gfile.FastGFile(filename,'rb') as f:
image_data = f.read()
with tf.Session() as sess:
image = sess.run(image_data)
print(image.shape)
这里我们用到tf.gfile.FastGFile的read方法
read
read(n=-1)
Returns the contents of a file as a string.Args:
n: Read ‘n’ bytes if n != -1. If n = -1, reads to end of file.
Returns:
‘n’ bytes of the file (or whole file) in bytes mode or ‘n’ bytes of the string if in string (regular) mode.
read方法默认以字符串返回整个文件的内容,我们这里的参数是rb
,所以按照字节返回
# -*- coding: utf-8 -*-
import tensorflow as tf
filename = '2.jpg'
with tf.gfile.FastGFile(filename,'r') as f:
image_data = f.read()
with tf.Session() as sess:
image = sess.run(image_data)
这里我尝试让其直接返回字符串,但是会报错
UnicodeDecodeError: 'utf-8' codec can't decode byte 0xff in position 0: invalid start byte
我也不是很清楚,不多用bytes返回应该就足够了。另一个读取方法是
image_data = tf.read_file(filename)
我觉得这两个读取差别不大,tf.read_file
返回的是一个Tensor
,所以需要sess,而tf.gfile.FastGFile
是直接返回的字节或字符串。这两个方法都是读取图片的原始数据,所以我们需要对其进行解码。
tf.image.decode_jpeg(image_data)
tf.image.decode_image(image_data)
tf.decode_raw(image_data,tf.uint8)
第一个函数就比较简单了,我们读取jpeg格式的图片,然后用tf.image.decode_jpeg(image_data)
进行解码得到图片。
# -*- coding: utf-8 -*-
import tensorflow as tf
import matplotlib.pyplot as plt
filename = '2.jpg'
with tf.gfile.FastGFile(filename,'rb') as f:
image_data = f.read()
with tf.Session() as sess:
image_data = tf.image.decode_jpeg(image_data)
image = sess.run(image_data)
plt.imshow(image)
print(image.shape)
如果是一个彩色的图片是可以正常输出来的,但是一个灰度图片就不行了。
TypeError: Invalid dimensions for image data
我不死心看了一下数据形状
# -*- coding: utf-8 -*-
import tensorflow as tf
import matplotlib.pyplot as plt
filename = '2.jpg'
with tf.gfile.FastGFile(filename,'rb') as f:
image_data = f.read()
with tf.Session() as sess:
image_data = tf.image.decode_jpeg(image_data)
image = sess.run(image_data)
print(image.shape)
输出
(400, 600, 1)
但是为什会报错呢?这是因为plt不能绘制这样的数据格式,灰度图必须是(height,width)
,但是数据读取会得到(height,width,channels)
普通的彩色图就是(height,width,3)
,灰度图就是(height,width,1)
,为了让plt绘制出来我们需要reshap
一下。
# -*- coding: utf-8 -*-
import tensorflow as tf
import matplotlib.pyplot as plt
filename = '2.jpg'
with tf.gfile.FastGFile(filename,'rb') as f:
image_data = f.read()
with tf.Session() as sess:
image_data = tf.image.decode_jpeg(image_data)
image = sess.run(image_data)
h,w,c=image.shape
assert c==1
image = image.reshape(h,w)
plt.imshow(image)
print(image.shape)
这样就可以绘制图片了,但是发现是彩色的,这是因为plt默认有颜色的,可以
plt.imshow(image,cmap='gray')
有时候是彩色图片但是我们想要变成灰色
image_data = tf.image.decode_jpeg(image_data,channels=1)
这样原本的彩色图片就会变成灰色
- 0: Use the number of channels in the JPEG-encoded image.
- 1: output a grayscale image.
- 3: output an RGB image.
但是要注意,数据shape是
(height,width,3)->(height,width,1)
若要显示还是需要reshape
# -*- coding: utf-8 -*-
import tensorflow as tf
import matplotlib.pyplot as plt
filename = '1.jpg' # 彩色
with tf.gfile.FastGFile(filename,'rb') as f:
image_data = f.read()
with tf.Session() as sess:
image_data = tf.image.decode_jpeg(image_data,channels=1)
image = sess.run(image_data)
h,w,c=image.shape
assert c==1
image = image.reshape(h,w)
plt.imshow(image,cmap='gray')
print(image.shape)
tf.image.decode_image(image_data)
这个函数和tf.image.decode_jpeg
类似,主要看一下返回
Tensor with type uint8 with shape [height, width, num_channels] for BMP, JPEG, and PNG images and shape [num_frames, height, width, 3] for GIF images.
jpeg,png,bmp,甚至是gif都可以解码。我觉得tf.image.decode_image
要比tf.image.decode_jpeg
更方便一点。但是你会发现ValueError: 'images' contains no shape.
,这是因为这个函数需要返回shape
# -*- coding: utf-8 -*-
import tensorflow as tf
import matplotlib.pyplot as plt
filename = '1.jpg' # 彩色
with tf.gfile.FastGFile(filename,'rb') as f:
image_data = f.read()
with tf.Session() as sess:
image_data = tf.image.decode_image(image_data)
image_data.set_shape([None,None,None])
image = sess.run(image_data)
h,w,c=image.shape
assert c==1
image = image.reshape(h,w)
plt.imshow(image,cmap='gray')
print(image.shape)
最后一个是
tf.decode_raw(image_data,tf.uint8)
这个函数就有点傻乎乎的了,
# -*- coding: utf-8 -*-
import tensorflow as tf
filename = '1.jpg'
with tf.gfile.FastGFile(filename,'rb') as f:
image_data = f.read()
with tf.Session() as sess:
image_data = tf.decode_raw(image_data,tf.uint8)
image = sess.run(image_data)
print(image)
print(image.shape)
输出
[255 216 255 ... 127 255 217]
(47982,)
就是原来什么样子我就给你输出什么样子,不过需要指定解码的数据类型。我们是图片所以用tf.uint8
。我的困惑就是从这个函数开始的,
# -*- coding: utf-8 -*-
import tensorflow as tf
filename = '1.jpg'
with tf.gfile.FastGFile(filename,'rb') as f:
image1 = f.read()
image2 = tf.read_file(filename)
with tf.Session() as sess:
print(len(image1))
image2 = sess.run(image2)
print(len(image2))
输出
47982
47982
但是图片的shape是什么(400,600,3)
,就是720000,这可怎么都不一样啊,我是不知道如何根据这些数据解码成图片的,jpeg的编码协议我也不清楚,所以我猜测tf.decode_raw
与tf.gfile.FastGFile
,tf.read_file
并不是配对使用的。
tensorflow这个垃圾,我也是垃圾
# -*- coding: utf-8 -*-
import tensorflow as tf
from PIL import Image
import numpy as np
filename = '1.jpg'
img = Image.open(filename)
img = np.array(img)
h,w,c=img.shape
img_raw = img.tobytes()
example = tf.train.Example(features=tf.train.Features(feature={
'h': tf.train.Feature(int64_list=tf.train.Int64List(value=[h])),
'w': tf.train.Feature(int64_list=tf.train.Int64List(value=[w])),
'c': tf.train.Feature(int64_list=tf.train.Int64List(value=[c])),
'img': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
}))
writer= tf.python_io.TFRecordWriter("pic.tfrecords")
writer.write(example.SerializeToString())
writer.close()
filename_queue = tf.train.string_input_producer(["pic.tfrecords"])
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue) #返回文件名和文件
features = tf.parse_single_example(serialized_example,features={
'h': tf.FixedLenFeature([], tf.int64),
'w': tf.FixedLenFeature([], tf.int64),
'c': tf.FixedLenFeature([], tf.int64),
'img' : tf.FixedLenFeature([], tf.string)})
img = tf.decode_raw(features['img'], tf.uint8)
h = tf.cast(features['h'], tf.int32)
w = tf.cast(features['w'], tf.int32)
c = tf.cast(features['c'], tf.int32)
img = tf.reshape(img, [h,w,c])
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
with tf.Session() as sess:
sess.run(init_op)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
print(sess.run(img))
coord.request_stop()
coord.join(threads)
上面这段简单的代码是制作与读取tfrecord文件,这里用到了tf.decode_raw
,可以看到我们是直接读取了图片的数据(每个像素的值),而不是原始图片的数据(jpeg编码后的数据)。
上一篇: 揉按手法