Tensorflow Lite初探(Android)
一、背景:
11月15日,谷歌正式发布了TensorFlow Lite开发者预览版。
TensorFlow Lite 是 Google I/O 2017 大会上的其中一个重要宣布,有了TensorFlow Lite,应用开发者可以在移动设备上部署人工智能。
Google 表示 Lite 版本 TensorFlow 是 TensorFlow Mobile 的一个延伸版本。尽管是一个轻量级版本,依然是在智能手机和嵌入式设备上部署深度学习的一大动作。此前,通过TensorFlow Mobile API,TensorFlow已经支持手机上的模型嵌入式部署。TensorFlow Lite应该被视为TensorFlow Mobile的升级版。
TensorFlow Lite 目前仍处于“积极开发”状态,目前仅有少量预训练AI模型面世,比如MobileNet、用于计算机视觉物体识别的Inception v3、用于自然语言处理的Smart Reply,当然,TensorFlow Lite上也可以部署用自己的数据集定制化训练的模型。
TensorFlow Lite可以与Android 8.1中发布的神经网络API完美配合,即便在没有硬件加速时也能调用CPU处理,确保模型在不同设备上的运行。 而Android端版本演进的控制权是掌握在谷歌手中的,从长期看,TensorFlow Lite会得到Android系统层面上的支持。
其组件包括:
- TensorFlow 模型(TensorFlow Model):保存在磁盘中的训练模型。
- TensorFlow Lite 转化器(TensorFlow Lite Converter):将模型转换成 TensorFlow Lite 文件格式的项目。
- TensorFlow Lite 模型文件(TensorFlow Lite Model File):基于 FlatBuffers,适配最大速度和最小规模的模型。
github链接如下:https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite
二、环境:
Android Studio 3.0, SDK Version API26, NDK Version 14
步骤:
1. 将此项目导入到Android Studio:
https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/java/demo
2. 下载移动端的模型(model)和标签数据(lables):
https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip
3. 下载完成解压mobilenet_v1_224_android_quant_2017_11_08.zip文件得到一个xxx.tflite和labes.txt文件,分别是模型和标签文件,并且把这两个文件复制到assets文件夹下。
4. 构建app,run……
详情可参考:http://blog.csdn.net/wu__di/article/details/78570303
三、源码分析:
整个demo的代码非常少,仅包含4个java文件(相信随着正式版的发布,会有更加丰富的功能以及更多的预训练模型):
其中:
- AutoFitTextureView: 一个自定义View;
- CameraActivity: 整个app的入口activity,这个activity只做了一件事,就是加载了一个fragment;
- Camera2BasicFragment: 入口activity中加载的fragment,其中实现了所有跟UI相关的代码;首先在onActivityCreated中,初始化了一个ImageClassifier对象,此类是整个demo的核心,用于加载模型并实现推理运算功能。然后开启了一个后台线程,在线程中反复地对从摄像头获取的图像进行分类操作。
/** Load the model and labels. */
@Override
public void onActivityCreated(Bundle savedInstanceState) {
super.onActivityCreated(savedInstanceState);
try {
classifier = new ImageClassifier(getActivity());
} catch (IOException e) {
Log.e(TAG, "Failed to initialize an image classifier.");
}
startBackgroundThread();
}
startBackgroundThread()中做的轮询操作:
private Runnable periodicClassify =
new Runnable() {
@Override
public void run() {
synchronized (lock) {
if (runClassifier) {
classifyFrame();
}
}
backgroundHandler.post(periodicClassify);
}
};
其中,classifyFrame()代码如下:
/** Classifies a frame from the preview stream. */
private void classifyFrame() {
if (classifier == null || getActivity() == null || cameraDevice == null) {
showToast("Uninitialized Classifier or invalid context.");
return;
}
Bitmap bitmap =
textureView.getBitmap(ImageClassifier.DIM_IMG_SIZE_X, ImageClassifier.DIM_IMG_SIZE_Y);
String textToShow = classifier.classifyFrame(bitmap);
bitmap.recycle();
showToast(textToShow);
}
大致过程就是从控件textureView中以指定的长宽读取一个Bitmap出来(也就是摄像头的实时画面),然后交给classifier的classifyFrame进行处理,返回一个结果,这个结果就是图片分类的结果,然后显示在手机屏幕上。
ImageClassifier:demo最重要的部分,但只有两个函数比较重要,一个是构造函数:
/** Initializes an {@code ImageClassifier}. */
ImageClassifier(Activity activity) throws IOException {
tflite = new Interpreter(loadModelFile(activity));
labelList = loadLabelList(activity);
imgData =
ByteBuffer.allocateDirect(
DIM_BATCH_SIZE * DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y * DIM_PIXEL_SIZE);
imgData.order(ByteOrder.nativeOrder());
labelProbArray = new byte[1][labelList.size()];
Log.d(TAG, "Created a Tensorflow Lite Image Classifier.");
}
其中Interpreter类非常关键,这是Android app与tensorflow lite之间的桥梁,位于org.tensorflow:tensorflow-lite-0.1.1中:
这个包实现了对张量(tensor)的基本操作,而整个tensorflow就是以张量为单位处理各种运算。
tflite = new Interpreter(loadModelFile(activity))这里通过loadModelFile将asset中的tflite格式的模型文件加载并返回一个MappedByteBuffer传给Interpreter。labelList = loadLabelList(activity)将asset中的labels文件中的分类标签加载到字符串列表labelList中。imgData则是一个存放输入张量的buffer,一个非常典型的(batch_size, x, y, channel)结构,在这里可以理解为一个placeholder。
最后labelProbArray是一个1 x labelList.size()的张量,可以认为是一个向量,元素的个数就是模型输出结果的总类别数,每一个元素代表模型判断到图片为某一类别的概率,对应于labels。
另一个是实现图片分类的函数:
/** Classifies a frame from the preview stream. */
String classifyFrame(Bitmap bitmap) {
if (tflite == null) {
Log.e(TAG, "Image classifier has not been initialized; Skipped.");
return "Uninitialized Classifier.";
}
convertBitmapToByteBuffer(bitmap);
// Here's where the magic happens!!!
long startTime = SystemClock.uptimeMillis();
tflite.run(imgData, labelProbArray);
long endTime = SystemClock.uptimeMillis();
Log.d(TAG, "Timecost to run model inference: " + Long.toString(endTime - startTime));
String textToShow = printTopKLabels();
textToShow = Long.toString(endTime - startTime) + "ms" + textToShow;
return textToShow;
}
首先convertBitmapToByteBuffer将bitmap中的像素值读出,并放入刚才初始化的imgData中,这里相当于为placeholder填充了数据。然后是最关键的一行tflite.run(imgData, labelProbArray),喂数据,得出结果,分类的结果存入labelProbArray中。
#对于这行代码,有没有似曾相识的感觉:
tf.Session().run(output, feed_dict={x:input})
最后labelProbArray转换为需要显示的文字,传给UI层。
四、关于tflite模型
关于tflite,官方有比较详细的说明:
https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite
这里总结一下,生成tflite有两种方式,一种是直接在模型设计流程中,通过tflite提供的接口tf.contrib.lite.toco_convert将推理图转化为可供移动端直接使用的tflite文件(由于目前是预览版,这个接口在正式版的tensorflow中还无法使用):
import tensorflow as tf
img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3))
val = img + tf.constant([1., 2., 3.]) + tf.constant([1., 4., 4.])
out = tf.identity(val, name="out")
with tf.Session() as sess:
tflite_model = tf.contrib.lite.toco_convert(sess.graph_def, [img], [out])
open("converteds_model.tflite", "wb").write(tflite_model)
还有就是将已经训练好的模型文件,转化为tflite格式。由于涉及到模型文件,这里先科普一下tensorflow的模型持久化。
从这里可以找到一些现成的模型:
https://github.com/tensorflow/models
随便下载一个,比如research/adv_imagenet_models当中的模型ens4_adv_inception_v3_2017_08_18.tar.gz,解压后可以得到这些文件:
这些文件保存了模型的信息,一般可通过如下代码生成:
import tensorflow as tf
...
saver = tf.train.Saver()
with tf.Session() as sess:
saver.save(sess, "/model/xxxx.ckpt") #在session中将计算图和变量信息保存到ckpt文件中
虽然只指定了一个文件路径,但是这个目录下会生成3个文件,分别是xxx.ckpt.data,xxx.ckpt.meta,xxx.ckpt.index,正如上图所示。其中,xxx.ckpt.meta保存了计算图结构,xxx.ckpt.data保存了所有变量的取值,xxx.ckpt.index保存了所有变量名。有了这三个文件,就能得到模型的信息并加载到其他项目中。
还有一种文件需要介绍一下,*.pb,官方的描述是这样的:
- GraphDef (.pb) - a protobuf that represents the TensorFlow training and or computation graph. This contains operators, tensors, and variables definitions.
- FrozenGraphDef - a subclass of GraphDef that contains no variables. A GraphDef can be converted to a frozen graphdef by taking a checkpoint and a graphdef and converting every variable into a constant with the value looked up in the checkpoint.
这里可以简单理解为*.pb文件有两种情况,一种是仅保存了计算图结构,不包含变量值,可以通过如下代码生成:
tf.train.write_graph()
还有一种就是上面提到的FrozenGraphDef ,不仅包含计算图结构,还包含了训练产生的变量值,这类*.pb可以直接被加载用于推理运算,tensorflow mobile的一个android应用demo就是很好的例子:
https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android
这个demo里,android应用正是通过FrozenGraphDef的*.pb文件将模型加载到app中,从而实现模型的推理功能。
那么如何使用现有的模型文件生成tflite呢?正式需要这样一个包含计算图和变量值的冻结图文件(*.pb)。
如果已经有了这个冻结图文件,根据官方文档,可以使用如下命令生成tflite:
bazel build tensorflow/contrib/lite/toco:toco
bazel-bin/tensorflow/contrib/lite/toco/toco -- \
--input_file=$(pwd)/mobilenet_v1_1.0_224/frozen_graph.pb \
--input_format=TENSORFLOW_GRAPHDEF --output_format=TFLITE \
--output_file=/tmp/mobilenet_v1_1.0_224.lite --inference_type=FLOAT \
--input_type=FLOAT --input_arrays=input \
--output_arrays=MobilenetV1/Predictions/Reshape_1 --input_shapes=1,224,224,3
如果没有冻结图,也可以根据包含变量值的ckpt和仅包含计算图结构的pb文件生成一个冻结图文件:
bazel build tensorflow/python/tools:freeze_graph
bazel-bin/tensorflow/python/tools/freeze_graph\
--input_graph=/tmp/mobilenet_v1_224.pb \
--input_checkpoint=/tmp/checkpoints/mobilenet-10202.ckpt \
--input_binary=true --output_graph=/tmp/frozen_mobilenet_v1_224.pb \
--output_node_names=MobileNet/Predictions/Reshape_1
最后,如果想要使用一些现成的tflite模型,可以从这里找到:
https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/g3doc/models.md
推荐阅读
-
TensorFlow Lite简介
-
Tensorflow Lite初探(Android)
-
Tensorflow lite for 移动端安卓开发(三)——移动端测试自己的模型
-
基于Android搭建tensorflow lite,实现官网的Demo以及运行自定义tensorflow模型(二)
-
Tensorflow Lite初探
-
TensorFlow Lite入门
-
在Android上可视化TensorFlow Lite AI结果
-
Building TensorFlow on Android so Easy
-
tensorflow入门:tensor初探
-
学习Android开发之RecyclerView使用初探