欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

TensorFlow:四种类型数据的读取流程及API讲解和代码实现

程序员文章站 2022-07-02 15:30:14
...

由于数据量一般会比较大,所以更多的会使用直接从文件中读取。

但是对于不同的文件类型,需要不同的文件处理API,有时候比较容易弄混淆,接下来就来梳理一下。

一.文件读取流程

TensorFlow:四种类型数据的读取流程及API讲解和代码实现

如上图所示,展示了文件读取的大致流程。 
最左边的A、B、C是存储于磁盘中文件,经过打乱文件之后(这里是默认的乱序读取,只是文件的顺序乱,但是文件内容不受影响),进入到文件队列中(Filename Queue)。文件队列当中的文件经过阅读器(Reader)处理,存储到内存当中。接下来对文件进行解码(Decode),解码之后进入样本队列当中进行批处理,此时经过批处理之后就可以用于模型训练了。

现在举例,对于读取CSV文件,大致要经历一下几步: 
1. 找到文件,并构造文件的列表(一阶张量) 
2. 构造文件队列 
3. 读取文件内容 
4. 解码CSV并读取内容 
5. 开启会话运行,得出训练结果

二.文件读取的API

1.文件队列构造

tf.train.string_input_producer(string_tensor,num_epochs,shuffle=True)

  • 将输出字符串(例如文件名)输入到管道队列
  • string_tensor:含有文件名的一阶张量,需要指定文件路径
  • num_epochs:将全部数据循环的次数
  • return:具有输出字符串的队列

2.文件阅读器

此时需要根据文件的格式,选择对应的文件阅读器

(1) 文本文件:tf.TextLineReader()

  • 读取文本文件,逗号分隔值(CSV)格式,默认按行读取
  • return:读取器实例

(2)二进制文件:tf.FixedLengthRecordReader(record_bytes)

  • 读取每个记录是固定数量字节的二进制文件
  • record_bytes:整型,指定每次读取的字节数
  • return:读取器实例

(3)图片文件:tf.WholeReader()

  • 将文件的全部内容作为值输出,即一次读取一整个文件
  • return:读取器实例

(4)TFRecords文件:tf.TFRecordReader()

  • 读取 TFRecords文件
  • return:读取器实例
注:这几种文件格式都有一个共同的读取方法:read(file_queue)
  • 从队列中指定内容数量
  • file_name : 文件队列
  • ruturn : 返回一个Tensor元组(key,value) 
    • key : 文件名
    • value : 每次读取的值(一行文本、一张图片或指定字节的值)

3.文件内容解码器

由于从文件中读取的是字符串,需要函数去解析这些字符串,最后变换成张量 
(1)CSV文件: 
tf.decode_csv(records,record_defaults=None,field_delim=None,name=None)

  • 将CSV文件转换成张量,需要tf.TextLineReader()搭配使用
  • records : tensor型字符串,每个字符串是CSV中的记录行(即value值)
  • record_defaults : 此参数决定了所得张量的类型,并设置一个值,如果在输入字符串中缺少则使用默认值,如[[1],[1]] 或者[[“None”],[“None”]]
  • field_dim : 默认分隔符“ ,”

(2)二进制文件: 
tf.decode_raw(bytes,out_type,little_endian=None,name=None)

  • 将字节转换为一个数字向量表示,字节为以字符串类型的张量
  • 与函数tf.FixedLengthRecordReader搭配使用
  • 将二进制转换为uint8格式

(3)图像文件:

  • 1)tf.image.decode_jpeg(contens)

    • 将JPEG编码的图像解码为uint8张量
    • return : uint8张量,3-D形状[height,width,channels]
  • 2) tf.image.decode_png(contents)

    • 将PNG编码的图像解码为uint8或者uint16编码
    • return : 张量类型,3-D形状[height,width,channels]

(4)TFRecords文件: 
TFRecords文件是TensorFlow中的统一格式,它的存储和读取方式较为复杂,我会在下篇文章中单独来梳理这部分的内容。

4.批处理数据

对数据进行批处理需要在会话开启之前进行 
(1)tf.train.batch(tensors,batch_size,num_threads=1,capacity=32,name=None)

  • 读取指定大小(个数)的张量
  • tensor : 包含张量的列表
  • batch_size : 从队列中读取的批处理数据大小
  • num_threads : 进入队列的线程数
  • capacity : 整数,批处理队列中元素的最大数量
  • teturn : tensors

(2)tf.train.shuffle_batch(tensors,batch_size,capacity,min_after_dequeue,num_threads=1,capacity=32,name=None)

  • 乱序读取指定大小(数量)的张量
  • min_after_dequeue : 留下队列里的张量个数,能够保持随机打乱

三.示例代码

1.CSV文件读取案例

def csvread(filelist):
    """
    CSV文件读取
    :param filelist: 文件的列表(1阶张量)
    :return:None
    """
    #2.构造文件的队列
    file_queue = tf.train.string_input_producer(filelist)

    #3.读取文件内容tf.decode_csv()
    #构造阅读器
    reader = tf.TextLineReader()
    #读队列文件内容,一行
    key,value = reader.read(file_queue)

    #4、解码csv文件
    #指定每一行格式的默认值,类型,[[1],[2.0],[1]]
    records = [["None"],["None"]]

    example,label = tf.decode_csv(value,record_defaults=records)

    #批处理读取数据
    example_batch,label_batch = tf.train.batch([example,label],batch_size=20,num_threads=1,capacity=100)

    #5、会话运行结果
    with tf.Session() as sess:
        #开启线程协调器
        coord = tf.train.Coordinator()

        #创建子线程去进行操作,返回线程列表
        threads = tf.train.start_queue_runners(sess,coord = coord)

        #打印
        print(sess.run([example_batch,label_batch]))

        #回收
        coord.request_stop()   #强制请求线程停止
        coord.join(threads)    #等待线程终止回收

    return None

if __name__ == '__main__':
    #列出文件目录,构造路径+文件名的列表,"A.csv"...
    # os.listdir() 方法用于返回指定的文件夹包含的文件或文件夹的名字的列表
    filename = os.listdir('./data/csvdata')

    #加上路径
    file_list = [os.path.join('./data/csvdata', file) for file in filename]

    csvread(file_list)

2,图像文件读取案例

./data/dog文件中存储了100张 *.jpg格式的狗的图片

def picread(file_list):
    """
    读取狗图片并转换成张量
    :param file_list:
    :return:
    """
    #1、构造文件的队列
    file_queue = tf.train.string_input_producer(file_list)

    #2、生成图片读取器,读取队列内容
    reader = tf.WholeFileReader()   #返回读取器实例

    key ,value = reader.read(file_queue)

    print(key,value)

    #3.进行图片的解码
    image = tf.image.decode_jpeg(value)

    print(image)

    #4.处理图片的大小
    image_resize = tf.image.resize_images(image,[256,256])

    print(image_resize)

    #设置静态形状   ,动态形状也可以
    image_resize.set_shape([256,256,3])

    print(image_resize)

    #5.进行批处理                  #此处image_siez必须指定形状,而且要为列表
    image_batch = tf.train.batch([image_resize],batch_size=100,num_threads=1,capacity=100)

    print(image_batch)

    return image_batch


if __name__ == '__main__':

    # 找到文件路径,名字,构造路径+文件名的列表,"A.csv"...
    # os.listdir() 方法用于返回指定的文件夹包含的文件或文件夹的名字的列表
    filename = os.listdir('./data/dog')

    #加上路径
    file_list = [os.path.join('./data/dog', file) for file in filename]

    image_batch = picread(file_list)

    with tf.Session() as sess:
        #定义线程协调器
        coord = tf.train.Coordinator()

        #开启线程
        threads = tf.train.start_queue_runners(sess,coord=coord)

        print(sess.run(image_batch))

        #回收线程
        coord.request_stop()
        coord.join(threads)

3.二进制文件读取案例

此案例中数据是使用的下载好的二进制的cifar10数据

#读取二进制转换文件
class CifarRead(object):
    """
    读取二进制文件转换成张量,写进TFRecords,同时读取TFRcords
    """
    def __init__(self,file_list):
        """
        初始化图片参数
        :param file_list:图片的路径名称列表
        """

        #文件列表
        self.file_list = file_list

        #图片大小,二进制文件字节数
        self.height = 32
        self.width = 32
        self.channel = 3
        self.label_bytes = 1
        self.image_bytes = self.height * self.width * self.channel
        self.bytes = self.label_bytes + self.image_bytes


    def read_and_decode(self):
        """
        解析二进制文件到张量
        :return: 批处理的image,label张量
        """
        #1.构造文件队列
        file_queue = tf.train.string_input_producer(self.file_list)

        #2.阅读器读取内容
        reader = tf.FixedLengthRecordReader(self.bytes)

        key ,value = reader.read(file_queue)    #key为文件名,value为元组

        print(value)

        #3.进行解码,处理格式
        label_image = tf.decode_raw(value,tf.uint8)
        print(label_image)

        #处理格式,image,label
        #进行切片处理,标签值
        #tf.cast()函数是转换数据格式,此处是将label二进制数据转换成int32格式
        label = tf.cast(tf.slice(label_image,[0],[self.label_bytes]),tf.int32)

        #处理图片数据
        image = tf.slice(label_image,[self.label_bytes],[self.image_bytes])
        print(image)

        #处理图片的形状,提供给批处理
        #因为image的形状已经固定,此处形状用动态形状来改变
        image_tensor = tf.reshape(image,[self.height,self.width,self.channel])
        print(image_tensor)

        #批处理图片数据
        image_batch,label_batch = tf.train.batch([image_tensor,label],batch_size=10,num_threads=1,capacity=10)

        return image_batch,label_batch

if __name__ == '__main__':

    # 找到文件路径,名字,构造路径+文件名的列表,"A.csv"...
    # os.listdir() 方法用于返回指定的文件夹包含的文件或文件夹的名字的列表
    filename = os.listdir('./data/cifar10/cifar-10-batches-bin/')

    #加上路径
    file_list = [os.path.join('./data/cifar10/cifar-10-batches-bin/', file) for file in filename if file[-3:] == "bin"]

    #初始化参数
    cr = CifarRead(file_list)

    image_batch,label_batch = cr.read_and_decode()

    with tf.Session() as sess:
        #线程协调器
        coord = tf.train.Coordinator()

        #开启线程
        threads = tf.train.start_queue_runners(sess,coord=coord)

        print(sess.run([image_batch,label_batch]))

        #回收线程
        coord.request_stop()
        coord.join(threads)
  • TFRecords文件是TensorFlow中的统一格式,它的存储和读取方式较上面三种格式要稍微复杂一些。

会在

https://blog.csdn.net/m0_37407756/article/details/80671905

说明。