欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

Tensorflow部署到移动端

程序员文章站 2024-03-24 15:26:16
...

本文以训练一个简单的图像分类器为主线,讲述从数据准备->训练->验证->预测->模型转换->部署到Android手机整个流程。

1 数据准备->训练->验证->预测

数据准备->训练->验证->预测,这部分可参考tensorflowClassification,其讲述了两种训练图像分类器的方式。
需要注意的是mobilenet比较难训练,建议采用transfer learning进行retrain。直接训练很可能出现虽然loss不断下降,但是进行validation、prediction的时候精度总是0.5(二分类)。

2 模型转换

模型转换有两种方法,一种是直接使用python源码,另一种是使用bazel工具,第二种比较方便,本文采用第二种方式进行讲解。

2.1 安装Bazel

安装Bazel没有比官方的安装教程更为好的参考资料了,请移步Installing Bazel on Ubuntu
官方安装方法有三种,如图1所示,建议采用推荐的方法一。采用方法二进行安装对时候如果SB的IT对网络有限制,那么在执行“curl https://bazel.build/bazel-release.pub.gpg | sudo apt-key add -”的时候很有可能报错。

Tensorflow部署到移动端
图1 Bazel安装

2.2 编译freeze_graph、toco

在模型转换的时候需要用到一些bazel编译出来的工具,比如freeze_graph、transform_graph、toco等,本文主要讲述模型转换中必要对工具。
TensorFlow官方github下载tensorflow,解压之后所在文件夹为tensorflow-master,打开终端,进入tensorflow-master,运行以下指令:
bazel build tensorflow/python/tools:freeze_graph
bazel build tensorflow/contrib/lite/toco:toco
如果一切顺利,则其编译过程会像图2-5所示:

bazel build tensorflow/python/tools:freeze_graph

Tensorflow部署到移动端
图2 freeze_graph编译进行中
Tensorflow部署到移动端
图3 freeze_graph编译成功

bazel build tensorflow/contrib/lite/toco:toco

Tensorflow部署到移动端
图4 toco编译进行中
Tensorflow部署到移动端
图5 toco编译成功

往往事与愿违,bazel编译工具虽然没有坑,但是往往由于某些原因导致编译失败,排错十四天之后,你会发现是因为bazel编译的时候会到https://bitbucket.org去下载一些文件(详见workspace.bzl),而该网站是没有给授权的。折腾很长时间发现是被IT坑了,心中难免不会有千千万万的*在奔腾。
bazel编译不成功,由于网站授权报的错如图6所示:

Tensorflow部署到移动端
图6 网站授权导致bazel编译不成功

2.3 toco_landscape

TensorFlow有多种高级API,其生成模型多样,但都可使用toco工具进行转换为.tflite文件,toco_landscape如图7所示:

Tensorflow部署到移动端
图7 toco_landscape

2.4 eval.pbtxt、.pb、.tflite

训练的时候会生成graph.pbtxt,但是万不可直接用训练的graph.pbtxt进行模型的freeze(企图使用训练的graph.pbtxt将模型转换为.pb格式再由.pb格式转换为.tflite格式的后果详见生成.tflite文件过程中遇到的问题及解决方案),以及后续的transform、toco都应当用验证的eval.pbtxt。下面以mobilenet_v1为例讲述eval.pbtxt的生成方法(更详细请参考tf_export_inference_graph.py):

import tensorflow as tf
slim = tf.contrib.slim
# Can be any nets you want to export
from nets import mobilenet_v1

def export_eval_pbtxt():
  """Export eval.pbtxt."""
  g = tf.Graph()
  with g.as_default():
    inputs = tf.placeholder(dtype=tf.float32,shape=[None,224,224,3])
    scope = mobilenet_v1.mobilenet_v1_arg_scope(
        is_training=False, weight_decay=0.0)
    with slim.arg_scope(scope):
      _, _ = mobilenet_v1.mobilenet_v1(
          inputs,
          is_training=False,
          depth_multiplier=FLAGS.depth_multiplier,
          num_classes=FLAGS.num_classes)
    eval_graph_file = '/home/lg/Desktop/mobilenet_v1_eval.pbtxt'
    with tf.Session() as sess:
          with open(eval_graph_file, 'w') as f:
            f.write(str(g.as_graph_def()))

需要注意的是export_eval_pbtxt()函数中inputs = tf.placeholder(dtype=tf.float32,shape=[None,224,224,3]),placeholder的数据类型应该是float,否则在对单幅图像进行预测时使用tf.gfile.FastFile读取图片、解码之后送入feed_dict时会报如下图8的错误:

Tensorflow部署到移动端
图8 注意placeholder数据类型

然后,frozen the graph:

bazel-bin/tensorflow/python/tools/freeze_graph  \
--input_graph=/home/lg/Desktop/mobilenet_v1_eval.pbtxt \
--input_checkpoint=/home/lg/Desktop/checkpoint/model.ckpt-10000 \
--input_binary=false \
--output_graph=/home/lg/Desktop/frozen_mobilenet_v1_224.pb  \
--output_node_names=MobilenetV1/Predictions/Reshape_1  \
--checkpoint_version=2

最后,.pb转换为.tflite,可以选择保持FLOAT格式,或进行量化:
保持FLOAT:

bazel-bin/tensorflow/contrib/lite/toco/toco \
--input_file=/home/lg/Desktop/frozen_mobilenet_v1_224.pb \
--input_format=TENSORFLOW_GRAPHDEF  \
--output_format=TFLITE  \
--output_file=/home/lg/Desktop/frozen_graph_mobilenet_v1-FLOAT.tflite \
--inference_type=FLOAT  \
--input_type=FLOAT \
--input_arrays=Placeholder  \
--output_arrays=MobilenetV1/Predictions/Reshape_1  \
--input_shapes=1,224,224,3

量化QUANTIZED_UINT8:

bazel-bin/tensorflow/contrib/lite/toco/toco \
--input_file=/home/lg/Desktop/frozen_mobilenet_v1_224.pb \
--input_format=TENSORFLOW_GRAPHDEF  \
--output_format=TFLITE  \
--output_file=/home/lg/Desktop/frozen_graph_mobilenet_v1-QUANTIZED_UINT8.tflite \
--inference_type=QUANTIZED_UINT8  \
--input_type=QUANTIZED_UINT8 \
--input_arrays=Placeholder  \
--output_arrays=MobilenetV1/Predictions/Reshape_1  \
--input_shapes=1,224,224,3 \
--default_ranges_min=0.0 \
--default_ranges_max=255.0

需要注意的是该处的量化是对数据类型的转换,和训练时的量化是两码事。训练时的量化可参考tf_train.py中:

if FLAGS.quantize:
  tf.contrib.quantize.create_training_graph(quant_delay=get_quant_delay())

训练产生的模型文件如图9所示:

Tensorflow部署到移动端
图9 训练产生的模型文件

转换过程中的文件如图10所示:

Tensorflow部署到移动端
图10 转换的模型文件

由图10可看出,量化之后.tflite的大小约是FLOAT的14

2.5 确定.pb、.tflite中input、output的名称

最简单的方式是用tensorboard打开生成的eval.pbtxt文件,可视化之后很容易确定freeze_graph时参数output_node_names的值,以及toco时input_arrays、output_arrays的值。
对于mobilenet_v1,按照前文所述产生eval.pbtxt的方法,找到相应参数名称如图11、12所示:

Tensorflow部署到移动端
图11 mobilenet_v1 input
Tensorflow部署到移动端
图12 mobilenet_v1 output

对于inception_v3 ,按照前文所述产生eval.pbtxt的方法,找到相应参数名称如图13、14所示:

Tensorflow部署到移动端
图13 inception_v3 input
Tensorflow部署到移动端
图14 inception_v3 output

在直接使用train.pbtxt进行freeze_graph、toco的时候找到的input、output相应名称如图15、16所示,可见虽然output是正确的,但是input是错误的。

Tensorflow部署到移动端
图15 mobilenet_v1 错误的使用了train的graph.pbtxt
Tensorflow部署到移动端
图16 inception_v3错误的使用了train的graph.pbtxt

2.5 部署到Android进进手机

该部分可参考之前的文章Tensorflow Lite Android Demo App,需要做的只是把标签文件labels,模型文件.tflite替换掉原工程中对应对文件即可。

3 参考文献

TensorFlow Lite
Preparing models for mobile deployment
TensorFlow Lite & TensorFlow Compatibility Guide
Graph Transform Tool

Tensorflow部署到移动端

更多资料请移步github:
https://github.com/GarryLau/tensorflowClassification