欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

tf.data.TextLineDataset 解析csv

程序员文章站 2024-01-19 13:30:16
...

TextLineDataset可以将文本类的数据映射到tesorflow的Dataset

TextLineDataset的读取机制:读取文本中数据,一行代表一组数据,一般再按照filter、map、shuffle、batch、repeat、prefetch的顺序获得可用数据。

在csv中,因存在头部行,使用filter对数据进行筛选,去掉不符合数据结构的行

上代码:

from six.moves.urllib.request import urlopen
import os

import numpy as np
import tensorflow as tf

IRIS_TRAINING = "iris_training.csv"
IRIS_TRAINING_URL = "http://download.tensorflow.org/data/iris_training.csv"

IRIS_TEST = "iris_test.csv"
IRIS_TEST_URL = "http://download.tensorflow.org/data/iris_test.csv"

if not os.path.exists(IRIS_TRAINING):
   raw = urlopen(IRIS_TRAINING_URL).read()
   with open(IRIS_TRAINING, "wb") as f:
       f.write(raw)

if not os.path.exists(IRIS_TEST):
   raw = urlopen(IRIS_TEST_URL).read()
   with open(IRIS_TEST, "wb") as f:
       f.write(raw)

fun = lambda x1: not tf.strings.regex_full_match(x1, '.*[a-z|A-Z].*')  # 判断是否存在字母

def funStringSplit(x):
   split_strings = tf.strings.to_number(tf.strings.split(x, ','))  # 分割字符串
   features, target = tf.split(split_strings, [4, 1], axis=0)
   return features, target

trainSet = tf.data.TextLineDataset([IRIS_TRAINING])
testSet = tf.data.TextLineDataset([IRIS_TEST])

trainSet = trainSet.filter(fun)
testSet = testSet.filter(fun)

trainSet = trainSet.map(funStringSplit,num_parallel_calls=4)
testSet = testSet.map(funStringSplit,num_parallel_calls=4)