数据集创建tf.data(tensorflow2)
程序员文章站
2024-01-19 13:51:16
...
tf.data API允许使用简单,可复用的代码创建一个数据输入流。比如它可以从图像分布式文件系统中创建数据输入,在此过程中可以为每张图像添加随机噪声,随机抽取图像当作本次batch进行训练。
tf.data API引入了tf.data.Dataset对象,它包含了一系列的元素(element),每一个元素有多个或一个components
有两种方法创建数据集dataset:
1.从内存或文件中的数据源构建Datadet
2.Dataset数据库中转化得到
基础创建Dataset
dataset = tf.data.Dataset.from_tensor_slices([8, 3, 0, 8, 2, 1])
dataset
#<TensorSliceDataset shapes: (), types: tf.int32>
for elem in dataset:
print(elem.numpy())
#8
#3
#0
#8
#2
#1
dataset其实是一个iter:
it = iter(dataset)
print(next(it).numpy())
#8
我们可以看到dataset是一个Dataset对象,下面看下reduce方法的使用:
print(dataset.reduce(0, lambda state, value: state + value).numpy())
#22
数据集结构
数据中每个元素都是相同的类型。
类型如下:
Tensor, SparseTensor, RaggedTensor, TensorArray, or Dataset。它们都被包含在 tf.TypeSpec
dataset1 = tf.data.Dataset.from_tensor_slices(tf.random.uniform([4, 10]))
dataset1.element_spec
#TensorSpec(shape=(10,), dtype=tf.float32, name=None)
dataset2 = tf.data.Dataset.from_tensor_slices(
(tf.random.uniform([4]),
tf.random.uniform([4, 100], maxval=100, dtype=tf.int32)))
dataset2.element_spec
#(TensorSpec(shape=(), dtype=tf.float32, name=None),
# TensorSpec(shape=(100,), dtype=tf.int32, name=None))
dataset3 = tf.data.Dataset.zip((dataset1, dataset2))
dataset3.element_spec
#(TensorSpec(shape=(10,), dtype=tf.float32, name=None),
# (TensorSpec(shape=(), dtype=tf.float32, name=None),
# TensorSpec(shape=(100,), dtype=tf.int32, name=None)))
# Dataset containing a sparse tensor.
dataset4 = tf.data.Dataset.from_tensors(tf.SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]))
dataset4.element_spec
#SparseTensorSpec(TensorShape([3, 4]), tf.int32)
# Use value_type to see the type of value represented by the element spec
dataset4.element_spec.value_type
#tensorflow.python.framework.sparse_tensor.SparseTensor
创建生成器
def count(stop):
i = 0
while i<stop:
yield i
i += 1
for n in count(5):
print(n)
#0
#1
#2
#3
#4
ds_counter = tf.data.Dataset.from_generator(count, args=[25], output_types=tf.int32, output_shapes = (), )
for count_batch in ds_counter.repeat().batch(10).take(10):
print(count_batch.numpy())
#[0 1 2 3 4 5 6 7 8 9]
#[10 11 12 13 14 15 16 17 18 19]
#[20 21 22 23 24 0 1 2 3 4]
#[ 5 6 7 8 9 10 11 12 13 14]
#[15 16 17 18 19 20 21 22 23 24]
#[0 1 2 3 4 5 6 7 8 9]
#[10 11 12 13 14 15 16 17 18 19]
#[20 21 22 23 24 0 1 2 3 4]
#[ 5 6 7 8 9 10 11 12 13 14]
#[15 16 17 18 19 20 21 22 23 24]
再看下,其它例子:
def gen_series():
i = 0
while True:
size = np.random.randint(0, 10)
yield i, np.random.normal(size=(size,))
i += 1
for i, series in gen_series():
print(i, ":", str(series))
if i > 5:
break
"""
0 : [ 0.9475 -0.6361 0.9765]
1 : [-0.555 1.3723 0.1027 -1.0957 0.141 ]
2 : [-0.5906 -1.2747 0.5064 1.104 0.1396 -0.1937 -0.3695 -0.5508]
3 : [ 0.2029 0.7422 1.3038 1.0698 1.7587 -0.7051]
4 : [ 0.4777 0.568 -0.7713 -0.0322 -1.0875]
5 : [ 0.2634 -0.3093 0.6087]
6 : [-0.1843 0.6568 0.2268 2.1317 -0.2758 -0.4531]
"""
ds_series = tf.data.Dataset.from_generator(
gen_series,
output_types=(tf.int32, tf.float32),
output_shapes=((), (None,)))
ds_series
#<FlatMapDataset shapes: ((), (None,)), types: (tf.int32, tf.float32)>
需要注意一点,就是上述数据长度是不相等。可以使用padded_batch。
ds_series_batch = ds_series.shuffle(20).padded_batch(10)
ids, sequence_batch = next(iter(ds_series_batch))
print(ids.numpy())
print()
print(sequence_batch.numpy())
"""
[10 2 13 4 19 12 25 14 8 0]
[[ 0.6505 0.9339 -0.6796 0. 0. 0. 0. 0. ]
[ 1.4865 -0.5334 0. 0. 0. 0. 0. 0. ]
[ 0.9648 1.0677 -1.2092 -0.4564 0.9524 -0.5516 -0.8149 1.1307]
[-1.2593 0.8061 0.7738 -0.6441 -1.3384 1.2362 0. 0. ]
[ 0.6877 1.9626 -1.0171 -0.7908 0. 0. 0. 0. ]
[ 0.2111 1.687 -0.0555 -0.0242 -1.2556 1.1843 -0.509 1.5797]
[-0.6607 0. 0. 0. 0. 0. 0. 0. ]
[-0.3114 -0.6608 0. 0. 0. 0. 0. 0. ]
[ 0.8224 -1.2478 -0.9483 0.6411 -0.9707 1.659 -0.642 0. ]
[ 1.7515 1.3955 -0.9958 -0.1844 0.5085 -0.1619 1.0888 1.7601]]
"""
再看下对图像数据的读取
flowers = tf.keras.utils.get_file(
'flower_photos',
'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
untar=True)
img_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255, rotation_range=20)
images, labels = next(img_gen.flow_from_directory(flowers))
#Found 3670 images belonging to 5 classes.
print(images.dtype, images.shape)
print(labels.dtype, labels.shape)
#float32 (32, 256, 256, 3)
#float32 (32, 5)
ds = tf.data.Dataset.from_generator(
img_gen.flow_from_directory, args=[flowers],
output_types=(tf.float32, tf.float32),
output_shapes=([32,256,256,3], [32,5])
)
ds
#<FlatMapDataset shapes: ((32, 256, 256, 3), (32, 5)), types: (tf.float32, #tf.float32)>
读取并使用 TFRecord data
读取TFRecord data当作输入流:
# Creates a dataset that reads all of the examples from two files.
fsns_test_file = tf.keras.utils.get_file("fsns.tfrec", "https://storage.googleapis.com/download.tensorflow.org/data/fsns-20160927/testdata/fsns-00000-of-00001")
'''
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/fsns-20160927/testdata/fsns-00000-of-00001
7905280/7904079 [==============================] - 1s 0us/step
'''
dataset = tf.data.TFRecordDataset(filenames = [fsns_test_file])
dataset
#<TFRecordDatasetV2 shapes: (), types: tf.string>
上一篇: vue 数组中引入动态地址图片