常用文件读取方法——tensorflow文件读取
tensorflow文件读取(多线程+队列)
重要的函数:数据切片 tf.slice()
数组装置 tf.transpose()
类型变换 tf.cast()
序列化 tostring()
1文件读取流程
1.1构造文件名队列
将文件名列表交给tf.train.string_input_producer 函数来生成一个先入先出的队列, 文件阅读器会需要它来读取数据
tf.train.string_input_producer(string_tensor, num_epochs=None, shuffle=True, seed=None, capacity=32, name=None)
string_tensor:含有文件名+路径的1阶张量
num_epochs:过几遍数据,默认过无限编
shuffle:是否打乱数据
Returns: 文件队列
file_queue = tf.train.string_input_producer(string_tensor,shuffle=True)
1.2.读取与解码
1.2.1.根据文件格式创建文件阅读器
文本/csv格式: tf.TextLineReader()
图像格式: tf.WholeFileReader()
二进制: tf.FixedLengthRecordReader()
TFRecords文件: tf.TFRecordReader()
1.2.2.文件读取——将文件名队列提供给阅读器的read
方法
阅读器的read
方法会输出一个key来表征输入的文件和其中的纪录(对于调试非常有用),同时得到一个字符串标量, 这个字符串标量可以被一个或多个解析器,或者转换操作将其解码为张量并且构造成为样本。
key, value =阅读器.read(queue, name=None)
-
key
: 文件名 -
value
: 样本
1.2.3.文件解码——根据文件格式选择不同的解码器
文本/csv格式: tf.decode_csv()
图像格式(不同的图片格式有不同的解码方式):
tf.image.decode_png()
tf.image.decode_jpeg()
二进制: tf.decode_raw()
TFRecords文件: tf.parse_single_example()
解码后默认格式:uint8
1.3.批处理——读取指定大小(个数)的张量——批处理前样本形状,类型必须一致
在数据输入管线的末端, 我们需要有另一个队列来执行输入样本的训练,评价和推理。因此我们使用f.train.batch、
tf.train.shuffle_batch
函数来对队列中的样本进行乱序处理
tf.train.batch(tensor_list, batch_size, num_threads=1, capacity=32, enqueue_many=False, shapes=None, name=None)
-
tensor_list
: The list of tensors to enqueue. -
batch_size
: 从列表中读取的批处理器大小 -
num_threads
: 进入队列的线程数 -
capacity
: An integer. The maximum number of elements in the queue. -
enqueue_many
: Whether each tensor intensor_list
is a single example. -
shapes
: (Optional) The shapes for each example. Defaults to the inferred shapes fortensor_list
. -
name
: (Optional) A name for the operations.
A list of tensors with the same number and types as tensor_list
.
1.4. 线程操作
1.4.1开启线程协调器——对线程进行协调和管理
request_stop() 请求停止
should_stop() 询问是否结束
join(threads, stop_grace_period_secs=120)回收线程
Returns: 线程协调器实例
1.4.2收集图中所有队列线程,同时默认启动线程
tf.train.start_queue_runners(sess=None, coord=None, daemon=True, start=True, collection='queue_runners')
Args:
sess: 所在的会话
coord: 线程协调器
daemon: Whether the threads should be marked as daemons, meaning they don't block program exit.
start: Set to False to only create the threads, not start them.
collection: A GraphKey specifying the graph collection to get the queue runners from. Defaults
toGraphKeys.QUEUE_RUNNERS.
Returns:
A list of threads.
1.4.3回收线程
2图像读取案例
import tensorflow as tf
import os
def read_picture():
"""
读取狗图片案例
:return:
"""
# 1、构造文件名队列
# 构造文件名列表
filename_list = os.listdir("./dog")
# 给文件名加上路径
file_list = [os.path.join("./dog/", i) for i in filename_list]
# print("file_list:\n", file_list)
# print("filename_list:\n", filename_list)
file_queue = tf.train.string_input_producer(file_list)
# 2、读取与解码
# 读取
reader = tf.WholeFileReader()
key, value = reader.read(file_queue)
print("key:\n", key)
print("value:\n", value)
# 解码
image_decoded = tf.image.decode_jpeg(value)
print("image_decoded:\n", image_decoded)
# 将图片缩放到同一个大小
image_resized = tf.image.resize_images(image_decoded, [200, 200])
print("image_resized_before:\n", image_resized)
# 更新静态形状
image_resized.set_shape([200, 200, 3])
print("image_resized_after:\n", image_resized)
# 3、批处理队列
image_batch = tf.train.batch([image_resized], batch_size=100, num_threads=2, capacity=100)
print("image_batch:\n", image_batch)
# 开启会话
with tf.Session() as sess:
# 开启线程
# 构造线程协调器
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
# 运行
filename, sample, image, n_image = sess.run([key, value, image_resized, image_batch])
print("filename:\n", filename)
print("sample:\n", sample)
print("image:\n", image)
print("n_image:\n", n_image)
coord.request_stop()
coord.join(threads)
return None
if __name__ == "__main__":
# 代码1:读取狗图片案例
read_picture()
3二进制、TFRecords文件读取案例
import tensorflow as tf
import os
class Cifar():
def __init__(self):
# 设置图像大小
self.height = 32
self.width = 32
self.channel = 3
# 设置图像字节数
self.image = self.height * self.width * self.channel
self.label = 1
self.sample = self.image + self.label
def read_binary(self):
"""
读取二进制文件
:return:
"""
# 1、构造文件名队列
filename_list = os.listdir("./cifar-10-batches-bin")
# print("filename_list:\n", filename_list)
file_list = [os.path.join("./cifar-10-batches-bin/", i) for i in filename_list if i[-3:]=="bin"]
# print("file_list:\n", file_list)
file_queue = tf.train.string_input_producer(file_list)
# 2、读取与解码
# 读取
reader = tf.FixedLengthRecordReader(self.sample)
# key文件名 value样本
key, value = reader.read(file_queue)
# 解码
image_decoded = tf.decode_raw(value, tf.uint8)
print("image_decoded:\n", image_decoded)
# 切片操作
label = tf.slice(image_decoded, [0], [self.label])
image = tf.slice(image_decoded, [self.label], [self.image])
print("label:\n", label)
print("image:\n", image)
# 调整图像的形状
image_reshaped = tf.reshape(image, [self.channel, self.height, self.width])
print("image_reshaped:\n", image_reshaped)
# 三维数组的转置
image_transposed = tf.transpose(image_reshaped, [1, 2, 0])
print("image_transposed:\n", image_transposed)
# 3、构造批处理队列
image_batch, label_batch = tf.train.batch([image_transposed, label], batch_size=100, num_threads=2, capacity=100)
# 开启会话
with tf.Session() as sess:
# 开启线程
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
label_value, image_value = sess.run([label_batch, image_batch])
print("label_value:\n", label_value)
print("image:\n", image_value)
coord.request_stop()
coord.join(threads)
return image_value, label_value
def write_to_tfrecords(self, image_batch, label_batch):
"""
将样本的特征值和目标值一起写入tfrecords文件
:param image:
:param label:
:return:
"""
with tf.python_io.TFRecordWriter("cifar10.tfrecords") as writer:
# 循环构造example对象,并序列化写入文件
for i in range(100):
image = image_batch[i].tostring()
label = label_batch[i][0]
# print("tfrecords_image:\n", image)
# print("tfrecords_label:\n", label)
example = tf.train.Example(features=tf.train.Features(feature={
"image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
}))
# example.SerializeToString()
# 将序列化后的example写入文件
writer.write(example.SerializeToString())
return None
def read_tfrecords(self):
"""
读取TFRecords文件
:return:
"""
# 1、构造文件名队列
file_queue = tf.train.string_input_producer(["cifar10.tfrecords"])
# 2、读取与解码
# 读取
reader = tf.TFRecordReader()
key, value = reader.read(file_queue)
# 解析example
feature = tf.parse_single_example(value, features={
"image": tf.FixedLenFeature([], tf.string),
"label": tf.FixedLenFeature([], tf.int64)
})
image = feature["image"]
label = feature["label"]
print("read_tf_image:\n", image)
print("read_tf_label:\n", label)
# 解码
image_decoded = tf.decode_raw(image, tf.uint8)
print("image_decoded:\n", image_decoded)
# 图像形状调整
image_reshaped = tf.reshape(image_decoded, [self.height, self.width, self.channel])
print("image_reshaped:\n", image_reshaped)
# 3、构造批处理队列
image_batch, label_batch = tf.train.batch([image_reshaped, label], batch_size=100, num_threads=2, capacity=100)
print("image_batch:\n", image_batch)
print("label_batch:\n", label_batch)
# 开启会话
with tf.Session() as sess:
# 开启线程
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
image_value, label_value = sess.run([image_batch, label_batch])
print("image_value:\n", image_value)
print("label_value:\n", label_value)
# 回收资源
coord.request_stop()
coord.join(threads)
return None
if __name__ == "__main__":
cifar = Cifar()
# image_value, label_value = cifar.read_binary()
# cifar.write_to_tfrecords(image_value, label_value)
cifar.read_tfrecords()
4TFRecords文件
train.Example协议内存块文件定义了将数据进行序列化的格式
message Example{
Features features = 1;
};
message Features{
map<string, Feature> feature = 1; #根据属性名获取属性值的字典
};
meaaage Feature{
oneof kind{
BytesList bytes_list = 1;
FloatList float_list = 2;
Int64List int64_list = 3;
}
};
1.4.1写入TFRecords文件步骤:
1)构造文件读写器: writer=tf.python_io.TFRecordWriter(path)
2)构造train.Example对象
注意:Feature仅支持bytes_list,float_list,int64_list 三种对象类型,所以图片类型需要通过tostring序列化
example = tf.train.Example(features=tf.train.Features(feature={
"image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
}))
3)序列化train.Example对象,并写入文件
writer.write(example.SerializeToString())
1.4.2读取TFRecords文件步骤:
解析单一Example对象:tf.parse_single_example(serialized,features,name,example_names)
解析Features对象:tf.FixedLenFeature(shape,dtype)
1)构造文件名队列
2)读取和解码
读取
解析example
feature = tf.parse_single_example(value, features={
"image":tf.FixedLenFeature([], tf.string),
"label":tf.FixedLenFeature([], tf.int64)
})
image = feature["image"]
label = feature["label"]
解码
tf.decode_raw()
3)构造批处理队列
上一篇: 初识hibernate
推荐阅读
-
常用文件读取方法——tensorflow文件读取
-
前端 读取XML文件
-
Springboot读取自定义properties文件和application.properties文件的值
-
springboot 读取自定义.properties 文件的内容
-
[iOS Tips]读取文件的头8字节和尾8字节的十六进制
-
SpringBoot 2.x 从yml文件中读取配置自动解密,同时附上DESUtil
-
php删除txt文件指定行及按行读取txt文档数据的方法
-
PHP遍历目录文件的常用方法小结
-
php实现的读取CSV文件函数示例
-
JAR包读取jar包内部和外部的配置文件,springboot读取外部配置文件的方法(优先级配置)