Tensorflow 模型转 tflite ,在安卓端使用
自己在将tensorflow模型移动端部署的时候(使用 tensorflow lite),踩了很多坑,查了很多资料,现在做个记录,所有参考资料在文章最后 参考 处列出。
tensorflow lite是TensorFlow Lite 是 Google I/O 2017 大会上的其中一个重要宣布,有了TensorFlow Lite,应用开发者可以在移动设备上部署人工智能。
tensorflow lite 【github】
基本思路:
- 在pc端进行 Tensorflow 模型训练,保存训练模型
- 使用 工具将该模型转换为 Tensorflow lite 模型
- 在Android上使用
tensorflow模型持久化
在tensorflow中进行模型训练,得到适合自己项目的模型。Tensorflow 模型训练好之后会生成三个文件:
- model.ckpt.meta :保存Tensorflow计算图结构,可以理解为神经网络的网络结构
- model.ckpt :保存Tensorflow程序中每一个变量的取值,变量是模型中可训练的部分
- checkpoint :保存一个目录下所有模型文件列表
# 使用tf.train.write_graph导出GraphDef文件
tf.train.write_graph(sess.graph_def, "./", "mz_graph.pb", as_text=False)
# 使用tf.train.save导出checkpoint文件
saver.save(sess, model_path)
生成的模型文件如下图所示:
bazel编译需要的工具
Tensoflow使用的编译工具是 bazel,谷歌开源的自动化构建工具。【bazel传送门】
安装bazel,用来编译 tensorflow 转 tflite 时用到的几个工具,freeze、toco、summarize_graph(具体作用下面说),这些工具都在 tensorflow(从github上clone) 中,按下面命令进行编译(在 tensorflow目录下进行):
bazel build tensorflow/python/tools:freeze_graph
bazel build tensorflow/contrib/lite/toto:toto
Bazel build tensorflow/tools/graph_transforms:summarize_graph (查看模型结构,找出输入输出)
模型转换
将训练好的tf模型,进行freeze、toco操作,freeze主要是将 tensorflow模型持久化 中生成的文件进行合并,得到一个变量值和运算图模型相结合的文件,是将变量值固定在图中的操作。如上图,这步生成 mz_freezegraph.pb .
summarize_graph
该命令查看整个Tensorflow模型概况,使用命令如下,运行之后,得到自己整个网络结构,从中可以找到自己模型的输入输出,如下图(模型比较乱。。。)
# “--in_graph=” 后面是模型存储的位置
bazel-bin/tensorflow/tools/graph_transforms/summarize_graph --in_graph=../mz_graph.pb
freeze_graph
该命令是 Tensorflow模型固化,将Tensorflow模型和计算图上变量的值合二为一,方便直接转换 Tensorflow lite 模型。
bazel-bin/tensorflow/python/tools/freeze_graph\
--input_graph=/tmp/mobilenet_v1_224.pb \
--input_checkpoint=/tmp/checkpoints/mobilenet-10202.ckpt \
--input_binary=true \
--output_graph=/tmp/frozen_mobilenet_v1_224.pb \
--output_node_names=MobileNet/Predictions/Reshape_1
- input_graph :Tensorflow 模型结构文件
- input_checkpoint :Tensorflow 模型 ckpt 文件
- output_graph :输出的freeze文件
- output_node_names :模型输出节点名字,使用 summarize_graph 查看 ,可以在 Tensorflow 网络训练时进行命名
toco
固化模型到 tflite 模型转化
toco --input_file=$(pwd)/mobilenet_v1_1.0_224/frozen_graph.pb \
--input_format=TENSORFLOW_GRAPHDEF \
--output_format=TFLITE \
--output_file=/tmp/mobilenet_v1_1.0_224.tflite \
--inference_type=FLOAT \
--input_type=FLOAT \
--input_arrays=input \
--output_arrays=MobilenetV1/Predictions/Reshape_1 \
--input_shapes=1,224,224,3
- input_file : freeze 之后的 Tensorflow 模型文件
- output_file :转换好的 Tensorflow lite 模型,扩展名为 .tflite
- output_arrays :仍然是Tensorflow 模型的输出
- input_shapes :输入图片的维度
部署Android
1、安装 官方GitHub进行Android软件搭建 Tensorflow lite 【Github】
2、工程中有 Float 和 Quantized 两个模式可选,如下图,这里使用Float,Quantized需要先量化模型,在进行 tflite 模型转换。
3、将生成的 .tflite 文件和 对应的 labels.txt 文件放入Android工程的 assets 文件中。
4、运行即可。
参考
上一篇: Java解析json