tf.estimator 训练demo
程序员文章站
2024-01-19 13:34:52
...
本文记录使用tf.estimator.DNNClassifier分类iris
上代码:测试版本:tf1.x
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from six.moves.urllib.request import urlopen
import os
import numpy as np
import tensorflow as tf
# tf.enable_eager_execution()
# Data sets
# sess=tf.Session()
# coord=tf.train.Coordinator()
# threads=tf.train.start_queue_runners(sess=sess,coord=coord)
IRIS_TRAINING = "iris_training.csv"
IRIS_TRAINING_URL = "http://download.tensorflow.org/data/iris_training.csv"
IRIS_TEST = "iris_test.csv"
IRIS_TEST_URL = "http://download.tensorflow.org/data/iris_test.csv"
def main():
# If the training and test sets aren't stored locally, download them.
if not os.path.exists(IRIS_TRAINING):
raw = urlopen(IRIS_TRAINING_URL).read()
with open(IRIS_TRAINING, "wb") as f:
f.write(raw)
if not os.path.exists(IRIS_TEST):
raw = urlopen(IRIS_TEST_URL).read()
with open(IRIS_TEST, "wb") as f:
f.write(raw)
# Specify that all features have real-value data
def input_fn(files,batch,training):
fun = lambda x1: tf.equal(tf.strings.regex_full_match(x1, '.*[a-z|A-Z].*'),False) # 判断是否存在字母
def funStringSplit(x):
split_strings = tf.strings.to_number(tf.strings.split(tf.reshape(x,[-1]), ',').values) # 分割字符串
features, target = tf.split(split_strings, [4, 1], axis=0)
target=tf.cast(target,tf.int32)
#features=tf.reshape(features,[4])
return ({'x':features}, target)
dataSet = tf.data.TextLineDataset(files)
dataSet = dataSet.filter(fun)
dataSet = dataSet.map(funStringSplit,num_parallel_calls=4)
if training:
dataSet=dataSet.shuffle(1000).repeat()
return dataSet.batch(batch).prefetch(10)
feature_columns = [tf.feature_column.numeric_column("x", shape=[4])]
#input_tensor=tf.feature_column.input_layer({'x':np.array([[1,2,3,4]])},feature_columns=feature_columns)
classifier = tf.estimator.DNNClassifier(feature_columns=feature_columns,
hidden_units=[10, 20, 10],
n_classes=3,
model_dir="./iris_model1")
# Listeners = []
Hooks=[tf.train.CheckpointSaverHook(checkpoint_dir='./iris_model1/cp',save_steps=1000),
]
# IRIS_TRAINING
# IRIS_TEST
classifier.train(input_fn=lambda: input_fn([IRIS_TRAINING], 128,True), max_steps=10000,hooks=Hooks)
# Evaluate accuracy.
accuracy_score = classifier.evaluate(input_fn=lambda:input_fn([IRIS_TEST], 16,False))["accuracy"]
print("\nTest Accuracy: {0:f}\n".format(accuracy_score))
accuracy_score = classifier.evaluate(input_fn=lambda:input_fn([IRIS_TRAINING], 16,False))["accuracy"]
print("\nTraining Accuracy: {0:f}\n".format(accuracy_score))
# Classify two new flower samples.
new_samples = np.array(
[[6.4, 3.2, 4.5, 1.5],#1
[5.8, 3.1, 5.0, 1.7],
[6.3,2.9,5.6,1.8]#2
],#
dtype=np.float32)
predict_input_fn = tf.estimator.inputs.numpy_input_fn(
x={"x": new_samples},
num_epochs=1,
shuffle=False)
predictions = list(classifier.predict(input_fn=predict_input_fn))
predicted_classes = [p["classes"] for p in predictions]
print("New Samples, Class Predictions:{}\n".format(predicted_classes))
if __name__ == "__main__":
main()
#out:Test Accuracy: 0.966667
#
#
#
记录:
坑1:
WARNING:tensorflow:The graph (<tensorflow.python.framework.ops.Graph object at 0x0000019FA3FAD550>) of the iterator is different from the graph (<tensorflow.python.framework.ops.Graph object at 0x0000019FAC815B70>) the dataset: <DatasetV1Adapter shapes: ((4,), (1,)), types: (tf.float32, tf.float32)> was created in. If you are using the Estimator API, make sure that no part of the dataset returned by the `input_fn` function is defined outside the `input_fn` function.Please ensure that all datasets in the pipeline are created in the same graph as the iterator. NOTE: This warning will become an error in future versions of TensorFlow.
解决:tf.estimator的图是在train的时候建立的,如给出代码中使用DNNClassifier,train中需要传入构建data pipeline的函数,而不能在函数外构建好data pipeline再构建数据输入函数。
推荐阅读
-
tf.estimator 训练demo
-
解决加载自定义训练模型OSError: Unable to open file (unable to open file: name = ‘./model/LeNet_model‘, errno = 2
-
《基于注解的SpringMVC增删改DEMO源码》Maven版
-
tf.kerasr入门示例:Lenet手写字符分类(2扩展) eager模式下Sequence生成器方式加载数据并训练
-
超算新突破将深度学习训练时间缩减到数分钟
-
tensorflow2.x 下 eager mode的训练流程
-
微信小程序picker组件简单用法示例【附demo源码下载】
-
Tika解析文件Demo
-
laravel4 简单demo
-
训练Word2Vec报错:RuntimeError: you must first build vocabulary before training the model