欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

[Tensorflow] Reader & queue图片读取管道

程序员文章站 2022-05-03 13:14:08
...

背景:在训练数据很大的情况下,无法将数据全部读入内存。除了自己写个工具处理,还可以使用tensorflow提供的工具。

[Tensorflow] Reader & queue图片读取管道

一、流程

  1. 文件名集合,用list表示。  如["a.jpg","b.jpg"];
  2. 文件名队列。调用API  tf.train.string_input_producer;
  3. 构建一个适合该文件格式的reader,如tf.WholeFileReader.read(queue)返回一个二进制流格式的tensor
  4. 一个decoder用于解码reader从文件中读到的数据。 如tf.image.decode_jpeg(ts)从tensor从解码jpeg图片,返回tensor
  5. 步骤四中的tensor可以直接用于神经网络输入

二、例子

import tensorflow as tf;    
  
filename_list = ['1.jpg']  
filename_queue = tf.train.string_input_producer(filename_list,num_epochs=5,shuffle=True) #创建输入队列  
image_reader = tf.WholeFileReader()  
key, val = image_reader.read(filename_queue)   #key和val都是tensor,key是文件名字符串,val是图片二进制数据,需要decode。
image = tf.cast(tf.image.decode_jpeg(val),dtype=tf.float32)/256-0.5
with tf.Session() as sess:
  sess.run(tf.local_variables_initializer())           #如果定义了num_epochs,一定要加这行初始化。
  coord = tf.train.Coordinator() #协同启动的线程    
  threads = tf.train.start_queue_runners(sess=sess, coord=coord) #启动线程运行队列    
  filename,arr=sess.run([key,image])        
  print(filename)            
  print(type(filename))   # bytes  
  print(arr.shape)  
  print(type(arr))        # numpy.ndarray
关于把数据组合成batch,可以这么实现的:
k1,v1=image_reader.read(filename_queue) 
k2,v2=image_reader.read(filename_queue)
labels=tf.stack([v1,v2])
也可以这么做:
import os
import sys
import tensorflow as tf
class IRIS_DS: 
  #tensor
  train_input=None
  train_lable=None 
  def getFileNameFrom(self,dir,accept_suffix="jpg|bmp"): #文件名list
    suffix=set(accept_suffix.split("|"))
    filename_list=[]
    for root,dirs,files in os.walk(dir):
      for file in files:
        if file.strip().split(".")[-1] in suffix:
          path="%s/%s" %(root,file)
          filename_list.append(path) 
    return filename_list
    
  def label_decoder(self,label):        
    return tf.substr(label,28,3)
    
  def input_decoder(self,input):
    return tf.cast(tf.image.decode_bmp(input),dtype=tf.float32)   #此处也可以保存为二进制数据,再使用Opencv imdecode进行后续处理
    
  def __init__(self,sess,train_dir="D:/Dataset/segmented_iris",test_dir=""):
    #(1)文件名list
    train_filename_list=self.getFileNameFrom(train_dir)
    print("[*]从%s读取训练数据,一共有%d张图" %(train_dir,len(train_filename_list)))
    #(2)文件名队列
    train_filename_queue=tf.train.string_input_producer(train_filename_list,shuffle=True)
    #(3)reader
    train_reader=tf.WholeFileReader()
    #(4)生成
    key,val=train_reader.read(train_filename_queue)
    self.train_input = self.input_decoder(val)
    self.train_label = self.label_decoder(key)
    #(5)初始化 
    sess.run(tf.local_variables_initializer()) 
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord) 
  def int2onehot(self,val,classes=1000):
    one_hot=[]
    for _ in range(classes):
      one_hot.append(0)
    one_hot[val]=1
    return one_hot
  def getTrainBatch(self,sess,batch_sz):
    data=([],[])
    for _ in range(batch_sz):
      input,label=sess.run([self.train_input,self.train_label])
      data[0].append(input)
      label=self.int2onehot(int(label.decode("utf-8")))
      data[1].append(label)
    return data
    
if __name__=="__main__":
  sess=tf.Session()
  ds=IRIS_DS(sess)
  input,label=ds.getTrainBatch(sess,11) 
  print(input[0].shape)
  print(label[0])


https://zhuanlan.zhihu.com/p/27238630