欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

tfrecord文件生成与读取

程序员文章站 2022-08-12 11:10:56
参考博客——tensorflow-TFRecord 文件详解1. 生成tfrecord文件代码#1.创建tfrecord对象tf_record=tf.python_io.TFRecordWriter(tf_record_name)tf.train.Int64List(value=list_data)tf.train.FloatList( )tf.train.BytesList()tf.train.Feature(int64_list=)tf.train.Feature(float_l...

参考博客——tensorflow-TFRecord 文件详解

1. 生成tfrecord文件

tfrecord文件生成与读取
代码

#1.创建tfrecord对象
tf_record=tf.python_io.TFRecordWriter(tf_record_name)

tf.train.Int64List(value=list_data)
tf.train.FloatList( )
tf.train.BytesList()

tf.train.Feature(int64_list=)
tf.train.Feature(float_list=tf.train.FloatList())
tf.train.Feature(bytes_list=tf.train.BytesList())

tf.train.Features(feature=dict_data)
ut = tf.train.Features(feature={"suibian": tf.train.Feature(int64_list=tf.train.Int64List(value=[1, 2, 4])),"a":tf.train.Feature(float_list=tf.train.FloatList(value=[5., 7.]))})

example=tf.train.Example(features=tf.train.Features(...))

#2. 写入example对象序列化后的结果
tfrecord_writer.write(example.SerializeToString())

2. 读取tfrecord文件

从文件读取有 3 大步骤

  1. 生成读取器,不同类型的文件有对应的读取器

  2. 把文件名列表生成队列

  3. 用读取器的 read 方法读取队列中的文件
    tfrecord文件生成与读取
    tfrecord文件生成与读取

3 代码

3.1 dataset_to_tfrecord.py

tfrecord文件生成与读取

import os
import xml.etree.ElementTree as ET
import tensorflow as tf
from dataset_config import DIRECTORY_ANNOTATIONS,DIRECTORY_IMAGES,NUM_IMAGES_TFRECORD,labels_to_class
from utils.data_process_util import int64_feature,float_feature,bytes_feature
def _convert_to_example(img,img_shape,labels,trunacted,difficult,bndbox_size):
    '''将一张图片使用example,转换成protobuffer 格式
    :param img:
    :param img_shape:
    :param labels:
    :param trunacted:
    :param difficult:
    :param bndbox_size:
    :return:
    '''
    # 为了转换需求,bbox由单个obj的四个位置值,
    # 转变成四个位置的单独列表
    # 即:[[12,120,330,333],[50,60,100,200]]————>[[12,50],[120,60],[330,100],[333,200]]
    ymin=[]
    xmin=[]
    ymax=[]
    xmax=[]
    for b in bndbox_size:
        ymin.append(b[0])
        xmin.append(b[1])
        ymax.append(b[2])
        xmax.append(b[3])
    img_format = b'JPEG'
    print(type(labels))
    for i,label in enumerate(labels):
        labels[i]=labels_to_class[label]
    print('trunacted:',trunacted,type(trunacted),len(trunacted))

    example = tf.train.Example(features=tf.train.Features(feature={
        'image/height':int64_feature(img_shape[0]),
        'image/width':int64_feature(img_shape[1]),
        'image/channels':int64_feature(img_shape[2]),
        'image/shape':int64_feature(img_shape),
        'image/object/bbox/xmin':float_feature(xmin),
        'image/object/bbox/ymin':float_feature(ymin),
        'image/object/bbox/xmax':float_feature(xmax),
        'image/object/bbox/ymax':float_feature(ymax),
        'image/object/bbox/label_text':int64_feature(labels),
        # 'image/object/bbox/trunacted':bytes_feature(trunacted),
        # 'image/object/bbox/difficult':bytes_feature(difficult),
        'image/object/bbox/format':bytes_feature(img_format),
        'image/object/bbox/data':bytes_feature(img)# 读取的图像值
    }))
    print(img_format)
    return example

def _process_image(dataset_dir,img_name):
    '''
    读取图像和xml文件
    :param dataset_dir:
    :param img_name:
    :return:
    '''
    #1.读取图像
    #图像路径
    img_path = os.path.join(dataset_dir,DIRECTORY_IMAGES,img_name+'.jpg')
    img = tf.gfile.FastGFile(img_path,'rb').read()#tensorflow读取图像
    #2.读取xml
    #xml路径
    xml_path =os.path.join(dataset_dir,DIRECTORY_ANNOTATIONS,img_name+'.xml')
    tree = ET.parse(xml_path)
    root = tree.getroot()#获取根节点,'annotation'标签
    # 2.1获取图像尺寸信息
    size = root.find('size')
    img_shape=[
        int(size.find('height').text),
        int(size.find('width').text),
        int(size.find('depth').text)
    ]
    #2.2 获取bounding box 相关信息
    # bounding box可能有多个,用多个列表存储相关信息。
    labels = []
    trunacted=[]
    difficult = []
    bndbox_sizes=[]
    bboxes = root.findall('object')
    for obj in bboxes:
        label = obj.find('name').text
        if obj.find('trunacted'):
            trunacted.append(obj.find('trunacted').text)
        else:
            trunacted.append('0')
        if obj.find(''):
            difficult.append(obj.find('difficult').text)
        else:
            difficult.append(0)
        bndbox = obj.find('bndbox')
        bndbox_size=(
            float(bndbox.find('ymin').text)/img_shape[0],
            float(bndbox.find('xmin').text)/img_shape[1],
            float(bndbox.find('ymax').text)/img_shape[0],
            float(bndbox.find('xmax').text)/img_shape[1]

        )
        labels.append(label)
        trunacted.append(trunacted)
        difficult.append(difficult)
        bndbox_sizes.append(bndbox_size)
    return img,img_shape,labels,trunacted,difficult,bndbox_sizes


def _add_to_tfrecord(dataset_dir,img_name,tfrecord_writer):
    '''
    读取图片和xml文件,保存成一个Example
    :param dataset_dir:根目录
    :param img_name:图像名称
    :param tfrecord_writer:
    :return:
    '''
    #1.读取图片内容及相应的xml文件
    img, img_shape, labels, trunacted, difficult, bndbox_size=_process_image(dataset_dir,img_name)
    # return img,img_shape,labels,trunacted,difficult,bndbox_size
    #2.读取的内容封装成Example,
    example = _convert_to_example(img, img_shape, labels, trunacted, difficult, bndbox_size)

    #3.Example序列化结果写入指定tfrecord文件
    tfrecord_writer.write(example.SerializeToString())

def _get_output_tfrecord_name(output_dir,name,fdx):
    """

    :param output_dir:
    :param name:
    :param fdx:第几个tfrecord文件
    :return:
    """
    return os.path.join(output_dir,name,'%06d'%fdx+'.tfrecord')

def read_tfrecord():
    slim = tf.contrib.slim
    dataset = slim.dataset
    #第一个参数,文件路径
    file_pattern = os.path.join('tf_records\data','*.record')
    #第二个参数
    reader = tf.TFRecordReader


    # file_pattern = '%s-*  '  # 前面保存的tfrecord文件的文件名类似于“train-00001-of-00004.tfrecord”
    # file_pattern = os.path.join(dataset_dir, file_pattern % split_name)  # dataset_dir即前面保存的tfrecord文件的路径


    # 使用slim中的函数tf.FixedLenFeature将tfrecord的example反序列化成存储之前的格式,
    # 字符串格式的用''表示,整型格式的用0表示,其他确定的信息还原为原来的形式,如'jpeg','png'

    keys_to_features = {
        'image/object/bbox/data': tf.FixedLenFeature((), tf.string, default_value=''),
        'image/object/bbox/format': tf.FixedLenFeature((), tf.string, default_value='jpeg'),
        'image/height': tf.FixedLenFeature((), tf.int64, default_value=0),
        'image/width': tf.FixedLenFeature((), tf.int64, default_value=0),
        'image/object/bbox/label_text': tf.FixedLenFeature((), tf.int64, default_value=0)}
    # 将反序列化的数据重组为更适合网络读入的格式
    items_to_handlers = {
        'image': slim.tfexample_decoder.Image(
            image_key='image/object/bbox/data',
            format_key='image/object/bbox/format',
            channels=3),
        # 'image_name': tfexample_decoder.Tensor('image/filename'),
        'height': slim.tfexample_decoder.Tensor('image/height'),
        'width': slim.tfexample_decoder.Tensor('image/width'),
        # 'labels_class': tfexample_decoder.Image(
        #     image_key='image/segmentation/class/encoded',
        #     format_key='image/segmentation/class/format',
        #     channels=1)
            }
    # 解码器进行解码,定义一个解码器对象,保存到dataset中
    # 第三个参数decoder
    decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers)
    # 返回由tfrecord信息所得到的数据集dataset,dataset对象定义了数据集的文件位置,解码方式等元信息
    dataset = dataset.Dataset(
        data_sources=file_pattern,  # tfrecord路径
        reader=tf.TFRecordReader,  # 读取tfrecord文件的方式
        decoder=decoder,  # 解码tfrecord文件的方式
        num_samples=1464,  # PASCAL-VOC2012数据集训练样本数
        items_to_descriptions={  # 样本集图像和标签描述
            'image': 'A color image of varying height and width.',
            'labels_class': ('A semantic segmentation label whose size matches image.'
                             'Its values range from 0 (background) to num_classes.')},
        num_classes = 3,  # 数据集包含类别数(20个前景类别和1个背景类别)
        multi_label = True)  # 多标签(具体我也不太清楚)

    dataset_data_provider = slim.dataset_data_provider
    prefetch_queue = slim.prefetch_queue

    # 创建一个DatasetDataProvider类的对象data_provider,根据dataset和其他的一些已知信息读取数据。
    data_provider = dataset_data_provider.DatasetDataProvider(
        dataset,
        num_readers=1,
        num_epochs=None,
        shuffle=True)
    # 通过调用data_provider对象的get实例函数能够根据data_provider中给出的信息解读tfrecord文件,生成图像和标签和图像文件名
    image, height, width = data_provider.get(['image', 'height', 'width'])
    # image_name, = data_provider.get(['image_name'])
    # label = data_provider.get(['label'])
    # 图像预处理过程,这里具体的处理过程与本文主题无关,因此省略具体的处理过程
    return image, height, width

def run(dataset_dir,output_dir,name='data'):
    """
    运行转换代码逻辑。
    存入多个tfrecord文件,每个文件固定N个样本
    :param dataset_dir:数据集目录,包含annotations,jpeg文件夹
    :param output_dir:tfrecords存储目录
    :param name:数据集名字,指定名字以及train or test
    :return:
    """
    # 1. 判断数据集目录是否存在,创建一个目录
    if tf.gfile.Exists(dataset_dir):
        tf.gfile.MakeDirs(dataset_dir)
    # 输出路径需要已存在
    # if tf.gfile.Exists(output_dir):
    #     tf.gfile.MakeDirs(output_dir)
    # 2. 读取某个文件夹下的所有文件名字列表
    dir_path = os.path.join(dataset_dir,DIRECTORY_ANNOTATIONS)
    files_path = sorted(os.listdir(dir_path))
    print(files_path)
    # 3. 循环名字列表,
    # 每200(NUM_IMAGES_TFRECORD)个图片及xml文件存储到一个tfrecord文件中
    num = len(files_path)
    i = 0
    fdx = 0
    while i < num:
        tf_record_name = _get_output_tfrecord_name(output_dir,name,fdx)
        with tf.python_io.TFRecordWriter(tf_record_name) as tf_record_writer:
            j = 0
            while i<num and j < NUM_IMAGES_TFRECORD:
                xml_path = files_path[i]
                img_name = xml_path.split('.')[0]
                #每个图像构建一个Example,保存到tf_record_name中
                _add_to_tfrecord(dataset_dir,img_name,tf_record_writer)

                j += 1
                i += 1

        fdx += 1
        print('fdx',fdx)
    print('数据集%s转换成功'%(dataset_dir))



3.2 tfrecord文件读取

tfrecord文件生成与读取

def read_tfrecord():
    slim = tf.contrib.slim
    dataset = slim.dataset
    #第一个参数,文件路径
    file_pattern = os.path.join('tf_records\data','*.tfrecord')
    #第二个参数
    reader = tf.TFRecordReader


    # file_pattern = '%s-*  '  # 前面保存的tfrecord文件的文件名类似于“train-00001-of-00004.tfrecord”
    # file_pattern = os.path.join(dataset_dir, file_pattern % split_name)  # dataset_dir即前面保存的tfrecord文件的路径


    # 使用slim中的函数tf.FixedLenFeature将tfrecord的example反序列化成存储之前的格式,
    # 字符串格式的用''表示,整型格式的用0表示,其他确定的信息还原为原来的形式,如'jpeg','png'

    keys_to_features = {
        'image/object/bbox/data': tf.FixedLenFeature((), tf.string, default_value=''),
        'image/object/bbox/format': tf.FixedLenFeature((), tf.string, default_value='jpeg'),
        'image/height': tf.FixedLenFeature((), tf.int64, default_value=0),
        'image/width': tf.FixedLenFeature((), tf.int64, default_value=0),
        'image/object/bbox/label_text': tf.FixedLenFeature((), tf.int64, default_value=0)}
    # 将反序列化的数据重组为更适合网络读入的格式
    items_to_handlers = {
        'image': slim.tfexample_decoder.Image(
            image_key='image/object/bbox/data',
            format_key='image/object/bbox/format',
            channels=3),
        # 'image_name': tfexample_decoder.Tensor('image/filename'),
        'height': slim.tfexample_decoder.Tensor('image/height'),
        'width': slim.tfexample_decoder.Tensor('image/width'),
        # 'labels_class': tfexample_decoder.Image(
        #     image_key='image/segmentation/class/encoded',
        #     format_key='image/segmentation/class/format',
        #     channels=1)
            }
    # 解码器进行解码,定义一个解码器对象,保存到dataset中
    # 第三个参数decoder
    decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers)
    # 返回由tfrecord信息所得到的数据集dataset,dataset对象定义了数据集的文件位置,解码方式等元信息
    dataset = dataset.Dataset(
        data_sources=file_pattern,  # tfrecord路径
        reader=tf.TFRecordReader,  # 读取tfrecord文件的方式
        decoder=decoder,  # 解码tfrecord文件的方式
        num_samples=1464,  # PASCAL-VOC2012数据集训练样本数
        items_to_descriptions={  # 样本集图像和标签描述
            'image': 'A color image of varying height and width.',
            'labels_class': ('A semantic segmentation label whose size matches image.'
                             'Its values range from 0 (background) to num_classes.')},
        num_classes = 3,  # 数据集包含类别数(20个前景类别和1个背景类别)
        multi_label = True)  # 多标签(具体我也不太清楚)

    dataset_data_provider = slim.dataset_data_provider
    prefetch_queue = slim.prefetch_queue

    # 创建一个DatasetDataProvider类的对象data_provider,根据dataset和其他的一些已知信息读取数据。
    data_provider = dataset_data_provider.DatasetDataProvider(
        dataset,
        num_readers=1,
        num_epochs=None,
        shuffle=True)
    # 通过调用data_provider对象的get实例函数能够根据data_provider中给出的信息解读tfrecord文件,生成图像和标签和图像文件名
    image, height, width = data_provider.get(['image', 'height', 'width'])
    # image_name, = data_provider.get(['image_name'])
    # label = data_provider.get(['label'])
    # 图像预处理过程,这里具体的处理过程与本文主题无关,因此省略具体的处理过程
    return image, height, width

本文地址:https://blog.csdn.net/Blankit1/article/details/107167425