TensorFlow:四种类型数据的读取流程及API讲解和代码实现
由于数据量一般会比较大,所以更多的会使用直接从文件中读取。
但是对于不同的文件类型,需要不同的文件处理API,有时候比较容易弄混淆,接下来就来梳理一下。
一.文件读取流程
如上图所示,展示了文件读取的大致流程。
最左边的A、B、C是存储于磁盘中文件,经过打乱文件之后(这里是默认的乱序读取,只是文件的顺序乱,但是文件内容不受影响),进入到文件队列中(Filename Queue)。文件队列当中的文件经过阅读器(Reader)处理,存储到内存当中。接下来对文件进行解码(Decode),解码之后进入样本队列当中进行批处理,此时经过批处理之后就可以用于模型训练了。
现在举例,对于读取CSV文件,大致要经历一下几步:
1. 找到文件,并构造文件的列表(一阶张量)
2. 构造文件队列
3. 读取文件内容
4. 解码CSV并读取内容
5. 开启会话运行,得出训练结果
二.文件读取的API
1.文件队列构造
tf.train.string_input_producer(string_tensor,num_epochs,shuffle=True)
- 将输出字符串(例如文件名)输入到管道队列
-
string_tensor
:含有文件名的一阶张量,需要指定文件路径 -
num_epochs
:将全部数据循环的次数 -
return
:具有输出字符串的队列
2.文件阅读器
此时需要根据文件的格式,选择对应的文件阅读器
(1) 文本文件:tf.TextLineReader()
- 读取文本文件,逗号分隔值(CSV)格式,默认按行读取
- return:读取器实例
(2)二进制文件:tf.FixedLengthRecordReader(record_bytes)
- 读取每个记录是固定数量字节的二进制文件
- record_bytes:整型,指定每次读取的字节数
- return:读取器实例
(3)图片文件:tf.WholeReader()
- 将文件的全部内容作为值输出,即一次读取一整个文件
- return:读取器实例
(4)TFRecords文件:tf.TFRecordReader()
- 读取 TFRecords文件
- return:读取器实例
注:这几种文件格式都有一个共同的读取方法:read(file_queue)
- 从队列中指定内容数量
- file_name : 文件队列
- ruturn : 返回一个Tensor元组(key,value)
- key : 文件名
- value : 每次读取的值(一行文本、一张图片或指定字节的值)
3.文件内容解码器
由于从文件中读取的是字符串,需要函数去解析这些字符串,最后变换成张量
(1)CSV文件: tf.decode_csv(records,record_defaults=None,field_delim=None,name=None)
- 将CSV文件转换成张量,需要
tf.TextLineReader()
搭配使用 - records : tensor型字符串,每个字符串是CSV中的记录行(即value值)
- record_defaults : 此参数决定了所得张量的类型,并设置一个值,如果在输入字符串中缺少则使用默认值,如[[1],[1]] 或者[[“None”],[“None”]]
- field_dim : 默认分隔符“ ,”
(2)二进制文件: tf.decode_raw(bytes,out_type,little_endian=None,name=None)
- 将字节转换为一个数字向量表示,字节为以字符串类型的张量
- 与函数tf.FixedLengthRecordReader搭配使用
- 将二进制转换为uint8格式
(3)图像文件:
-
1)
tf.image.decode_jpeg(contens)
- 将JPEG编码的图像解码为uint8张量
- return : uint8张量,3-D形状[height,width,channels]
-
2)
tf.image.decode_png(contents)
- 将PNG编码的图像解码为uint8或者uint16编码
- return : 张量类型,3-D形状[height,width,channels]
(4)TFRecords文件:
TFRecords文件是TensorFlow中的统一格式,它的存储和读取方式较为复杂,我会在下篇文章中单独来梳理这部分的内容。
4.批处理数据
对数据进行批处理需要在会话开启之前进行
(1)tf.train.batch(tensors,batch_size,num_threads=1,capacity=32,name=None)
- 读取指定大小(个数)的张量
- tensor : 包含张量的列表
- batch_size : 从队列中读取的批处理数据大小
- num_threads : 进入队列的线程数
- capacity : 整数,批处理队列中元素的最大数量
- teturn : tensors
(2)tf.train.shuffle_batch(tensors,batch_size,capacity,min_after_dequeue,num_threads=1,capacity=32,name=None)
- 乱序读取指定大小(数量)的张量
- min_after_dequeue : 留下队列里的张量个数,能够保持随机打乱
三.示例代码
1.CSV文件读取案例
def csvread(filelist):
"""
CSV文件读取
:param filelist: 文件的列表(1阶张量)
:return:None
"""
#2.构造文件的队列
file_queue = tf.train.string_input_producer(filelist)
#3.读取文件内容tf.decode_csv()
#构造阅读器
reader = tf.TextLineReader()
#读队列文件内容,一行
key,value = reader.read(file_queue)
#4、解码csv文件
#指定每一行格式的默认值,类型,[[1],[2.0],[1]]
records = [["None"],["None"]]
example,label = tf.decode_csv(value,record_defaults=records)
#批处理读取数据
example_batch,label_batch = tf.train.batch([example,label],batch_size=20,num_threads=1,capacity=100)
#5、会话运行结果
with tf.Session() as sess:
#开启线程协调器
coord = tf.train.Coordinator()
#创建子线程去进行操作,返回线程列表
threads = tf.train.start_queue_runners(sess,coord = coord)
#打印
print(sess.run([example_batch,label_batch]))
#回收
coord.request_stop() #强制请求线程停止
coord.join(threads) #等待线程终止回收
return None
if __name__ == '__main__':
#列出文件目录,构造路径+文件名的列表,"A.csv"...
# os.listdir() 方法用于返回指定的文件夹包含的文件或文件夹的名字的列表
filename = os.listdir('./data/csvdata')
#加上路径
file_list = [os.path.join('./data/csvdata', file) for file in filename]
csvread(file_list)
2,图像文件读取案例
./data/dog文件中存储了100张 *.jpg格式的狗的图片
def picread(file_list):
"""
读取狗图片并转换成张量
:param file_list:
:return:
"""
#1、构造文件的队列
file_queue = tf.train.string_input_producer(file_list)
#2、生成图片读取器,读取队列内容
reader = tf.WholeFileReader() #返回读取器实例
key ,value = reader.read(file_queue)
print(key,value)
#3.进行图片的解码
image = tf.image.decode_jpeg(value)
print(image)
#4.处理图片的大小
image_resize = tf.image.resize_images(image,[256,256])
print(image_resize)
#设置静态形状 ,动态形状也可以
image_resize.set_shape([256,256,3])
print(image_resize)
#5.进行批处理 #此处image_siez必须指定形状,而且要为列表
image_batch = tf.train.batch([image_resize],batch_size=100,num_threads=1,capacity=100)
print(image_batch)
return image_batch
if __name__ == '__main__':
# 找到文件路径,名字,构造路径+文件名的列表,"A.csv"...
# os.listdir() 方法用于返回指定的文件夹包含的文件或文件夹的名字的列表
filename = os.listdir('./data/dog')
#加上路径
file_list = [os.path.join('./data/dog', file) for file in filename]
image_batch = picread(file_list)
with tf.Session() as sess:
#定义线程协调器
coord = tf.train.Coordinator()
#开启线程
threads = tf.train.start_queue_runners(sess,coord=coord)
print(sess.run(image_batch))
#回收线程
coord.request_stop()
coord.join(threads)
3.二进制文件读取案例
此案例中数据是使用的下载好的二进制的cifar10数据
#读取二进制转换文件
class CifarRead(object):
"""
读取二进制文件转换成张量,写进TFRecords,同时读取TFRcords
"""
def __init__(self,file_list):
"""
初始化图片参数
:param file_list:图片的路径名称列表
"""
#文件列表
self.file_list = file_list
#图片大小,二进制文件字节数
self.height = 32
self.width = 32
self.channel = 3
self.label_bytes = 1
self.image_bytes = self.height * self.width * self.channel
self.bytes = self.label_bytes + self.image_bytes
def read_and_decode(self):
"""
解析二进制文件到张量
:return: 批处理的image,label张量
"""
#1.构造文件队列
file_queue = tf.train.string_input_producer(self.file_list)
#2.阅读器读取内容
reader = tf.FixedLengthRecordReader(self.bytes)
key ,value = reader.read(file_queue) #key为文件名,value为元组
print(value)
#3.进行解码,处理格式
label_image = tf.decode_raw(value,tf.uint8)
print(label_image)
#处理格式,image,label
#进行切片处理,标签值
#tf.cast()函数是转换数据格式,此处是将label二进制数据转换成int32格式
label = tf.cast(tf.slice(label_image,[0],[self.label_bytes]),tf.int32)
#处理图片数据
image = tf.slice(label_image,[self.label_bytes],[self.image_bytes])
print(image)
#处理图片的形状,提供给批处理
#因为image的形状已经固定,此处形状用动态形状来改变
image_tensor = tf.reshape(image,[self.height,self.width,self.channel])
print(image_tensor)
#批处理图片数据
image_batch,label_batch = tf.train.batch([image_tensor,label],batch_size=10,num_threads=1,capacity=10)
return image_batch,label_batch
if __name__ == '__main__':
# 找到文件路径,名字,构造路径+文件名的列表,"A.csv"...
# os.listdir() 方法用于返回指定的文件夹包含的文件或文件夹的名字的列表
filename = os.listdir('./data/cifar10/cifar-10-batches-bin/')
#加上路径
file_list = [os.path.join('./data/cifar10/cifar-10-batches-bin/', file) for file in filename if file[-3:] == "bin"]
#初始化参数
cr = CifarRead(file_list)
image_batch,label_batch = cr.read_and_decode()
with tf.Session() as sess:
#线程协调器
coord = tf.train.Coordinator()
#开启线程
threads = tf.train.start_queue_runners(sess,coord=coord)
print(sess.run([image_batch,label_batch]))
#回收线程
coord.request_stop()
coord.join(threads)
- TFRecords文件是TensorFlow中的统一格式,它的存储和读取方式较上面三种格式要稍微复杂一些。
会在
https://blog.csdn.net/m0_37407756/article/details/80671905
说明。
上一篇: JavaScript小技巧:!!的使用
下一篇: DFS遍历图时的小技巧