欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

TensorFlow2.0 (4) dataset 使用

程序员文章站 2022-06-06 08:40:35
...

摘要:在第一章我们学会了用 TensorFlow2 来构建模型,第二章又学习了超参数搜索,第三章学会了 TensorFlow 基础 api 的实现。但想要训练出更好的模型,还有一步极其关键的步骤,数据的输入与处理,在实际工作项目中,数据的处理与输入甚至可能占用 60% 的时间。

一、 Dataset 基础 API使用

          1.1 tf.data.Dataset.from_tensor_slices

                我们先使用第一种方法,从内存构建数据

# 日常 import
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import sklearn 
import pandas as pd
import os
import sys
import time
import tensorflow as tf

from tensorflow import keras

               调用 from_tenso_slices

# 初始化一个 1*10 的一维向量的 dataset
dataset = tf.data.Dataset.from_tensor_slices(np.arange(10))
print(dataset)

# 那么对 dataset 我们可以遍历
for item in dataset:
    print(item)

# 重复读取:repeat epoch 遍历一次为一次 epoch
# 一次获取的数据量:get batch
# 遍历三次。每次 7 个
dataset = dataset.repeat(3).batch(7)

# interleave: 对现有的每个元素进行处理,形成一个新的数据集
# case: 做一个变化,把文件名对应的文件内容读取出来,再讲文件内容合并起来,变成一个总的数据集

dataset2 = dataset.interleave(
    lambda v: tf.data.Dataset.from_tensor_slices(v),# map_fn :数据变换形式
    cycle_length = 5,# cycle_length : 同时处理的数据个数
    block_length = 5,# block_length : 每次取多少个元素出来

)

# 然后我们再遍历这个 dataset2
for item in dataset2:
    # 显示每个被 dataset 从 dataset2 获取到的数据
    print(item) 
# 通过这个结果我们会发现,它会在每个数组前取 5 个数,在最后不足 5 个数的时候,会从之前只取 5 个数时遗漏的尾部数据那按顺序提取数据,这样就达到了一个均匀混合的效果

              这里我们输入进去的是 np 的向量,我们尝试一下别的数据类型

              除了支持 numpy 的数据,还支持 python 原有的元组字典等,我们先尝试一下元组

              输入元组 (x,y)

# 除了支持 numpy 的数据,还支持 python 原有的元组字典等
# 设定一个二维矩阵,和二维矩阵对应的对象
x = np.array([1,2],[3,4],[5,6]])
y = np.array(['cat','dog','fox'])
# 这一步就是将两个数组以元组的方式输入
dataset3 = tf.data.Dataset.from_tensor_slices((x,y))
print(dataset3)

for item_x,item_y in dataset3:
    # 这里直接输出 item 的话,是tensor类型,想简单一点的话,像之前所讲,加个 numpy() 就好
    print(item_x.numpy(),item_y.numpy())

             我们再试试字典形式

dataset4 = tf.data.Dataset.from_tensor_slices({"fesature":x,
                                                "label":y})
# 直接输出 item 的字典数据
for item in dataset4:
    print(item)
# 简洁的输出方法
for item in dataset4:
    print(item["feature"].numpy(),item["label"].numpy())

 

 

          1.2 repeat ,batch ,interleave,map,shuffle,list_files,

 

二、Dataset 读取 CSV 文件

         2.1 tf.data.TextLineDataset , tf.io.decode_csv

 

 

 

三、 Dataset 读取 tfrecord 文件

         3.1 tf.train.FloatList ,tf.train.Int64List ,tf.train.BytesList

         3.2 tf.train.Feature ,tf.train.Features ,tf.train.Example

         3.3 example.SerializeToString

         3.4 tf.io.ParseSingleExample

         3.5 tf.io.VarLenFeature ,tf.io.FixedLenFeature

         3.6 tf.data.TFRecordDataset ,tf.io.TFRecordOptions

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

相关标签: TensorFlow 2.0