欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

Tensorflow 模型转 tflite ,在安卓端使用

程序员文章站 2024-03-24 15:22:16
...

自己在将tensorflow模型移动端部署的时候(使用 tensorflow lite),踩了很多坑,查了很多资料,现在做个记录,所有参考资料在文章最后 参考 处列出。

tensorflow lite是TensorFlow Lite 是 Google I/O 2017 大会上的其中一个重要宣布,有了TensorFlow Lite,应用开发者可以在移动设备上部署人工智能。
tensorflow lite 【github】

Tensorflow 模型转 tflite ,在安卓端使用

基本思路:

  1. 在pc端进行 Tensorflow 模型训练,保存训练模型
  2. 使用 工具将该模型转换为 Tensorflow lite 模型
  3. 在Android上使用

tensorflow模型持久化

在tensorflow中进行模型训练,得到适合自己项目的模型。Tensorflow 模型训练好之后会生成三个文件:

  • model.ckpt.meta :保存Tensorflow计算图结构,可以理解为神经网络的网络结构
  • model.ckpt :保存Tensorflow程序中每一个变量的取值,变量是模型中可训练的部分
  • checkpoint :保存一个目录下所有模型文件列表
# 使用tf.train.write_graph导出GraphDef文件
tf.train.write_graph(sess.graph_def, "./", "mz_graph.pb", as_text=False)
# 使用tf.train.save导出checkpoint文件
saver.save(sess, model_path)

生成的模型文件如下图所示:
Tensorflow 模型转 tflite ,在安卓端使用

bazel编译需要的工具

Tensoflow使用的编译工具是 bazel,谷歌开源的自动化构建工具。【bazel传送门】
安装bazel,用来编译 tensorflow 转 tflite 时用到的几个工具,freeze、toco、summarize_graph(具体作用下面说),这些工具都在 tensorflow(从github上clone) 中,按下面命令进行编译(在 tensorflow目录下进行):

bazel build tensorflow/python/tools:freeze_graph

bazel build tensorflow/contrib/lite/toto:toto

Bazel build tensorflow/tools/graph_transforms:summarize_graph  (查看模型结构,找出输入输出)

模型转换

将训练好的tf模型,进行freeze、toco操作,freeze主要是将 tensorflow模型持久化 中生成的文件进行合并,得到一个变量值和运算图模型相结合的文件,是将变量值固定在图中的操作。如上图,这步生成 mz_freezegraph.pb .

summarize_graph

该命令查看整个Tensorflow模型概况,使用命令如下,运行之后,得到自己整个网络结构,从中可以找到自己模型的输入输出,如下图(模型比较乱。。。)

# --in_graph=” 后面是模型存储的位置
bazel-bin/tensorflow/tools/graph_transforms/summarize_graph --in_graph=../mz_graph.pb

Tensorflow 模型转 tflite ,在安卓端使用

freeze_graph

该命令是 Tensorflow模型固化,将Tensorflow模型和计算图上变量的值合二为一,方便直接转换 Tensorflow lite 模型。

    bazel-bin/tensorflow/python/tools/freeze_graph\
        --input_graph=/tmp/mobilenet_v1_224.pb \
        --input_checkpoint=/tmp/checkpoints/mobilenet-10202.ckpt \
        --input_binary=true \
        --output_graph=/tmp/frozen_mobilenet_v1_224.pb \
        --output_node_names=MobileNet/Predictions/Reshape_1
  • input_graph :Tensorflow 模型结构文件
  • input_checkpoint :Tensorflow 模型 ckpt 文件
  • output_graph :输出的freeze文件
  • output_node_names :模型输出节点名字,使用 summarize_graph 查看 ,可以在 Tensorflow 网络训练时进行命名

Tensorflow 模型转 tflite ,在安卓端使用

toco

固化模型到 tflite 模型转化

toco --input_file=$(pwd)/mobilenet_v1_1.0_224/frozen_graph.pb \
      --input_format=TENSORFLOW_GRAPHDEF \
      --output_format=TFLITE \
      --output_file=/tmp/mobilenet_v1_1.0_224.tflite \
      --inference_type=FLOAT \
      --input_type=FLOAT \
      --input_arrays=input \
      --output_arrays=MobilenetV1/Predictions/Reshape_1 \
      --input_shapes=1,224,224,3
  • input_file : freeze 之后的 Tensorflow 模型文件
  • output_file :转换好的 Tensorflow lite 模型,扩展名为 .tflite
  • output_arrays :仍然是Tensorflow 模型的输出
  • input_shapes :输入图片的维度

Tensorflow 模型转 tflite ,在安卓端使用

部署Android

1、安装 官方GitHub进行Android软件搭建 Tensorflow lite 【Github】
2、工程中有 FloatQuantized 两个模式可选,如下图,这里使用Float,Quantized需要先量化模型,在进行 tflite 模型转换。
3、将生成的 .tflite 文件和 对应的 labels.txt 文件放入Android工程的 assets 文件中。
4、运行即可。

Tensorflow 模型转 tflite ,在安卓端使用

参考

  1. TensorFlow Lite学习笔记2:生成TFLite模型文件
  2. TensorFlow固化模型
  3. TensorFlow Lite模型生成以及bazel的安装使用、出现的问题及解决方案整合
  4. Tensorflow Lite之编译生成tflite文件
  5. tensorflow Lite的使用
  6. tensorflow模型量化
  7. 用 TensorFlow 压缩神经网络
  8. 在Android上使用TensorFlow Lite