TensorFlow2.0 (4) dataset 使用
程序员文章站
2022-06-06 08:40:35
...
摘要:在第一章我们学会了用 TensorFlow2 来构建模型,第二章又学习了超参数搜索,第三章学会了 TensorFlow 基础 api 的实现。但想要训练出更好的模型,还有一步极其关键的步骤,数据的输入与处理,在实际工作项目中,数据的处理与输入甚至可能占用 60% 的时间。
一、 Dataset 基础 API使用
1.1 tf.data.Dataset.from_tensor_slices
我们先使用第一种方法,从内存构建数据
# 日常 import
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import sklearn
import pandas as pd
import os
import sys
import time
import tensorflow as tf
from tensorflow import keras
调用 from_tenso_slices
# 初始化一个 1*10 的一维向量的 dataset
dataset = tf.data.Dataset.from_tensor_slices(np.arange(10))
print(dataset)
# 那么对 dataset 我们可以遍历
for item in dataset:
print(item)
# 重复读取:repeat epoch 遍历一次为一次 epoch
# 一次获取的数据量:get batch
# 遍历三次。每次 7 个
dataset = dataset.repeat(3).batch(7)
# interleave: 对现有的每个元素进行处理,形成一个新的数据集
# case: 做一个变化,把文件名对应的文件内容读取出来,再讲文件内容合并起来,变成一个总的数据集
dataset2 = dataset.interleave(
lambda v: tf.data.Dataset.from_tensor_slices(v),# map_fn :数据变换形式
cycle_length = 5,# cycle_length : 同时处理的数据个数
block_length = 5,# block_length : 每次取多少个元素出来
)
# 然后我们再遍历这个 dataset2
for item in dataset2:
# 显示每个被 dataset 从 dataset2 获取到的数据
print(item)
# 通过这个结果我们会发现,它会在每个数组前取 5 个数,在最后不足 5 个数的时候,会从之前只取 5 个数时遗漏的尾部数据那按顺序提取数据,这样就达到了一个均匀混合的效果
这里我们输入进去的是 np 的向量,我们尝试一下别的数据类型
除了支持 numpy 的数据,还支持 python 原有的元组字典等,我们先尝试一下元组
输入元组 (x,y)
# 除了支持 numpy 的数据,还支持 python 原有的元组字典等
# 设定一个二维矩阵,和二维矩阵对应的对象
x = np.array([1,2],[3,4],[5,6]])
y = np.array(['cat','dog','fox'])
# 这一步就是将两个数组以元组的方式输入
dataset3 = tf.data.Dataset.from_tensor_slices((x,y))
print(dataset3)
for item_x,item_y in dataset3:
# 这里直接输出 item 的话,是tensor类型,想简单一点的话,像之前所讲,加个 numpy() 就好
print(item_x.numpy(),item_y.numpy())
我们再试试字典形式
dataset4 = tf.data.Dataset.from_tensor_slices({"fesature":x,
"label":y})
# 直接输出 item 的字典数据
for item in dataset4:
print(item)
# 简洁的输出方法
for item in dataset4:
print(item["feature"].numpy(),item["label"].numpy())
1.2 repeat ,batch ,interleave,map,shuffle,list_files,
二、Dataset 读取 CSV 文件
2.1 tf.data.TextLineDataset , tf.io.decode_csv
三、 Dataset 读取 tfrecord 文件
3.1 tf.train.FloatList ,tf.train.Int64List ,tf.train.BytesList
3.2 tf.train.Feature ,tf.train.Features ,tf.train.Example
3.3 example.SerializeToString
3.4 tf.io.ParseSingleExample
3.5 tf.io.VarLenFeature ,tf.io.FixedLenFeature
3.6 tf.data.TFRecordDataset ,tf.io.TFRecordOptions
推荐阅读
-
C4D目标标签怎么使用?
-
webpack4的迁移的使用方法
-
php中有关字符串的4个函数substr、strrchr、strstr、ereg介绍和使用例子
-
DataReader、DataSet、DataAdapter和DataView使用介绍
-
MP4Joiner怎么用?使用MP4Joiner快速合并多个mp4视频文件的方法介绍
-
ASP.NET Log4Net日志的配置及使用,文件写入
-
c4d怎么阵列对象? c4d阵列的使用方法
-
使用vs2010编译log4cxx图文教程
-
HTML5自定义属性前缀data-及dataset的使用方法(html5 新特性)
-
c4d怎么使用刚体标签制作小球体滑落的动画?