欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

tf.feature_column解析tfrecord

程序员文章站 2024-01-19 13:25:52
...

对于获取的一个tfrecord eg. a.tfrecord,若未知对应的key,则可参考tf.train.Example.FromString解析tfrecord


# tfrecord结构
# message Example
# {
#     Features features = 1;
# };
#
# message Features
# {
#     map< string, Feature > featrue = 1;
# };
#
# message Feature
#
#     oneof kind{
#     BytesList bytes_list = 1;
#     FloatList float_list = 2;
#     Int64List int64_list = 3;
# }
# };

上一篇:tf 解析tfrecord  介绍tf.io.FixedLenFeature解析tfrecord

ps:在这篇中,介绍一种利用tf.feature_column来生成映射字典从而来解析tfrecord,但默认tfrecord使用tf.float32来存储,暂未知如何利用tf.feature_column来生成关于tf.string的映射字典,望大佬们不吝告知!!!!!

同样是cifar10的tfrecord,下载链接:https://download.csdn.net/download/u014426939/13619698

目前,新版本的tensorflow关于输入输出相关的函数都整合到tf.io下,因此,若想用FixedLenFeature,VarLenFeature,都需要写作tf.io.FixedLenFeature,tf.io.VarLenFeature,

解析函数也不例外:tf.io.parse_example, tf.io.parse_single_example

tf.io.parse_example:解析多个examples(向量型)

tf.io.parse_single_example:解析单个example

直接上代码感受:

import tensorflow as tf
import glob
tf.enable_eager_execution()#Eager模式


path='cifar/cifar10*.tfrecord*'

files=glob.glob(path)

label = tf.feature_column.numeric_column("label", shape=(), dtype=tf.dtypes.int64)   

feature_columns = [label]

features = tf.feature_column.make_parse_example_spec(feature_columns)  #生成featuredict
#out: features:{'label': FixedLenFeature(shape=(), dtype=tf.int64, default_value=None)}
data = tf.data.TFRecordDataset(files)  #读取tfrecord
#分别用tf.io.parse_single_example  和 tf.io.parse_example解析数据

data1 = data.map(lambda x : tf.io.parse_single_example(x, features = features ))
#data1.__iter__().next()
#out:{'label': <tf.Tensor: id=82, shape=(), dtype=int64, numpy=7>}

####################################################################
#利用parse_example 解析向量型的examples
#应用场景 同时获取多个tfrecord文件中的example,批量解析 
data_batch=data.batch(2)
data2 = data_batch.map(lambda x : tf.io.parse_example(x, features = features))
# data2.__iter__().next()
#out:{'label': <tf.Tensor: id=293, shape=(2,), dtype=int64, numpy=array([7, 0], #dtype=int64)>}