用 tf.data 加载 CSV 数据
目录
这篇教程使用的是泰坦尼克号乘客的数据。模型会根据乘客的年龄、性别、票务舱和是否独自旅行等特征来预测乘客生还的可能性。
一、设置
!pip install -q tensorflow==2.0.0-beta1
from __future__ import absolute_import, division, print_function, unicode_literals
import functools
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
/tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow/python/framework/dtypes.py:516: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_qint8 = np.dtype([("qint8", np.int8, 1)])
/tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow/python/framework/dtypes.py:517: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_quint8 = np.dtype([("quint8", np.uint8, 1)])
/tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow/python/framework/dtypes.py:518: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_qint16 = np.dtype([("qint16", np.int16, 1)])
/tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow/python/framework/dtypes.py:519: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_quint16 = np.dtype([("quint16", np.uint16, 1)])
/tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow/python/framework/dtypes.py:520: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_qint32 = np.dtype([("qint32", np.int32, 1)])
/tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow/python/framework/dtypes.py:525: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
np_resource = np.dtype([("resource", np.ubyte, 1)])
/tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:541: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_qint8 = np.dtype([("qint8", np.int8, 1)])
/tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:542: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_quint8 = np.dtype([("quint8", np.uint8, 1)])
/tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:543: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_qint16 = np.dtype([("qint16", np.int16, 1)])
/tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:544: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_quint16 = np.dtype([("quint16", np.uint16, 1)])
/tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:545: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_qint32 = np.dtype([("qint32", np.int32, 1)])
/tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:550: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
np_resource = np.dtype([("resource", np.ubyte, 1)])
TRAIN_DATA_URL = "https://storage.googleapis.com/tf-datasets/titanic/train.csv"
TEST_DATA_URL = "https://storage.googleapis.com/tf-datasets/titanic/eval.csv"
train_file_path = tf.keras.utils.get_file("train.csv", TRAIN_DATA_URL)
test_file_path = tf.keras.utils.get_file("eval.csv", TEST_DATA_URL)
Downloading data from https://storage.googleapis.com/tf-datasets/titanic/train.csv
32768/30874 [===============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tf-datasets/titanic/eval.csv
16384/13049 [=====================================] - 0s 0us/step
# 让 numpy 数据更易读。
np.set_printoptions(precision=3, suppress=True)
二、加载数据
开始的时候,我们通过打印 CSV 文件的前几行来了解文件的格式。
!head {train_file_path}
survived,sex,age,n_siblings_spouses,parch,fare,class,deck,embark_town,alone
0,male,22.0,1,0,7.25,Third,unknown,Southampton,n
1,female,38.0,1,0,71.2833,First,C,Cherbourg,n
1,female,26.0,0,0,7.925,Third,unknown,Southampton,y
1,female,35.0,1,0,53.1,First,C,Southampton,n
0,male,28.0,0,0,8.4583,Third,unknown,Queenstown,y
0,male,2.0,3,1,21.075,Third,unknown,Southampton,n
1,female,27.0,0,2,11.1333,Third,unknown,Southampton,n
1,female,14.0,1,0,30.0708,Second,unknown,Cherbourg,n
1,female,4.0,1,1,16.7,Third,G,Southampton,n
正如你看到的那样,CSV 文件的每列都会有一个列名。dataset 的构造函数会自动识别这些列名。如果你使用的文件的第一行不包含列名,那么需要将列名通过字符串列表传给 make_csv_dataset
函数的 column_names
参数。
CSV_COLUMNS = ['survived', 'sex', 'age', 'n_siblings_spouses', 'parch', 'fare', 'class', 'deck', 'embark_town', 'alone']
dataset = tf.data.experimental.make_csv_dataset(
...,
column_names=CSV_COLUMNS,
...)
这个示例使用了所有的列。如果你需要忽略数据集中的某些列,创建一个包含你需要使用的列的列表,然后传给构造器的(可选)参数 select_columns
。
dataset = tf.data.experimental.make_csv_dataset(
...,
select_columns = columns_to_use,
...)
对于包含模型需要预测的值的列是你需要显式指定的。
LABEL_COLUMN = 'survived'
LABELS = [0, 1]
现在从文件中读取 CSV 数据并且创建 dataset。
(完整的文档,参考 tf.data.experimental.make_csv_dataset
)
def get_dataset(file_path):
dataset = tf.data.experimental.make_csv_dataset(
file_path,
batch_size=12, # 为了示例更容易展示,手动设置较小的值
label_name=LABEL_COLUMN,
na_value="?",
num_epochs=1,
ignore_errors=True)
return dataset
raw_train_data = get_dataset(train_file_path)
raw_test_data = get_dataset(test_file_path)
WARNING: Logging before flag parsing goes to stderr.
W0729 22:46:57.032288 140251943106304 deprecation.py:323] From /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow/python/data/experimental/ops/readers.py:498: parallel_interleave (from tensorflow.python.data.experimental.ops.interleave_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Dataset.interleave(map_func, cycle_length, block_length, num_parallel_calls=tf.data.experimental.AUTOTUNE)` instead. If sloppy execution is desired, use `tf.data.Options.experimental_determinstic`.
dataset 中的每个条目都是一个批次,用一个元组(多个样本,多个标签)表示。样本中的数据组织形式是以列为主的张量(而不是以行为主的张量),每条数据中包含的元素个数就是批次大小(这个示例中是 12)。
阅读下面的示例有助于你的理解。
examples, labels = next(iter(raw_train_data)) # 第一个批次
print("EXAMPLES: \n", examples, "\n")
print("LABELS: \n", labels)
EXAMPLES:
OrderedDict([('sex', <tf.Tensor: id=170, shape=(12,), dtype=string, numpy=
array([b'male', b'male', b'female', b'female', b'male', b'male',
b'female', b'male', b'male', b'female', b'female', b'male'],
dtype=object)>), ('age', <tf.Tensor: id=162, shape=(12,), dtype=float32, numpy=
array([66., 28., 16., 6., 28., 16., 31., 36., 35., 33., 3., 37.],
dtype=float32)>), ('n_siblings_spouses', <tf.Tensor: id=168, shape=(12,), dtype=int32, numpy=array([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1], dtype=int32)>), ('parch', <tf.Tensor: id=169, shape=(12,), dtype=int32, numpy=array([0, 0, 0, 1, 0, 0, 0, 1, 0, 2, 2, 0], dtype=int32)>), ('fare', <tf.Tensor: id=167, shape=(12,), dtype=float32, numpy=
array([10.5 , 0. , 7.75 , 33. , 39.6 , 8.05 , 7.854, 24.15 ,
10.5 , 26. , 41.579, 53.1 ], dtype=float32)>), ('class', <tf.Tensor: id=164, shape=(12,), dtype=string, numpy=
array([b'Second', b'Second', b'Third', b'Second', b'First', b'Third',
b'Third', b'Third', b'Second', b'Second', b'Second', b'First'],
dtype=object)>), ('deck', <tf.Tensor: id=165, shape=(12,), dtype=string, numpy=
array([b'unknown', b'unknown', b'unknown', b'unknown', b'unknown',
b'unknown', b'unknown', b'unknown', b'unknown', b'unknown',
b'unknown', b'C'], dtype=object)>), ('embark_town', <tf.Tensor: id=166, shape=(12,), dtype=string, numpy=
array([b'Southampton', b'Southampton', b'Queenstown', b'Southampton',
b'Cherbourg', b'Southampton', b'Southampton', b'Southampton',
b'Southampton', b'Southampton', b'Cherbourg', b'Southampton'],
dtype=object)>), ('alone', <tf.Tensor: id=163, shape=(12,), dtype=string, numpy=
array([b'y', b'y', b'y', b'n', b'y', b'y', b'y', b'n', b'y', b'n', b'n',
b'n'], dtype=object)>)])
LABELS:
tf.Tensor([0 0 1 1 0 0 0 0 0 1 1 0], shape=(12,), dtype=int32)
三、数据预处理
1、分类数据
CSV 数据中的有些列是分类的列。也就是说,这些列只能在有限的集合中取值。
使用 tf.feature_column
API 创建一个 tf.feature_column.indicator_column
集合,每个 tf.feature_column.indicator_column
对应一个分类的列。
CATEGORIES = {
'sex': ['male', 'female'],
'class' : ['First', 'Second', 'Third'],
'deck' : ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J'],
'embark_town' : ['Cherbourg', 'Southhampton', 'Queenstown'],
'alone' : ['y', 'n']
}
categorical_columns = []
for feature, vocab in CATEGORIES.items():
cat_col = tf.feature_column.categorical_column_with_vocabulary_list(
key=feature, vocabulary_list=vocab)
categorical_columns.append(tf.feature_column.indicator_column(cat_col))
# 你刚才创建的内容
categorical_columns
[IndicatorColumn(categorical_column=VocabularyListCategoricalColumn(key='alone', vocabulary_list=('y', 'n'), dtype=tf.string, default_value=-1, num_oov_buckets=0)),
IndicatorColumn(categorical_column=VocabularyListCategoricalColumn(key='deck', vocabulary_list=('A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J'), dtype=tf.string, default_value=-1, num_oov_buckets=0)),
IndicatorColumn(categorical_column=VocabularyListCategoricalColumn(key='sex', vocabulary_list=('male', 'female'), dtype=tf.string, default_value=-1, num_oov_buckets=0)),
IndicatorColumn(categorical_column=VocabularyListCategoricalColumn(key='embark_town', vocabulary_list=('Cherbourg', 'Southhampton', 'Queenstown'), dtype=tf.string, default_value=-1, num_oov_buckets=0)),
IndicatorColumn(categorical_column=VocabularyListCategoricalColumn(key='class', vocabulary_list=('First', 'Second', 'Third'), dtype=tf.string, default_value=-1, num_oov_buckets=0))]
这将是后续构建模型时处理输入数据的一部分。
2、连续数据
连续数据需要标准化。
写一个函数标准化这些值,然后将这些值改造成 2 维的张量。
def process_continuous_data(mean, data):
# 标准化数据
data = tf.cast(data, tf.float32) * 1/(2*mean)
return tf.reshape(data, [-1, 1])
现在创建一个数值列的集合。tf.feature_columns.numeric_column
API 会使用 normalizer_fn
参数。在传参的时候使用 functools.partial
,functools.partial
由使用每个列的均值进行标准化的函数构成。
MEANS = {
'age' : 29.631308,
'n_siblings_spouses' : 0.545455,
'parch' : 0.379585,
'fare' : 34.385399
}
numerical_columns = []
for feature in MEANS.keys():
num_col = tf.feature_column.numeric_column(feature, normalizer_fn=functools.partial(process_continuous_data, MEANS[feature]))
numerical_columns.append(num_col)
# 你刚才创建的内容。
numerical_columns
[NumericColumn(key='parch', shape=(1,), default_value=None, dtype=tf.float32, normalizer_fn=functools.partial(<function process_continuous_data at 0x7f8eaaad1d90>, 0.379585)),
NumericColumn(key='n_siblings_spouses', shape=(1,), default_value=None, dtype=tf.float32, normalizer_fn=functools.partial(<function process_continuous_data at 0x7f8eaaad1d90>, 0.545455)),
NumericColumn(key='age', shape=(1,), default_value=None, dtype=tf.float32, normalizer_fn=functools.partial(<function process_continuous_data at 0x7f8eaaad1d90>, 29.631308)),
NumericColumn(key='fare', shape=(1,), default_value=None, dtype=tf.float32, normalizer_fn=functools.partial(<function process_continuous_data at 0x7f8eaaad1d90>, 34.385399))]
这里使用标准化的方法需要提前知道每列的均值。如果需要计算连续的数据流的标准化的值可以使用 TensorFlow Transform。
3、创建预处理层
将这两个特征列的集合相加,并且传给 tf.keras.layers.DenseFeatures
从而创建一个进行预处理的输入层。
preprocessing_layer = tf.keras.layers.DenseFeatures(categorical_columns+numerical_columns)
四、构建模型
从 preprocessing_layer
开始构建 tf.keras.Sequential
。
model = tf.keras.Sequential([
preprocessing_layer,
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(1, activation='sigmoid'),
])
model.compile(
loss='binary_crossentropy',
optimizer='adam',
metrics=['accuracy'])
五、训练、评估和预测
现在可以实例化和训练模型。
train_data = raw_train_data.shuffle(500)
test_data = raw_test_data
model.fit(train_data, epochs=20)
Epoch 1/20
W0729 22:46:57.629283 140251943106304 deprecation.py:323] From /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow/python/feature_column/feature_column_v2.py:2655: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
W0729 22:46:57.644883 140251943106304 deprecation.py:323] From /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow/python/feature_column/feature_column_v2.py:4215: IndicatorColumn._variable_shape (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed in a future version.
Instructions for updating:
The old _FeatureColumn APIs are being deprecated. Please use the new FeatureColumn APIs instead.
W0729 22:46:57.645712 140251943106304 deprecation.py:323] From /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow/python/feature_column/feature_column_v2.py:4270: VocabularyListCategoricalColumn._num_buckets (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed in a future version.
Instructions for updating:
The old _FeatureColumn APIs are being deprecated. Please use the new FeatureColumn APIs instead.
53/53 [==============================] - 2s 34ms/step - loss: 0.5244 - accuracy: 0.7071
Epoch 2/20
53/53 [==============================] - 0s 4ms/step - loss: 0.4292 - accuracy: 0.8164
Epoch 3/20
53/53 [==============================] - 0s 3ms/step - loss: 0.4106 - accuracy: 0.8249
Epoch 4/20
53/53 [==============================] - 0s 3ms/step - loss: 0.3997 - accuracy: 0.8367
Epoch 5/20
53/53 [==============================] - 0s 3ms/step - loss: 0.3912 - accuracy: 0.8381
Epoch 6/20
53/53 [==============================] - 0s 3ms/step - loss: 0.3840 - accuracy: 0.8388
Epoch 7/20
53/53 [==============================] - 0s 3ms/step - loss: 0.3780 - accuracy: 0.8458
Epoch 8/20
53/53 [==============================] - 0s 3ms/step - loss: 0.3726 - accuracy: 0.8456
Epoch 9/20
53/53 [==============================] - 0s 3ms/step - loss: 0.3673 - accuracy: 0.8518
Epoch 10/20
53/53 [==============================] - 0s 3ms/step - loss: 0.3630 - accuracy: 0.8557
Epoch 11/20
53/53 [==============================] - 0s 3ms/step - loss: 0.3588 - accuracy: 0.8632
Epoch 12/20
53/53 [==============================] - 0s 3ms/step - loss: 0.3543 - accuracy: 0.8609
Epoch 13/20
53/53 [==============================] - 0s 3ms/step - loss: 0.3509 - accuracy: 0.8622
Epoch 14/20
53/53 [==============================] - 0s 3ms/step - loss: 0.3475 - accuracy: 0.8573
Epoch 15/20
53/53 [==============================] - 0s 3ms/step - loss: 0.3440 - accuracy: 0.8636
Epoch 16/20
53/53 [==============================] - 0s 3ms/step - loss: 0.3409 - accuracy: 0.8582
Epoch 17/20
53/53 [==============================] - 0s 3ms/step - loss: 0.3381 - accuracy: 0.8617
Epoch 18/20
53/53 [==============================] - 0s 3ms/step - loss: 0.3352 - accuracy: 0.8594
Epoch 19/20
53/53 [==============================] - 0s 3ms/step - loss: 0.3321 - accuracy: 0.8617
Epoch 20/20
53/53 [==============================] - 0s 3ms/step - loss: 0.3304 - accuracy: 0.8601
<tensorflow.python.keras.callbacks.History at 0x7f8ea5737358>
当模型训练完成的时候,你可以在测试集 test_data
上检查准确性。
test_loss, test_accuracy = model.evaluate(test_data)
print('\n\nTest Loss {}, Test Accuracy {}'.format(test_loss, test_accuracy))
22/Unknown - 0s 22ms/step - loss: 0.4439 - accuracy: 0.8068
Test Loss 0.4438588565046137, Test Accuracy 0.8068181872367859
使用 tf.keras.Model.predict
推断一个批次或多个批次的标签。
predictions = model.predict(test_data)
# 显示部分结果
for prediction, survived in zip(predictions[:10], list(test_data)[0][1][:10]):
print("Predicted survival: {:.2%}".format(prediction[0]),
" | Actual outcome: ",
("SURVIVED" if bool(survived) else "DIED"))
Predicted survival: 91.51% | Actual outcome: SURVIVED
Predicted survival: 99.93% | Actual outcome: SURVIVED
Predicted survival: 5.72% | Actual outcome: DIED
Predicted survival: 55.36% | Actual outcome: SURVIVED
Predicted survival: 12.10% | Actual outcome: DIED
Predicted survival: 22.46% | Actual outcome: DIED
Predicted survival: 8.55% | Actual outcome: DIED
Predicted survival: 19.70% | Actual outcome: DIED
Predicted survival: 2.71% | Actual outcome: DIED
Predicted survival: 84.95% | Actual outcome: SURVIVED
上一篇: tf.data详解
推荐阅读
-
用 tf.data 加载 CSV 数据
-
python用dataframe将csv中的0值数据转化为nan缺失值字样
-
mysql-请问:用java代码把不同的excel形式加载到MySQL数据库中(循环判断空就停止)?谢谢!
-
怎么用PHP实现数据库导出到txt文件或csv文件
-
mysql-请问:用java代码把不同的excel形式加载到MySQL数据库中(循环判断空就停止)?谢谢!
-
react通过axios请求数据,用一个bol来判断数据是否加载完成后,渲染在页面上,不用async await和定时器的更好方法
-
python用pandas数据加载、存储与文件格式的实例
-
python用pandas实现数据加载、存储与文件格式的教程
-
怎么用PHP实现数据库导出到txt文件或csv文件
-
怎么用PHP实现数据库导出到txt文件或csv文件