TensorFlow输入数据处理框架
程序员文章站
2022-07-13 14:50:57
...
如图,大致为输入数据处理流程示意图。输入数据处理第一步为获取存储训练数据的文件列表,在该图中文件列表为{A,B,C}。通过tf.train.string_input_producer函数可以选择性将文件顺序打乱,并加入输入队列。tf.train.string_input_producer函数会生成并维护一个输入文件队列,不同线程中的文件读取函数可以共享这个文件队列。
在读取样例程序后,需要对图像进行预处理。预处理的过程也会通过tf.train.shuffle_batch提供的机制并行的跑在多个线程中。输入数据处理流程的最后通过tf.train.shuffle_batch函数将处理好的单个输入样例整理成batch提供给神经网络输入层。
import tensorflow as tf
#创建文件列表
files = tf.train.match_filenames_once("Records/output.tfrecords")
#创建文件输入队列
filename_queue = tf.train.string_input_producer(files, shuffle=False)
# 读取文件。
# 解析数据。假设image是图像数据,label是标签,height、width、channels给出了图片的维度
reader = tf.TFRecordReader()
_,serialized_example = reader.read(filename_queue)
# 解析读取的样例。
features = tf.parse_single_example(
serialized_example,
features={
'image': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64),
'height': tf.FixedLenFeature([], tf.int64),
'width': tf.FixedLenFeature([], tf.int64),
'channels': tf.FixedLenFeature([], tf.int64)
})
image, label = features['image'], features['label']
height, width = tf.cast(features['height'], tf.int32), tf.cast(features['width'], tf.int32)
channels = tf.cast(features['channels'], tf.int32)
# 从原始图像中解析出像素矩阵,并根据像素尺寸还原图像
decoded_images = tf.decode_raw(features['image_raw'],tf.uint8)
decoded_image.set_shape([height, width, channels])
#定义神经网络输入层图片的大小
image_size = 299
# preprocess_for_train函数是对图片进行预处理的函数
distorted_image = preprocess_for_train(decoded_image, image_size, image_size,
None)
#将处理后的图像和标签通过tf.train.shuffle_batch整理成神经网络训练时需要的batch
min_after_dequeue = 10000
batch_size = 100
capacity = min_after_dequeue + 3 * batch_size
image_batch, label_batch = tf.train.shuffle_batch([images, labels],
batch_size=batch_size,
capacity=capacity,
min_after_dequeue=min_after_dequeue)
# 定义神经网络的结构及优化过程。image_batch可以作为输入提供给神经网络的输入层
#label_batch则提供了输入batch中样例的正确答案
logit = inference(image_batch)
loss = calc_loss(logit, label_batch)
train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)
#声明会话并运行神经网络优化过程
with tf.Session() as sess:
#神经网络训练准备工作,这些工作包括变量初始化、线程启动
sess.run(
[tf.global_variables_initializer(),
tf.local_variables_initializer()])
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
# 神经网络训练过程
for i in range(TRAINING_ROUNDS):
sess.run(train_step)
#停止所有线程
coord.request_stop()
coord.join()
其代码如下:
上一篇: Java 8——函数式数据处理(流)
下一篇: JUC之并发容器的选择
推荐阅读
-
TensorFlow数据输入的方法示例
-
Tensorflow使用tfrecord输入数据格式
-
TensorFlow多线程输入数据处理框架(四)——输入数据处理框架
-
TensorFlow实战Google深度学习框架-人工智能教程-自学人工智能的第二天-深度学习
-
手写数字识别(使用tensorflow2.2.0框架)
-
使用tensorflow DataSet实现高效加载变长文本输入
-
python生成tensorflow输入输出的图像格式的方法
-
Tensorflow 利用tf.contrib.learn建立输入函数的方法
-
编辑表格输入内容、根据input输入框输入数字动态生成表格行数、编辑表格内容提交传给后台数据处理
-
基于EasyUI的TopJUI前端框架之如何动态改变下拉列表框ComboBox输入框的背景颜色