关于tf.image.resize_images的一个小问题
首选我们需要加载一张原图,然后使用对应格式解码,从而得到图像对应的三维矩阵
image_raw_data = tf.gfile.FastGFile('pic.jpg', 'r').read()
img_data = tf.image.decode_jpeg(image_raw_data)
由于我这里的图片是jpg格式的,因此使用了decode_jpeg,TensorFlow还提供了tf.image.decode_png来处理png格式的图片。
处理完之后的img_data数据格式是uint8类型的。可以直接使用tf.image.encode_jpeg(img_data)在编码回去,
然后保存成新的图片,新图和原图肯定是一直的。
但是涉及到使用使用图片大小调整的时候就可能遇到一个小问题。
图片大小调整函数为 tf.image.resize_images 。
首先这是关于该函数的一个介绍。
resized = tf.image.resize_images(img_data, (300, 300), method=0)
该函数第三个参数是选择一种算法,使得新图像尽量保存原始图上的所有信息。
有四种算法,这里不是讲算法的,就不说算法的内容了。
用的时候,会发现,由于原本是uint8类型的数据,只有method=1的时候,选择最近邻居法,
也就是 `ResizeMethod.NEAREST_NEIGHBOR` 的时候,可接受uint8类型,
其它算法,如果显示图像,都会发现不是正常内容。
显示图像的代码如下:
import matplotlib.pyplot as plt
plt.imshow(resized.eval())
plt.show()
不正常的原因就是,另外三个算法接受的数据类型应该是 tf.float32类型,而不是tf.uint8类型。
这时就需要,将数据类型进行一下转化。
TensorFlow提供了 tf.image.convert_image_dtype 函数来转换图像对应矩阵数据类型。
因此,想要使用其它几个函数显示图像的话,就需要先转换一下数据类型。
img_data = tf.image.convert_image_dtype(img_data, dtype=tf.float32)
这样其它三个函数就能正常显示。当然如果这是修改method=1,你会发现代码会报错。数据类型不对。
所以根据你需要的算法,修改数据类型即可。
然后又会碰到新的问题。当你想要存储这张图片的时候,又会发现数据编码有问题。
encoded_image = tf.image.encode_jpeg(resized)
那是什么原因呢?仔细看上面解码处的逻辑,
使用 tf.image.decode_jpeg 解码出来的数据类型是uint8的,但是这会儿编码数据类型又是float32的,
数据类型对不上,大概率会出错。
那怎么办呢?再使用刚介绍的数据类型转换函数处理一下,把数据类型转换为uint8即可。
至此就没有问题了。
把完整的逻辑代码贴出来:
import matplotlib.pyplot as plt
import tensorflow as tf
# 读取图像的原始数据
image_raw_data = tf.gfile.FastGFile('./cat.jpg', 'r').read()
with tf.Session() as sess:
# 将图像使用jpeg的格式解码从而得到图像对应的三维矩阵
# TensorFlow还提供了tf.image.decode_png函数对png格式的图像进行解码
# 解码之后的结果为一个张量 在使用它的取值之前需要明确调用运行的过程
img_data = tf.image.decode_jpeg(image_raw_data)
# 使用pyplot工具可视化得到的图像
plt.imshow(img_data.eval())
plt.show()
# 将数据的类型转化成实数方便下面的样例程序对图片进行处理
img_data = tf.image.convert_image_dtype(img_data, dtype=tf.float32)
# 通过tf.image.resize_images函数调整图像的大小。这个函数第一个参数为原始图像
# 第二个和第三个参数为调整后的图像的大小,method参数给出了调整图像大小的算法
resized = tf.image.resize_images(img_data, (300, 300), method=0)
# 输出调整后图像的大小 此处结果为(300, 300, ?) 表示图像的大小为 300*300
# 但图像的深度在没有明确设置之前会是?
print(resized.get_shape())
# 双线性插值法 数据类型需要是float32
plt.imshow(resized.eval())
plt.show()
# 编码时需要是uint8类型 所以需要先转化一下数据类型
resized = tf.image.convert_image_dtype(resized, dtype=tf.uint8)
encoded_image = tf.image.encode_jpeg(resized)
# 保存新图片
with tf.gfile.GFile('./cp.jpg', 'wb') as f:
f.write(encoded_image.eval())