Tensorflow部署到移动端
本文以训练一个简单的图像分类器为主线,讲述从数据准备->训练->验证->预测->模型转换->部署到Android手机整个流程。
1 数据准备->训练->验证->预测
数据准备->训练->验证->预测,这部分可参考tensorflowClassification,其讲述了两种训练图像分类器的方式。
需要注意的是mobilenet比较难训练,建议采用transfer learning进行retrain。直接训练很可能出现虽然loss不断下降,但是进行validation、prediction的时候精度总是0.5(二分类)。
2 模型转换
模型转换有两种方法,一种是直接使用python源码,另一种是使用bazel工具,第二种比较方便,本文采用第二种方式进行讲解。
2.1 安装Bazel
安装Bazel没有比官方的安装教程更为好的参考资料了,请移步Installing Bazel on Ubuntu 。
官方安装方法有三种,如图1所示,建议采用推荐的方法一。采用方法二进行安装对时候如果SB的IT对网络有限制,那么在执行“curl https://bazel.build/bazel-release.pub.gpg | sudo apt-key add -”的时候很有可能报错。
2.2 编译freeze_graph、toco
在模型转换的时候需要用到一些bazel编译出来的工具,比如freeze_graph、transform_graph、toco等,本文主要讲述模型转换中必要对工具。
在TensorFlow官方github下载tensorflow,解压之后所在文件夹为tensorflow-master,打开终端,进入tensorflow-master,运行以下指令:
bazel build tensorflow/python/tools:freeze_graph
bazel build tensorflow/contrib/lite/toco:toco
如果一切顺利,则其编译过程会像图2-5所示:
往往事与愿违,bazel编译工具虽然没有坑,但是往往由于某些原因导致编译失败,排错十四天之后,你会发现是因为bazel编译的时候会到https://bitbucket.org去下载一些文件(详见workspace.bzl),而该网站是没有给授权的。折腾很长时间发现是被IT坑了,心中难免不会有千千万万的*在奔腾。
bazel编译不成功,由于网站授权报的错如图6所示:
2.3 toco_landscape
TensorFlow有多种高级API,其生成模型多样,但都可使用toco工具进行转换为.tflite文件,toco_landscape如图7所示:
2.4 eval.pbtxt、.pb、.tflite
训练的时候会生成graph.pbtxt,但是万不可直接用训练的graph.pbtxt进行模型的freeze(企图使用训练的graph.pbtxt将模型转换为.pb格式再由.pb格式转换为.tflite格式的后果详见生成.tflite文件过程中遇到的问题及解决方案),以及后续的transform、toco都应当用验证的eval.pbtxt。下面以mobilenet_v1为例讲述eval.pbtxt的生成方法(更详细请参考tf_export_inference_graph.py):
import tensorflow as tf
slim = tf.contrib.slim
# Can be any nets you want to export
from nets import mobilenet_v1
def export_eval_pbtxt():
"""Export eval.pbtxt."""
g = tf.Graph()
with g.as_default():
inputs = tf.placeholder(dtype=tf.float32,shape=[None,224,224,3])
scope = mobilenet_v1.mobilenet_v1_arg_scope(
is_training=False, weight_decay=0.0)
with slim.arg_scope(scope):
_, _ = mobilenet_v1.mobilenet_v1(
inputs,
is_training=False,
depth_multiplier=FLAGS.depth_multiplier,
num_classes=FLAGS.num_classes)
eval_graph_file = '/home/lg/Desktop/mobilenet_v1_eval.pbtxt'
with tf.Session() as sess:
with open(eval_graph_file, 'w') as f:
f.write(str(g.as_graph_def()))
需要注意的是export_eval_pbtxt()函数中inputs = tf.placeholder(dtype=tf.float32,shape=[None,224,224,3]),placeholder的数据类型应该是float,否则在对单幅图像进行预测时使用tf.gfile.FastFile读取图片、解码之后送入feed_dict时会报如下图8的错误:
然后,frozen the graph:
bazel-bin/tensorflow/python/tools/freeze_graph \
--input_graph=/home/lg/Desktop/mobilenet_v1_eval.pbtxt \
--input_checkpoint=/home/lg/Desktop/checkpoint/model.ckpt-10000 \
--input_binary=false \
--output_graph=/home/lg/Desktop/frozen_mobilenet_v1_224.pb \
--output_node_names=MobilenetV1/Predictions/Reshape_1 \
--checkpoint_version=2
最后,.pb转换为.tflite,可以选择保持FLOAT格式,或进行量化:
保持FLOAT:
bazel-bin/tensorflow/contrib/lite/toco/toco \
--input_file=/home/lg/Desktop/frozen_mobilenet_v1_224.pb \
--input_format=TENSORFLOW_GRAPHDEF \
--output_format=TFLITE \
--output_file=/home/lg/Desktop/frozen_graph_mobilenet_v1-FLOAT.tflite \
--inference_type=FLOAT \
--input_type=FLOAT \
--input_arrays=Placeholder \
--output_arrays=MobilenetV1/Predictions/Reshape_1 \
--input_shapes=1,224,224,3
量化QUANTIZED_UINT8:
bazel-bin/tensorflow/contrib/lite/toco/toco \
--input_file=/home/lg/Desktop/frozen_mobilenet_v1_224.pb \
--input_format=TENSORFLOW_GRAPHDEF \
--output_format=TFLITE \
--output_file=/home/lg/Desktop/frozen_graph_mobilenet_v1-QUANTIZED_UINT8.tflite \
--inference_type=QUANTIZED_UINT8 \
--input_type=QUANTIZED_UINT8 \
--input_arrays=Placeholder \
--output_arrays=MobilenetV1/Predictions/Reshape_1 \
--input_shapes=1,224,224,3 \
--default_ranges_min=0.0 \
--default_ranges_max=255.0
需要注意的是该处的量化是对数据类型的转换,和训练时的量化是两码事。训练时的量化可参考tf_train.py中:
if FLAGS.quantize:
tf.contrib.quantize.create_training_graph(quant_delay=get_quant_delay())
训练产生的模型文件如图9所示:
转换过程中的文件如图10所示:
由图10可看出,量化之后.tflite的大小约是FLOAT的。
2.5 确定.pb、.tflite中input、output的名称
最简单的方式是用tensorboard打开生成的eval.pbtxt文件,可视化之后很容易确定freeze_graph时参数output_node_names的值,以及toco时input_arrays、output_arrays的值。
对于mobilenet_v1,按照前文所述产生eval.pbtxt的方法,找到相应参数名称如图11、12所示:
对于inception_v3 ,按照前文所述产生eval.pbtxt的方法,找到相应参数名称如图13、14所示:
在直接使用train.pbtxt进行freeze_graph、toco的时候找到的input、output相应名称如图15、16所示,可见虽然output是正确的,但是input是错误的。
2.5 部署到Android进进手机
该部分可参考之前的文章Tensorflow Lite Android Demo App,需要做的只是把标签文件labels,模型文件.tflite替换掉原工程中对应对文件即可。
3 参考文献
TensorFlow Lite
Preparing models for mobile deployment
TensorFlow Lite & TensorFlow Compatibility Guide
Graph Transform Tool
更多资料请移步github:
https://github.com/GarryLau/tensorflowClassification