欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

常用文件读取方法——tensorflow文件读取

程序员文章站 2024-03-14 22:00:23
...

tensorflow文件读取(多线程+队列)

重要的函数:数据切片 tf.slice()

                     数组装置  tf.transpose()

                    类型变换   tf.cast()

                     序列化     tostring()

常用文件读取方法——tensorflow文件读取

常用文件读取方法——tensorflow文件读取

  1文件读取流程

      1.1构造文件名队列  

将文件名列表交给tf.train.string_input_producer 函数来生成一个先入先出的队列, 文件阅读器会需要它来读取数据
tf.train.string_input_producer(string_tensor, num_epochs=None, shuffle=True, seed=None, capacity=32, name=None)

  string_tensor:含有文件名+路径的1阶张量

  num_epochs:过几遍数据,默认过无限编

  shuffle:是否打乱数据

Returns: 文件队列

file_queue = tf.train.string_input_producer(string_tensor,shuffle=True)

       1.2.读取与解码

1.2.1.根据文件格式创建文件阅读器

       文本/csv格式:  tf.TextLineReader()

       图像格式:         tf.WholeFileReader()

       二进制:               tf.FixedLengthRecordReader()

      TFRecords文件:  tf.TFRecordReader()

1.2.2.文件读取——将文件名队列提供给阅读器的read方法

阅读器的read方法会输出一个key来表征输入的文件和其中的纪录(对于调试非常有用),同时得到一个字符串标量, 这个字符串标量可以被一个或多个解析器,或者转换操作将其解码为张量并且构造成为样本。

  key, value  =阅读器.read(queue, name=None)

  • key: 文件名
  • value: 样本

1.2.3.文件解码——根据文件格式选择不同的解码器

       文本/csv格式:  tf.decode_csv()

       图像格式(不同的图片格式有不同的解码方式):

                                  tf.image.decode_png()

          tf.image.decode_jpeg()

       二进制:               tf.decode_raw()

      TFRecords文件:  tf.parse_single_example()

解码后默认格式:uint8

      1.3.批处理——读取指定大小(个数)的张量——批处理前样本形状,类型必须一致

      在数据输入管线的末端, 我们需要有另一个队列来执行输入样本的训练,评价和推理。因此我们使用f.train.batch、tf.train.shuffle_batch 函数来对队列中的样本进行乱序处理

tf.train.batch(tensor_list, batch_size, num_threads=1, capacity=32, enqueue_many=False, shapes=None, name=None)

Args:

  • tensor_list: The list of tensors to enqueue.
  • batch_size: 从列表中读取的批处理器大小
  • num_threads: 进入队列的线程数
  • capacity: An integer. The maximum number of elements in the queue.
  • enqueue_many: Whether each tensor in tensor_list is a single example.
  • shapes: (Optional) The shapes for each example. Defaults to the inferred shapes for tensor_list.
  • name: (Optional) A name for the operations.

Returns:

    A list of tensors with the same number and types as tensor_list.

      1.4.  线程操作

1.4.1开启线程协调器——对线程进行协调和管理

      class tf.train.Coordinator()

  request_stop() 请求停止

  should_stop()   询问是否结束

  join(threads, stop_grace_period_secs=120)回收线程

     Returns: 线程协调器实例

1.4.2收集图中所有队列线程,同时默认启动线程

tf.train.start_queue_runners(sess=None, coord=None, daemon=True, start=True, collection='queue_runners')

Args:

    sess: 所在的会话
    coord: 线程协调器
    daemon: Whether the threads should be marked as daemons, meaning they don't block program exit.
    start: Set to False to only create the threads, not start them.
    collection: A GraphKey specifying the graph collection to get the queue runners from. Defaults

                    toGraphKeys.QUEUE_RUNNERS.
Returns:

A list of threads.

1.4.3回收线程

  2图像读取案例

import tensorflow as tf
import os


def read_picture():
    """
    读取狗图片案例
    :return:
    """
    # 1、构造文件名队列
    # 构造文件名列表
    filename_list = os.listdir("./dog")
    # 给文件名加上路径
    file_list = [os.path.join("./dog/", i) for i in filename_list]
    # print("file_list:\n", file_list)
    # print("filename_list:\n", filename_list)
    file_queue = tf.train.string_input_producer(file_list)

    # 2、读取与解码
    # 读取
    reader = tf.WholeFileReader()
    key, value = reader.read(file_queue)
    print("key:\n", key)
    print("value:\n", value)

    # 解码
    image_decoded = tf.image.decode_jpeg(value)
    print("image_decoded:\n", image_decoded)

    # 将图片缩放到同一个大小
    image_resized = tf.image.resize_images(image_decoded, [200, 200])
    print("image_resized_before:\n", image_resized)
    # 更新静态形状
    image_resized.set_shape([200, 200, 3])
    print("image_resized_after:\n", image_resized)


    # 3、批处理队列
    image_batch = tf.train.batch([image_resized], batch_size=100, num_threads=2, capacity=100)
    print("image_batch:\n", image_batch)

    # 开启会话
    with tf.Session() as sess:
        # 开启线程
        # 构造线程协调器
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        # 运行
        filename, sample, image, n_image = sess.run([key, value, image_resized, image_batch])
        print("filename:\n", filename)
        print("sample:\n", sample)
        print("image:\n", image)
        print("n_image:\n", n_image)

        coord.request_stop()
        coord.join(threads)


    return None

if __name__ == "__main__":
    # 代码1:读取狗图片案例
    read_picture()

   3二进制、TFRecords文件读取案例

import tensorflow as tf
import os


class Cifar():

    def __init__(self):

        # 设置图像大小
        self.height = 32
        self.width = 32
        self.channel = 3

        # 设置图像字节数
        self.image = self.height * self.width * self.channel
        self.label = 1
        self.sample = self.image + self.label


    def read_binary(self):
        """
        读取二进制文件
        :return:
        """
        # 1、构造文件名队列
        filename_list = os.listdir("./cifar-10-batches-bin")
        # print("filename_list:\n", filename_list)
        file_list = [os.path.join("./cifar-10-batches-bin/", i) for i in filename_list if i[-3:]=="bin"]
        # print("file_list:\n", file_list)
        file_queue = tf.train.string_input_producer(file_list)

        # 2、读取与解码
        # 读取
        reader = tf.FixedLengthRecordReader(self.sample)
        # key文件名 value样本
        key, value = reader.read(file_queue)

        # 解码
        image_decoded = tf.decode_raw(value, tf.uint8)
        print("image_decoded:\n", image_decoded)

        # 切片操作
        label = tf.slice(image_decoded, [0], [self.label])
        image = tf.slice(image_decoded, [self.label], [self.image])
        print("label:\n", label)
        print("image:\n", image)

        # 调整图像的形状
        image_reshaped = tf.reshape(image, [self.channel, self.height, self.width])
        print("image_reshaped:\n", image_reshaped)

        # 三维数组的转置
        image_transposed = tf.transpose(image_reshaped, [1, 2, 0])
        print("image_transposed:\n", image_transposed)

        # 3、构造批处理队列
        image_batch, label_batch = tf.train.batch([image_transposed, label], batch_size=100, num_threads=2, capacity=100)

        # 开启会话
        with tf.Session() as sess:

            # 开启线程
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            label_value, image_value = sess.run([label_batch, image_batch])
            print("label_value:\n", label_value)
            print("image:\n", image_value)

            coord.request_stop()
            coord.join(threads)

        return image_value, label_value

    def write_to_tfrecords(self, image_batch, label_batch):
        """
        将样本的特征值和目标值一起写入tfrecords文件
        :param image:
        :param label:
        :return:
        """
        with tf.python_io.TFRecordWriter("cifar10.tfrecords") as writer:
            # 循环构造example对象,并序列化写入文件
            for i in range(100):
                image = image_batch[i].tostring()
                label = label_batch[i][0]
                # print("tfrecords_image:\n", image)
                # print("tfrecords_label:\n", label)
                example = tf.train.Example(features=tf.train.Features(feature={
                    "image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
                    "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
                }))
                # example.SerializeToString()
                # 将序列化后的example写入文件
                writer.write(example.SerializeToString())

        return None

    def read_tfrecords(self):
        """
        读取TFRecords文件
        :return:
        """
        # 1、构造文件名队列
        file_queue = tf.train.string_input_producer(["cifar10.tfrecords"])

        # 2、读取与解码
        # 读取
        reader = tf.TFRecordReader()
        key, value = reader.read(file_queue)

        # 解析example
        feature = tf.parse_single_example(value, features={
            "image": tf.FixedLenFeature([], tf.string),
            "label": tf.FixedLenFeature([], tf.int64)
        })
        image = feature["image"]
        label = feature["label"]
        print("read_tf_image:\n", image)
        print("read_tf_label:\n", label)

        # 解码
        image_decoded = tf.decode_raw(image, tf.uint8)
        print("image_decoded:\n", image_decoded)
        # 图像形状调整
        image_reshaped = tf.reshape(image_decoded, [self.height, self.width, self.channel])
        print("image_reshaped:\n", image_reshaped)

        # 3、构造批处理队列
        image_batch, label_batch = tf.train.batch([image_reshaped, label], batch_size=100, num_threads=2, capacity=100)
        print("image_batch:\n", image_batch)
        print("label_batch:\n", label_batch)

        # 开启会话
        with tf.Session() as sess:

            # 开启线程
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            image_value, label_value = sess.run([image_batch, label_batch])
            print("image_value:\n", image_value)
            print("label_value:\n", label_value)

            # 回收资源
            coord.request_stop()
            coord.join(threads)

        return None

if __name__ == "__main__":
    cifar = Cifar()
    # image_value, label_value = cifar.read_binary()
    # cifar.write_to_tfrecords(image_value, label_value)
    cifar.read_tfrecords()

4TFRecords文件

train.Example协议内存块文件定义了将数据进行序列化的格式

message Example{
  Features features = 1;
};

message Features{
  map<string, Feature> feature = 1;    #根据属性名获取属性值的字典
};

meaaage Feature{
  oneof kind{
    BytesList bytes_list = 1;
    FloatList float_list = 2;
    Int64List int64_list = 3;
  }
};

1.4.1写入TFRecords文件步骤:

       1)构造文件读写器: writer=tf.python_io.TFRecordWriter(path)

       2)构造train.Example对象

注意:Feature仅支持bytes_list,float_list,int64_list 三种对象类型,所以图片类型需要通过tostring序列化

   example = tf.train.Example(features=tf.train.Features(feature={
                    "image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
                    "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
                }))

       3)序列化train.Example对象,并写入文件

   writer.write(example.SerializeToString())

1.4.2读取TFRecords文件步骤:

解析单一Example对象:tf.parse_single_example(serialized,features,name,example_names)

解析Features对象:tf.FixedLenFeature(shape,dtype)

        1)构造文件名队列
        2)读取和解码
            读取
            解析example
            feature = tf.parse_single_example(value, features={
            "image":tf.FixedLenFeature([], tf.string),
            "label":tf.FixedLenFeature([], tf.int64)
            })
            image = feature["image"]
            label = feature["label"]
            解码
            tf.decode_raw()
        3)构造批处理队列

相关标签: 文件读取