欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

数据集创建tf.data(tensorflow2)

程序员文章站 2024-01-19 13:51:16
...

tf.data API允许使用简单,可复用的代码创建一个数据输入流。比如它可以从图像分布式文件系统中创建数据输入,在此过程中可以为每张图像添加随机噪声,随机抽取图像当作本次batch进行训练。

tf.data API引入了tf.data.Dataset对象,它包含了一系列的元素(element),每一个元素有多个或一个components

有两种方法创建数据集dataset:

1.从内存或文件中的数据源构建Datadet
2.Dataset数据库中转化得到

基础创建Dataset

dataset = tf.data.Dataset.from_tensor_slices([8, 3, 0, 8, 2, 1])
dataset
#<TensorSliceDataset shapes: (), types: tf.int32>

for elem in dataset:
  print(elem.numpy())
#8
#3
#0
#8
#2
#1

dataset其实是一个iter:

it = iter(dataset)
print(next(it).numpy())
#8

我们可以看到dataset是一个Dataset对象,下面看下reduce方法的使用:

print(dataset.reduce(0, lambda state, value: state + value).numpy())
#22

数据集结构

数据中每个元素都是相同的类型。
类型如下:
Tensor, SparseTensor, RaggedTensor, TensorArray, or Dataset。它们都被包含在 tf.TypeSpec

dataset1 = tf.data.Dataset.from_tensor_slices(tf.random.uniform([4, 10]))

dataset1.element_spec

#TensorSpec(shape=(10,), dtype=tf.float32, name=None)

dataset2 = tf.data.Dataset.from_tensor_slices(
   (tf.random.uniform([4]),
    tf.random.uniform([4, 100], maxval=100, dtype=tf.int32)))

dataset2.element_spec

#(TensorSpec(shape=(), dtype=tf.float32, name=None),
# TensorSpec(shape=(100,), dtype=tf.int32, name=None))

dataset3 = tf.data.Dataset.zip((dataset1, dataset2))

dataset3.element_spec

#(TensorSpec(shape=(10,), dtype=tf.float32, name=None),
# (TensorSpec(shape=(), dtype=tf.float32, name=None),
#  TensorSpec(shape=(100,), dtype=tf.int32, name=None)))

# Dataset containing a sparse tensor.
dataset4 = tf.data.Dataset.from_tensors(tf.SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]))

dataset4.element_spec

#SparseTensorSpec(TensorShape([3, 4]), tf.int32)

# Use value_type to see the type of value represented by the element spec
dataset4.element_spec.value_type

#tensorflow.python.framework.sparse_tensor.SparseTensor

创建生成器

def count(stop):
  i = 0
  while i<stop:
    yield i
    i += 1

for n in count(5):
  print(n)
#0
#1
#2
#3
#4
ds_counter = tf.data.Dataset.from_generator(count, args=[25], output_types=tf.int32, output_shapes = (), )

for count_batch in ds_counter.repeat().batch(10).take(10):
  print(count_batch.numpy())

#[0 1 2 3 4 5 6 7 8 9]
#[10 11 12 13 14 15 16 17 18 19]
#[20 21 22 23 24  0  1  2  3  4]
#[ 5  6  7  8  9 10 11 12 13 14]
#[15 16 17 18 19 20 21 22 23 24]
#[0 1 2 3 4 5 6 7 8 9]
#[10 11 12 13 14 15 16 17 18 19]
#[20 21 22 23 24  0  1  2  3  4]
#[ 5  6  7  8  9 10 11 12 13 14]
#[15 16 17 18 19 20 21 22 23 24]

再看下,其它例子:

def gen_series():
  i = 0
  while True:
    size = np.random.randint(0, 10)
    yield i, np.random.normal(size=(size,))
    i += 1

for i, series in gen_series():
  print(i, ":", str(series))
  if i > 5:
    break
"""
0 : [ 0.9475 -0.6361  0.9765]
1 : [-0.555   1.3723  0.1027 -1.0957  0.141 ]
2 : [-0.5906 -1.2747  0.5064  1.104   0.1396 -0.1937 -0.3695 -0.5508]
3 : [ 0.2029  0.7422  1.3038  1.0698  1.7587 -0.7051]
4 : [ 0.4777  0.568  -0.7713 -0.0322 -1.0875]
5 : [ 0.2634 -0.3093  0.6087]
6 : [-0.1843  0.6568  0.2268  2.1317 -0.2758 -0.4531]
"""
ds_series = tf.data.Dataset.from_generator(
    gen_series, 
    output_types=(tf.int32, tf.float32), 
    output_shapes=((), (None,)))

ds_series

#<FlatMapDataset shapes: ((), (None,)), types: (tf.int32, tf.float32)>

需要注意一点,就是上述数据长度是不相等。可以使用padded_batch。

ds_series_batch = ds_series.shuffle(20).padded_batch(10)

ids, sequence_batch = next(iter(ds_series_batch))
print(ids.numpy())
print()
print(sequence_batch.numpy())
"""
[10  2 13  4 19 12 25 14  8  0]

[[ 0.6505  0.9339 -0.6796  0.      0.      0.      0.      0.    ]
 [ 1.4865 -0.5334  0.      0.      0.      0.      0.      0.    ]
 [ 0.9648  1.0677 -1.2092 -0.4564  0.9524 -0.5516 -0.8149  1.1307]
 [-1.2593  0.8061  0.7738 -0.6441 -1.3384  1.2362  0.      0.    ]
 [ 0.6877  1.9626 -1.0171 -0.7908  0.      0.      0.      0.    ]
 [ 0.2111  1.687  -0.0555 -0.0242 -1.2556  1.1843 -0.509   1.5797]
 [-0.6607  0.      0.      0.      0.      0.      0.      0.    ]
 [-0.3114 -0.6608  0.      0.      0.      0.      0.      0.    ]
 [ 0.8224 -1.2478 -0.9483  0.6411 -0.9707  1.659  -0.642   0.    ]
 [ 1.7515  1.3955 -0.9958 -0.1844  0.5085 -0.1619  1.0888  1.7601]]
 """

再看下对图像数据的读取

flowers = tf.keras.utils.get_file(
    'flower_photos',
    'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
    untar=True)

img_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255, rotation_range=20)

images, labels = next(img_gen.flow_from_directory(flowers))

#Found 3670 images belonging to 5 classes.

print(images.dtype, images.shape)
print(labels.dtype, labels.shape)

#float32 (32, 256, 256, 3)
#float32 (32, 5)

ds = tf.data.Dataset.from_generator(
    img_gen.flow_from_directory, args=[flowers], 
    output_types=(tf.float32, tf.float32), 
    output_shapes=([32,256,256,3], [32,5])
)

ds

#<FlatMapDataset shapes: ((32, 256, 256, 3), (32, 5)), types: (tf.float32, #tf.float32)>

读取并使用 TFRecord data

读取TFRecord data当作输入流:

# Creates a dataset that reads all of the examples from two files.
fsns_test_file = tf.keras.utils.get_file("fsns.tfrec", "https://storage.googleapis.com/download.tensorflow.org/data/fsns-20160927/testdata/fsns-00000-of-00001")
'''
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/fsns-20160927/testdata/fsns-00000-of-00001
7905280/7904079 [==============================] - 1s 0us/step
'''
dataset = tf.data.TFRecordDataset(filenames = [fsns_test_file])
dataset

#<TFRecordDatasetV2 shapes: (), types: tf.string>
相关标签: tensorflow