关于TensorFlow中Estimator模板的基础理解与使用流程
理解
在TensorFlow 1.4版本之后,官方开始在入门文档中就鼓励使用高层的Estimator API。并且,许多开源代码中也使用了Estimator
模板,因此我认为掌握好这个类使用方法,对于能够优雅地书写Tensorflow程序有着重要的意义。
什么是Estimator
?在这里先不考虑源码实现细节以及类定义,首先建立一个整体的认识。Estimator
,估计器,这个类的核心思想就是把一个网络封装起来,使用类方法中的train
、eval
、predict
等等进行操作。具体的网络细节对于这个类的运行者来说是一个黑盒,只需要提供输入,选择相应的方法,就可以获得输出。
另外,TensorFlow中对于Estimator
,不仅有预设好的对象可以直接生成,还可以自己定义。通俗的讲就是使用预先写好的网络框架(例如DNN)还是用户自定义的网络结构。当然了,预先写好的网络框架也不是死的,具体的隐层数目等等参数都是可以在初始化的时候进行设置的。预定义好的Estimator
在大部分文档中被称作pre(made) Estimator,具体包含哪些类型的分类器,可以查看这一部分的文档。
Estimator
类的主要结构如下,看不明白也没有关系,我们先关注初步的流程框架。需要重点注意的就是其中的model_fn
函数,你会在接下来的使用流程中看到这个函数的作用。
基本使用流程
-
首先定义特征列(
feature_columns
)。这个是之后在Estimator
对象初始化时需要接收的必要参数
my_feature_columns = []
for key in train_x.keys():
my_feature_columns.append(
tf.feature_column.numeric_column(key=key))
-
之后初始化
Estimator
对象。这里分为两种情况:a)如果是使用预定义的
Estimator
(例如DNNClassifier
)则可以直接调用其初始化函数。
classifier = tf.estimator.DNNClassifier(
feature_columns=my_feature_columns,
hidden_units=[10, 10],
n_classes=3)
b) 如果自定义Estimator
(意味着自定义的模型),则首先需要定义model_fn
函数,描述模型细节,之后将model_fn
与其他params
一起传入Estimator
的初始化函数中。
classifier = tf.estimator.Estimator(
model_fn=my_model,
params={
'feature_columns': my_feature_columns,
'hidden_units': [10, 10],
'n_classes': 3,
})
其中params
字典中的值将会被传入model_fn
中用于定义自定义的模型。
前后进行对比,很容易看出两者的区别和共同点。相比之下,可以说自定义的模型比预设的模型多了一层壳,多传了一次参数。
- 调用
Esitimator
对象中的train
、evaluate
等方法得到结果
# train the model
classifier.train(
input_fn=lambda:iris_data.train_input_fn(
train_x, train_y, args.batch_size),
steps=args.train_steps)
# evalute the model
eval_result = classifier.evaluate(
input_fn=lambda:iris_data.eval_input_fn(
test_x, test_y, args.batch_size))