[Tensorflow] Reader & queue图片读取管道
程序员文章站
2022-05-03 13:14:08
...
背景:在训练数据很大的情况下,无法将数据全部读入内存。除了自己写个工具处理,还可以使用tensorflow提供的工具。
一、流程
- 文件名集合,用list表示。 如["a.jpg","b.jpg"];
- 文件名队列。调用API tf.train.string_input_producer;
- 构建一个适合该文件格式的reader,如tf.WholeFileReader.read(queue)返回一个二进制流格式的tensor
- 一个decoder用于解码reader从文件中读到的数据。 如tf.image.decode_jpeg(ts)从tensor从解码jpeg图片,返回tensor
- 步骤四中的tensor可以直接用于神经网络输入
二、例子
import tensorflow as tf;
filename_list = ['1.jpg']
filename_queue = tf.train.string_input_producer(filename_list,num_epochs=5,shuffle=True) #创建输入队列
image_reader = tf.WholeFileReader()
key, val = image_reader.read(filename_queue) #key和val都是tensor,key是文件名字符串,val是图片二进制数据,需要decode。
image = tf.cast(tf.image.decode_jpeg(val),dtype=tf.float32)/256-0.5
with tf.Session() as sess:
sess.run(tf.local_variables_initializer()) #如果定义了num_epochs,一定要加这行初始化。
coord = tf.train.Coordinator() #协同启动的线程
threads = tf.train.start_queue_runners(sess=sess, coord=coord) #启动线程运行队列
filename,arr=sess.run([key,image])
print(filename)
print(type(filename)) # bytes
print(arr.shape)
print(type(arr)) # numpy.ndarray
关于把数据组合成batch,可以这么实现的:
k1,v1=image_reader.read(filename_queue)
k2,v2=image_reader.read(filename_queue)
labels=tf.stack([v1,v2])
也可以这么做:import os
import sys
import tensorflow as tf
class IRIS_DS:
#tensor
train_input=None
train_lable=None
def getFileNameFrom(self,dir,accept_suffix="jpg|bmp"): #文件名list
suffix=set(accept_suffix.split("|"))
filename_list=[]
for root,dirs,files in os.walk(dir):
for file in files:
if file.strip().split(".")[-1] in suffix:
path="%s/%s" %(root,file)
filename_list.append(path)
return filename_list
def label_decoder(self,label):
return tf.substr(label,28,3)
def input_decoder(self,input):
return tf.cast(tf.image.decode_bmp(input),dtype=tf.float32) #此处也可以保存为二进制数据,再使用Opencv imdecode进行后续处理
def __init__(self,sess,train_dir="D:/Dataset/segmented_iris",test_dir=""):
#(1)文件名list
train_filename_list=self.getFileNameFrom(train_dir)
print("[*]从%s读取训练数据,一共有%d张图" %(train_dir,len(train_filename_list)))
#(2)文件名队列
train_filename_queue=tf.train.string_input_producer(train_filename_list,shuffle=True)
#(3)reader
train_reader=tf.WholeFileReader()
#(4)生成
key,val=train_reader.read(train_filename_queue)
self.train_input = self.input_decoder(val)
self.train_label = self.label_decoder(key)
#(5)初始化
sess.run(tf.local_variables_initializer())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
def int2onehot(self,val,classes=1000):
one_hot=[]
for _ in range(classes):
one_hot.append(0)
one_hot[val]=1
return one_hot
def getTrainBatch(self,sess,batch_sz):
data=([],[])
for _ in range(batch_sz):
input,label=sess.run([self.train_input,self.train_label])
data[0].append(input)
label=self.int2onehot(int(label.decode("utf-8")))
data[1].append(label)
return data
if __name__=="__main__":
sess=tf.Session()
ds=IRIS_DS(sess)
input,label=ds.getTrainBatch(sess,11)
print(input[0].shape)
print(label[0])