tfrecord文件生成与读取
程序员文章站
2022-08-12 11:10:56
参考博客——tensorflow-TFRecord 文件详解1. 生成tfrecord文件代码#1.创建tfrecord对象tf_record=tf.python_io.TFRecordWriter(tf_record_name)tf.train.Int64List(value=list_data)tf.train.FloatList( )tf.train.BytesList()tf.train.Feature(int64_list=)tf.train.Feature(float_l...
参考博客——tensorflow-TFRecord 文件详解
1. 生成tfrecord文件
代码
#1.创建tfrecord对象
tf_record=tf.python_io.TFRecordWriter(tf_record_name)
tf.train.Int64List(value=list_data)
tf.train.FloatList( )
tf.train.BytesList()
tf.train.Feature(int64_list=)
tf.train.Feature(float_list=tf.train.FloatList())
tf.train.Feature(bytes_list=tf.train.BytesList())
tf.train.Features(feature=dict_data)
ut = tf.train.Features(feature={"suibian": tf.train.Feature(int64_list=tf.train.Int64List(value=[1, 2, 4])),"a":tf.train.Feature(float_list=tf.train.FloatList(value=[5., 7.]))})
example=tf.train.Example(features=tf.train.Features(...))
#2. 写入example对象序列化后的结果
tfrecord_writer.write(example.SerializeToString())
2. 读取tfrecord文件
从文件读取有 3 大步骤
-
生成读取器,不同类型的文件有对应的读取器
-
把文件名列表生成队列
-
用读取器的 read 方法读取队列中的文件
3 代码
3.1 dataset_to_tfrecord.py
import os
import xml.etree.ElementTree as ET
import tensorflow as tf
from dataset_config import DIRECTORY_ANNOTATIONS,DIRECTORY_IMAGES,NUM_IMAGES_TFRECORD,labels_to_class
from utils.data_process_util import int64_feature,float_feature,bytes_feature
def _convert_to_example(img,img_shape,labels,trunacted,difficult,bndbox_size):
'''将一张图片使用example,转换成protobuffer 格式
:param img:
:param img_shape:
:param labels:
:param trunacted:
:param difficult:
:param bndbox_size:
:return:
'''
# 为了转换需求,bbox由单个obj的四个位置值,
# 转变成四个位置的单独列表
# 即:[[12,120,330,333],[50,60,100,200]]————>[[12,50],[120,60],[330,100],[333,200]]
ymin=[]
xmin=[]
ymax=[]
xmax=[]
for b in bndbox_size:
ymin.append(b[0])
xmin.append(b[1])
ymax.append(b[2])
xmax.append(b[3])
img_format = b'JPEG'
print(type(labels))
for i,label in enumerate(labels):
labels[i]=labels_to_class[label]
print('trunacted:',trunacted,type(trunacted),len(trunacted))
example = tf.train.Example(features=tf.train.Features(feature={
'image/height':int64_feature(img_shape[0]),
'image/width':int64_feature(img_shape[1]),
'image/channels':int64_feature(img_shape[2]),
'image/shape':int64_feature(img_shape),
'image/object/bbox/xmin':float_feature(xmin),
'image/object/bbox/ymin':float_feature(ymin),
'image/object/bbox/xmax':float_feature(xmax),
'image/object/bbox/ymax':float_feature(ymax),
'image/object/bbox/label_text':int64_feature(labels),
# 'image/object/bbox/trunacted':bytes_feature(trunacted),
# 'image/object/bbox/difficult':bytes_feature(difficult),
'image/object/bbox/format':bytes_feature(img_format),
'image/object/bbox/data':bytes_feature(img)# 读取的图像值
}))
print(img_format)
return example
def _process_image(dataset_dir,img_name):
'''
读取图像和xml文件
:param dataset_dir:
:param img_name:
:return:
'''
#1.读取图像
#图像路径
img_path = os.path.join(dataset_dir,DIRECTORY_IMAGES,img_name+'.jpg')
img = tf.gfile.FastGFile(img_path,'rb').read()#tensorflow读取图像
#2.读取xml
#xml路径
xml_path =os.path.join(dataset_dir,DIRECTORY_ANNOTATIONS,img_name+'.xml')
tree = ET.parse(xml_path)
root = tree.getroot()#获取根节点,'annotation'标签
# 2.1获取图像尺寸信息
size = root.find('size')
img_shape=[
int(size.find('height').text),
int(size.find('width').text),
int(size.find('depth').text)
]
#2.2 获取bounding box 相关信息
# bounding box可能有多个,用多个列表存储相关信息。
labels = []
trunacted=[]
difficult = []
bndbox_sizes=[]
bboxes = root.findall('object')
for obj in bboxes:
label = obj.find('name').text
if obj.find('trunacted'):
trunacted.append(obj.find('trunacted').text)
else:
trunacted.append('0')
if obj.find(''):
difficult.append(obj.find('difficult').text)
else:
difficult.append(0)
bndbox = obj.find('bndbox')
bndbox_size=(
float(bndbox.find('ymin').text)/img_shape[0],
float(bndbox.find('xmin').text)/img_shape[1],
float(bndbox.find('ymax').text)/img_shape[0],
float(bndbox.find('xmax').text)/img_shape[1]
)
labels.append(label)
trunacted.append(trunacted)
difficult.append(difficult)
bndbox_sizes.append(bndbox_size)
return img,img_shape,labels,trunacted,difficult,bndbox_sizes
def _add_to_tfrecord(dataset_dir,img_name,tfrecord_writer):
'''
读取图片和xml文件,保存成一个Example
:param dataset_dir:根目录
:param img_name:图像名称
:param tfrecord_writer:
:return:
'''
#1.读取图片内容及相应的xml文件
img, img_shape, labels, trunacted, difficult, bndbox_size=_process_image(dataset_dir,img_name)
# return img,img_shape,labels,trunacted,difficult,bndbox_size
#2.读取的内容封装成Example,
example = _convert_to_example(img, img_shape, labels, trunacted, difficult, bndbox_size)
#3.Example序列化结果写入指定tfrecord文件
tfrecord_writer.write(example.SerializeToString())
def _get_output_tfrecord_name(output_dir,name,fdx):
"""
:param output_dir:
:param name:
:param fdx:第几个tfrecord文件
:return:
"""
return os.path.join(output_dir,name,'%06d'%fdx+'.tfrecord')
def read_tfrecord():
slim = tf.contrib.slim
dataset = slim.dataset
#第一个参数,文件路径
file_pattern = os.path.join('tf_records\data','*.record')
#第二个参数
reader = tf.TFRecordReader
# file_pattern = '%s-* ' # 前面保存的tfrecord文件的文件名类似于“train-00001-of-00004.tfrecord”
# file_pattern = os.path.join(dataset_dir, file_pattern % split_name) # dataset_dir即前面保存的tfrecord文件的路径
# 使用slim中的函数tf.FixedLenFeature将tfrecord的example反序列化成存储之前的格式,
# 字符串格式的用''表示,整型格式的用0表示,其他确定的信息还原为原来的形式,如'jpeg','png'
keys_to_features = {
'image/object/bbox/data': tf.FixedLenFeature((), tf.string, default_value=''),
'image/object/bbox/format': tf.FixedLenFeature((), tf.string, default_value='jpeg'),
'image/height': tf.FixedLenFeature((), tf.int64, default_value=0),
'image/width': tf.FixedLenFeature((), tf.int64, default_value=0),
'image/object/bbox/label_text': tf.FixedLenFeature((), tf.int64, default_value=0)}
# 将反序列化的数据重组为更适合网络读入的格式
items_to_handlers = {
'image': slim.tfexample_decoder.Image(
image_key='image/object/bbox/data',
format_key='image/object/bbox/format',
channels=3),
# 'image_name': tfexample_decoder.Tensor('image/filename'),
'height': slim.tfexample_decoder.Tensor('image/height'),
'width': slim.tfexample_decoder.Tensor('image/width'),
# 'labels_class': tfexample_decoder.Image(
# image_key='image/segmentation/class/encoded',
# format_key='image/segmentation/class/format',
# channels=1)
}
# 解码器进行解码,定义一个解码器对象,保存到dataset中
# 第三个参数decoder
decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers)
# 返回由tfrecord信息所得到的数据集dataset,dataset对象定义了数据集的文件位置,解码方式等元信息
dataset = dataset.Dataset(
data_sources=file_pattern, # tfrecord路径
reader=tf.TFRecordReader, # 读取tfrecord文件的方式
decoder=decoder, # 解码tfrecord文件的方式
num_samples=1464, # PASCAL-VOC2012数据集训练样本数
items_to_descriptions={ # 样本集图像和标签描述
'image': 'A color image of varying height and width.',
'labels_class': ('A semantic segmentation label whose size matches image.'
'Its values range from 0 (background) to num_classes.')},
num_classes = 3, # 数据集包含类别数(20个前景类别和1个背景类别)
multi_label = True) # 多标签(具体我也不太清楚)
dataset_data_provider = slim.dataset_data_provider
prefetch_queue = slim.prefetch_queue
# 创建一个DatasetDataProvider类的对象data_provider,根据dataset和其他的一些已知信息读取数据。
data_provider = dataset_data_provider.DatasetDataProvider(
dataset,
num_readers=1,
num_epochs=None,
shuffle=True)
# 通过调用data_provider对象的get实例函数能够根据data_provider中给出的信息解读tfrecord文件,生成图像和标签和图像文件名
image, height, width = data_provider.get(['image', 'height', 'width'])
# image_name, = data_provider.get(['image_name'])
# label = data_provider.get(['label'])
# 图像预处理过程,这里具体的处理过程与本文主题无关,因此省略具体的处理过程
return image, height, width
def run(dataset_dir,output_dir,name='data'):
"""
运行转换代码逻辑。
存入多个tfrecord文件,每个文件固定N个样本
:param dataset_dir:数据集目录,包含annotations,jpeg文件夹
:param output_dir:tfrecords存储目录
:param name:数据集名字,指定名字以及train or test
:return:
"""
# 1. 判断数据集目录是否存在,创建一个目录
if tf.gfile.Exists(dataset_dir):
tf.gfile.MakeDirs(dataset_dir)
# 输出路径需要已存在
# if tf.gfile.Exists(output_dir):
# tf.gfile.MakeDirs(output_dir)
# 2. 读取某个文件夹下的所有文件名字列表
dir_path = os.path.join(dataset_dir,DIRECTORY_ANNOTATIONS)
files_path = sorted(os.listdir(dir_path))
print(files_path)
# 3. 循环名字列表,
# 每200(NUM_IMAGES_TFRECORD)个图片及xml文件存储到一个tfrecord文件中
num = len(files_path)
i = 0
fdx = 0
while i < num:
tf_record_name = _get_output_tfrecord_name(output_dir,name,fdx)
with tf.python_io.TFRecordWriter(tf_record_name) as tf_record_writer:
j = 0
while i<num and j < NUM_IMAGES_TFRECORD:
xml_path = files_path[i]
img_name = xml_path.split('.')[0]
#每个图像构建一个Example,保存到tf_record_name中
_add_to_tfrecord(dataset_dir,img_name,tf_record_writer)
j += 1
i += 1
fdx += 1
print('fdx',fdx)
print('数据集%s转换成功'%(dataset_dir))
3.2 tfrecord文件读取
def read_tfrecord():
slim = tf.contrib.slim
dataset = slim.dataset
#第一个参数,文件路径
file_pattern = os.path.join('tf_records\data','*.tfrecord')
#第二个参数
reader = tf.TFRecordReader
# file_pattern = '%s-* ' # 前面保存的tfrecord文件的文件名类似于“train-00001-of-00004.tfrecord”
# file_pattern = os.path.join(dataset_dir, file_pattern % split_name) # dataset_dir即前面保存的tfrecord文件的路径
# 使用slim中的函数tf.FixedLenFeature将tfrecord的example反序列化成存储之前的格式,
# 字符串格式的用''表示,整型格式的用0表示,其他确定的信息还原为原来的形式,如'jpeg','png'
keys_to_features = {
'image/object/bbox/data': tf.FixedLenFeature((), tf.string, default_value=''),
'image/object/bbox/format': tf.FixedLenFeature((), tf.string, default_value='jpeg'),
'image/height': tf.FixedLenFeature((), tf.int64, default_value=0),
'image/width': tf.FixedLenFeature((), tf.int64, default_value=0),
'image/object/bbox/label_text': tf.FixedLenFeature((), tf.int64, default_value=0)}
# 将反序列化的数据重组为更适合网络读入的格式
items_to_handlers = {
'image': slim.tfexample_decoder.Image(
image_key='image/object/bbox/data',
format_key='image/object/bbox/format',
channels=3),
# 'image_name': tfexample_decoder.Tensor('image/filename'),
'height': slim.tfexample_decoder.Tensor('image/height'),
'width': slim.tfexample_decoder.Tensor('image/width'),
# 'labels_class': tfexample_decoder.Image(
# image_key='image/segmentation/class/encoded',
# format_key='image/segmentation/class/format',
# channels=1)
}
# 解码器进行解码,定义一个解码器对象,保存到dataset中
# 第三个参数decoder
decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers)
# 返回由tfrecord信息所得到的数据集dataset,dataset对象定义了数据集的文件位置,解码方式等元信息
dataset = dataset.Dataset(
data_sources=file_pattern, # tfrecord路径
reader=tf.TFRecordReader, # 读取tfrecord文件的方式
decoder=decoder, # 解码tfrecord文件的方式
num_samples=1464, # PASCAL-VOC2012数据集训练样本数
items_to_descriptions={ # 样本集图像和标签描述
'image': 'A color image of varying height and width.',
'labels_class': ('A semantic segmentation label whose size matches image.'
'Its values range from 0 (background) to num_classes.')},
num_classes = 3, # 数据集包含类别数(20个前景类别和1个背景类别)
multi_label = True) # 多标签(具体我也不太清楚)
dataset_data_provider = slim.dataset_data_provider
prefetch_queue = slim.prefetch_queue
# 创建一个DatasetDataProvider类的对象data_provider,根据dataset和其他的一些已知信息读取数据。
data_provider = dataset_data_provider.DatasetDataProvider(
dataset,
num_readers=1,
num_epochs=None,
shuffle=True)
# 通过调用data_provider对象的get实例函数能够根据data_provider中给出的信息解读tfrecord文件,生成图像和标签和图像文件名
image, height, width = data_provider.get(['image', 'height', 'width'])
# image_name, = data_provider.get(['image_name'])
# label = data_provider.get(['label'])
# 图像预处理过程,这里具体的处理过程与本文主题无关,因此省略具体的处理过程
return image, height, width
本文地址:https://blog.csdn.net/Blankit1/article/details/107167425