欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

关于TensorFlow中Estimator模板的基础理解与使用流程

程序员文章站 2022-05-31 18:46:38
...

理解

在TensorFlow 1.4版本之后,官方开始在入门文档中就鼓励使用高层的Estimator API。并且,许多开源代码中也使用了Estimator模板,因此我认为掌握好这个类使用方法,对于能够优雅地书写Tensorflow程序有着重要的意义。

什么是Estimator?在这里先不考虑源码实现细节以及类定义,首先建立一个整体的认识。Estimator,估计器,这个类的核心思想就是把一个网络封装起来,使用类方法中的trainevalpredict等等进行操作。具体的网络细节对于这个类的运行者来说是一个黑盒,只需要提供输入,选择相应的方法,就可以获得输出。

另外,TensorFlow中对于Estimator,不仅有预设好的对象可以直接生成,还可以自己定义。通俗的讲就是使用预先写好的网络框架(例如DNN)还是用户自定义的网络结构。当然了,预先写好的网络框架也不是死的,具体的隐层数目等等参数都是可以在初始化的时候进行设置的。预定义好的Estimator在大部分文档中被称作pre(made) Estimator,具体包含哪些类型的分类器,可以查看这一部分的文档。

Estimator类的主要结构如下,看不明白也没有关系,我们先关注初步的流程框架。需要重点注意的就是其中的model_fn函数,你会在接下来的使用流程中看到这个函数的作用。
关于TensorFlow中Estimator模板的基础理解与使用流程

基本使用流程

  1. 首先定义特征列(feature_columns)。这个是之后在Estimator对象初始化时需要接收的必要参数
		my_feature_columns = []
		for key in train_x.keys():
			my_feature_columns.append(
               	tf.feature_column.numeric_column(key=key))
  1. 之后初始化Estimator对象。这里分为两种情况:

    a)如果是使用预定义的Estimator(例如DNNClassifier)则可以直接调用其初始化函数。

		classifier = tf.estimator.DNNClassifier(
			    feature_columns=my_feature_columns,
			    hidden_units=[10, 10],
			    n_classes=3)

b) 如果自定义Estimator(意味着自定义的模型),则首先需要定义model_fn函数,描述模型细节,之后将model_fn与其他params一起传入Estimator的初始化函数中。

		classifier = tf.estimator.Estimator(
		        model_fn=my_model,
		        params={
		            'feature_columns': my_feature_columns,
		            'hidden_units': [10, 10],
		            'n_classes': 3,
		        })

其中params字典中的值将会被传入model_fn中用于定义自定义的模型。

前后进行对比,很容易看出两者的区别和共同点。相比之下,可以说自定义的模型比预设的模型多了一层壳,多传了一次参数

  1. 调用Esitimator对象中的trainevaluate等方法得到结果
	# train the model
	classifier.train(
	        input_fn=lambda:iris_data.train_input_fn(
	                         train_x, train_y, args.batch_size),
	        steps=args.train_steps)
	
	# evalute the model
	eval_result = classifier.evaluate(
	        input_fn=lambda:iris_data.eval_input_fn(
	                         test_x, test_y, args.batch_size))