tensorflow使用range_input_producer多线程读取数据实例
先放关键代码:
i = tf.train.range_input_producer(num_expoches, num_epochs=1, shuffle=false).dequeue() inputs = tf.slice(array, [i * batch_size], [batch_size])
原理解析:
第一行会产生一个队列,队列包含0到num_expoches-1的元素,如果num_epochs有指定,则每个元素只产生num_epochs次,否则循环产生。shuffle指定是否打乱顺序,这里shuffle=false表示队列的元素是按0到num_expoches-1的顺序存储。在graph运行的时候,每个线程从队列取出元素,假设值为i,然后按照第二行代码切出array的一小段数据作为一个batch。例如num_expoches=3,如果num_epochs=2,则队列的内容是这样子;
0,1,2,0,1,2
队列只有6个元素,这样在训练的时候只能产生6个batch,迭代6次以后训练就结束。
如果num_epochs不指定,则队列内容是这样子:
0,1,2,0,1,2,0,1,2,0,1,2...
队列可以一直生成元素,训练的时候可以产生无限的batch,需要自己控制什么时候停止训练。
下面是完整的演示代码。
数据文件test.txt内容:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
main.py内容:
import tensorflow as tf import codecs batch_size = 6 num_expoches = 5 def input_producer(): array = codecs.open("test.txt").readlines() array = map(lambda line: line.strip(), array) i = tf.train.range_input_producer(num_expoches, num_epochs=1, shuffle=false).dequeue() inputs = tf.slice(array, [i * batch_size], [batch_size]) return inputs class inputs(object): def __init__(self): self.inputs = input_producer() def main(*args, **kwargs): inputs = inputs() init = tf.group(tf.initialize_all_variables(), tf.initialize_local_variables()) sess = tf.session() coord = tf.train.coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) sess.run(init) try: index = 0 while not coord.should_stop() and index<10: datalines = sess.run(inputs.inputs) index += 1 print("step: %d, batch data: %s" % (index, str(datalines))) except tf.errors.outofrangeerror: print("done traing:-------epoch limit reached") except keyboardinterrupt: print("keyboard interrput detected, stop training") finally: coord.request_stop() coord.join(threads) sess.close() del sess if __name__ == "__main__": main()
输出:
step: 1, batch data: ['1' '2' '3' '4' '5' '6'] step: 2, batch data: ['7' '8' '9' '10' '11' '12'] step: 3, batch data: ['13' '14' '15' '16' '17' '18'] step: 4, batch data: ['19' '20' '21' '22' '23' '24'] step: 5, batch data: ['25' '26' '27' '28' '29' '30'] done traing:-------epoch limit reached
如果range_input_producer去掉参数num_epochs=1,则输出:
step: 1, batch data: ['1' '2' '3' '4' '5' '6'] step: 2, batch data: ['7' '8' '9' '10' '11' '12'] step: 3, batch data: ['13' '14' '15' '16' '17' '18'] step: 4, batch data: ['19' '20' '21' '22' '23' '24'] step: 5, batch data: ['25' '26' '27' '28' '29' '30'] step: 6, batch data: ['1' '2' '3' '4' '5' '6'] step: 7, batch data: ['7' '8' '9' '10' '11' '12'] step: 8, batch data: ['13' '14' '15' '16' '17' '18'] step: 9, batch data: ['19' '20' '21' '22' '23' '24'] step: 10, batch data: ['25' '26' '27' '28' '29' '30']
有一点需要注意,文件总共有35条数据,batch_size = 6表示每个batch包含6条数据,num_expoches = 5表示产生5个batch,如果num_expoches =6,则总共需要36条数据,就会报如下错误:
invalidargumenterror (see above for traceback): expected size[0] in [0, 5], but got 6 [[node: slice = slice[index=dt_int32, t=dt_string, _device="/job:localhost/replica:0/task:0/cpu:0"](slice/input, slice/begin/_5, slice/size)]]
错误信息的意思是35/batch_size=5,即num_expoches 的取值能只能在0到5之间。
以上这篇tensorflow使用range_input_producer多线程读取数据实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持。
上一篇: 悟空你去化缘
下一篇: 基于Python获取照片的GPS位置信息
推荐阅读
-
使用Python脚本从文件读取数据代码实例
-
使用Tensorflow将自己的数据分割成batch训练实例
-
tensorflow使用range_input_producer多线程读取数据实例
-
Tensorflow中使用tfrecord方式读取数据的方法
-
使用Tensorflow将自己的数据分割成batch训练实例
-
TensorFlow 2.1.0 使用 TFRecord 转存与读取文本数据
-
tensorflow使用range_input_producer多线程读取数据实例
-
python使用numpy读取、保存txt数据的实例
-
使用python读取csv文件快速插入数据库的实例
-
使用Python脚本从文件读取数据代码实例