TensorFLow Lite 开发手册(5)——TensorFlow Lite模型使用实例(分类模型)
程序员文章站
2022-04-13 15:43:29
...
TensorFLow Lite 开发手册(5)——TensorFlow Lite模型使用实例(分类模型)
(一)新建CLion工程
到(https://download.csdn.net/download/weixin_42499236/11892106)下载该工程,解压后如下图所示:
(二)编写Cmakelist
cmake_minimum_required(VERSION 3.15)
project(testlite)
set(CMAKE_CXX_STANDARD 14)
include_directories(/home/ai/CLionProjects/tensorflow-master/)
include_directories(/home/ai/CLionProjects/tensorflow-master/tensorflow/lite/tools/make/downloads/flatbuffers/include)
include_directories(/home/ai/CLionProjects/tensorflow-master/tensorflow/lite/tools/make/downloads/absl)
add_executable(testlite main.cpp bitmap_helpers.cc utils.cc)
target_link_libraries(testlite /home/ai/CLionProjects/tensorflow-master/tensorflow/lite/tools/make/gen/linux_x86_64/lib/libtensorflow-lite.a -lpthread -ldl -lrt)
(三)编写main.cpp
- 导入头文件
#include <fcntl.h> // NOLINT(build/include_order)
#include <getopt.h> // NOLINT(build/include_order)
#include <sys/time.h> // NOLINT(build/include_order)
#include <sys/types.h> // NOLINT(build/include_order)
#include <sys/uio.h> // NOLINT(build/include_order)
#include <unistd.h> // NOLINT(build/include_order)
#include <iostream>
#include <cstdarg>
#include <cstdio>
#include <cstdlib>
#include <fstream>
#include <iomanip>
#include <iostream>
#include <map>
#include <memory>
#include <sstream>
#include <string>
#include <unordered_set>
#include <vector>
#include "bitmap_helpers.h"
#include "get_top_n.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/optional_debug_tools.h"
#include "tensorflow/lite/string_util.h"
#include "tensorflow/lite/profiling/profiler.h"
#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
#include "absl/memory/memory.h"
#include "utils.h"
using namespace std;
- 调用GPU、NNAPI加速(若无GPU,则默认使用CPU)
#define LOG(x) std::cerr
double get_us(struct timeval t) { return (t.tv_sec * 1000000 + t.tv_usec); }
using TfLiteDelegatePtr = tflite::Interpreter::TfLiteDelegatePtr;
using TfLiteDelegatePtrMap = std::map<std::string, TfLiteDelegatePtr>;
// 调用GPU
TfLiteDelegatePtr CreateGPUDelegate(tflite::label_image::Settings* s) {
#if defined(__ANDROID__)
TfLiteGpuDelegateOptionsV2 gpu_opts = TfLiteGpuDelegateOptionsV2Default();
gpu_opts.inference_preference =
TFLITE_GPU_INFERENCE_PREFERENCE_SUSTAINED_SPEED;
gpu_opts.is_precision_loss_allowed = s->allow_fp16 ? 1 : 0;
return evaluation::CreateGPUDelegate(s->model, &gpu_opts);
#else
return tflite::evaluation::CreateGPUDelegate(s->model);
#endif
}
TfLiteDelegatePtrMap GetDelegates(tflite::label_image::Settings* s) {
TfLiteDelegatePtrMap delegates;
if (s->gl_backend) {
auto delegate = CreateGPUDelegate(s);
if (!delegate) {
LOG(INFO) << "GPU acceleration is unsupported on this platform.";
} else {
delegates.emplace("GPU", std::move(delegate));
}
}
if (s->accel) {
auto delegate = tflite::evaluation::CreateNNAPIDelegate();
if (!delegate) {
LOG(INFO) << "NNAPI acceleration is unsupported on this platform.";
} else {
delegates.emplace("NNAPI", tflite::evaluation::CreateNNAPIDelegate());
}
}
return delegates;
}
- 读取标签文件
TfLiteStatus ReadLabelsFile(const string& file_name,
std::vector<string>* result,
size_t* found_label_count) {
std::ifstream file(file_name);
if (!file) {
LOG(FATAL) << "Labels file " << file_name << " not found\n";
return kTfLiteError;
}
result->clear();
string line;
while (std::getline(file, line)) {
result->push_back(line);
}
*found_label_count = result->size();
const int padding = 16;
while (result->size() % padding) {
result->emplace_back();
}
return kTfLiteOk;
}
- 打印模型节点信息
void PrintProfilingInfo(const tflite::profiling::ProfileEvent* e,
uint32_t subgraph_index, uint32_t op_index,
TfLiteRegistration registration) {
// output something like
// time (ms) , Node xxx, OpCode xxx, symblic name
// 5.352, Node 5, OpCode 4, DEPTHWISE_CONV_2D
LOG(INFO) << std::fixed << std::setw(10) << std::setprecision(3)
<< (e->end_timestamp_us - e->begin_timestamp_us) / 1000.0
<< ", Subgraph " << std::setw(3) << std::setprecision(3)
<< subgraph_index << ", Node " << std::setw(3)
<< std::setprecision(3) << op_index << ", OpCode " << std::setw(3)
<< std::setprecision(3) << registration.builtin_code << ", "
<< EnumNameBuiltinOperator(
static_cast<tflite::BuiltinOperator>(registration.builtin_code))
<< "\n";
}
- 定义模型推理函数
void RunInference(tflite::label_image::Settings* s){
if (!s->model_name.c_str()) {
LOG(ERROR) << "no model file name\n";
exit(-1);
}
// 读取.tflite模型
std::unique_ptr<tflite::FlatBufferModel> model;
std::unique_ptr<tflite::Interpreter> interpreter;
model = tflite::FlatBufferModel::BuildFromFile(s->model_name.c_str());
if (!model) {
LOG(FATAL) << "\nFailed to mmap model " << s->model_name << "\n";
exit(-1);
}
s->model = model.get();
LOG(INFO) << "Loaded model " << s->model_name << "\n";
model->error_reporter();
LOG(INFO) << "resolved reporter\n";
// 生成解释器
tflite::ops::builtin::BuiltinOpResolver resolver;
tflite::InterpreterBuilder(*model, resolver)(&interpreter);
if (!interpreter) {
LOG(FATAL) << "Failed to construct interpreter\n";
exit(-1);
}
interpreter->UseNNAPI(s->old_accel);
interpreter->SetAllowFp16PrecisionForFp32(s->allow_fp16);
// 打印解释器参数,包括张量大小、输入节点名称等
if (s->verbose) {
LOG(INFO) << "tensors size: " << interpreter->tensors_size() << "\n";
LOG(INFO) << "nodes size: " << interpreter->nodes_size() << "\n";
LOG(INFO) << "inputs: " << interpreter->inputs().size() << "\n";
LOG(INFO) << "input(0) name: " << interpreter->GetInputName(0) << "\n";
int t_size = interpreter->tensors_size();
for (int i = 0; i < t_size; i++) {
if (interpreter->tensor(i)->name)
LOG(INFO) << i << ": " << interpreter->tensor(i)->name << ", "
<< interpreter->tensor(i)->bytes << ", "
<< interpreter->tensor(i)->type << ", "
<< interpreter->tensor(i)->params.scale << ", "
<< interpreter->tensor(i)->params.zero_point << "\n";
}
}
if (s->number_of_threads != -1) {
interpreter->SetNumThreads(s->number_of_threads);
}
// 定义输入图像参数
int image_width = 224;
int image_height = 224;
int image_channels = 3;
// 读取bmp图像
std::vector<uint8_t> in = tflite::label_image::read_bmp(s->input_bmp_name, &image_width,
&image_height, &image_channels, s);
int input = interpreter->inputs()[0];
if (s->verbose) LOG(INFO) << "input: " << input << "\n";
const std::vector<int> inputs = interpreter->inputs();
const std::vector<int> outputs = interpreter->outputs();
if (s->verbose) {
LOG(INFO) << "number of inputs: " << inputs.size() << "\n";
LOG(INFO) << "number of outputs: " << outputs.size() << "\n";
}
// 创建图
auto delegates_ = GetDelegates(s);
for (const auto& delegate : delegates_) {
if (interpreter->ModifyGraphWithDelegate(delegate.second.get()) !=
kTfLiteOk) {
LOG(FATAL) << "Failed to apply " << delegate.first << " delegate.";
} else {
LOG(INFO) << "Applied " << delegate.first << " delegate.";
}
}
if (interpreter->AllocateTensors() != kTfLiteOk) {
LOG(FATAL) << "Failed to allocate tensors!";
}
if (s->verbose) PrintInterpreterState(interpreter.get());
// 获取输入张量元数据的维度等信息
TfLiteIntArray* dims = interpreter->tensor(input)->dims;
int wanted_height = dims->data[1];
int wanted_width = dims->data[2];
int wanted_channels = dims->data[3];
// 对图像进行resize
switch (interpreter->tensor(input)->type) {
case kTfLiteFloat32:
s->input_floating = true;
tflite::label_image::resize<float>(interpreter->typed_tensor<float>(input), in.data(),
image_height, image_width, image_channels, wanted_height,
wanted_width, wanted_channels, s);
break;
case kTfLiteUInt8:
tflite::label_image::resize<uint8_t>(interpreter->typed_tensor<uint8_t>(input), in.data(),
image_height, image_width, image_channels, wanted_height,
wanted_width, wanted_channels, s);
break;
default:
LOG(FATAL) << "cannot handle input type "
<< interpreter->tensor(input)->type << " yet";
exit(-1);
}
// 调用解释器
auto profiler =
absl::make_unique<tflite::profiling::Profiler>(s->max_profiling_buffer_entries);
interpreter->SetProfiler(profiler.get());
if (s->profiling) profiler->StartProfiling();
if (s->loop_count > 1)
for (int i = 0; i < s->number_of_warmup_runs; i++) {
if (interpreter->Invoke() != kTfLiteOk) {
LOG(FATAL) << "Failed to invoke tflite!\n";
}
}
// 进行模型推理并计算运行时间
struct timeval start_time, stop_time;
gettimeofday(&start_time, nullptr);
for (int i = 0; i < s->loop_count; i++) {
if (interpreter->Invoke() != kTfLiteOk) {
LOG(FATAL) << "Failed to invoke tflite!\n";
}
}
gettimeofday(&stop_time, nullptr);
LOG(INFO) << "invoked \n";
LOG(INFO) << "average time: "
<< (get_us(stop_time) - get_us(start_time)) / (s->loop_count * 1000)
<< " ms \n";
// 打印运行事件
if (s->profiling) {
profiler->StopProfiling();
auto profile_events = profiler->GetProfileEvents();
for (int i = 0; i < profile_events.size(); i++) {
auto subgraph_index = profile_events[i]->event_subgraph_index;
auto op_index = profile_events[i]->event_metadata;
const auto subgraph = interpreter->subgraph(subgraph_index);
const auto node_and_registration =
subgraph->node_and_registration(op_index);
const TfLiteRegistration registration = node_and_registration->second;
PrintProfilingInfo(profile_events[i], subgraph_index, op_index,
registration);
}
}
const float threshold = 0.001f;
std::vector<std::pair<float, int>> top_results;
// 获取Top-N结果
int output = interpreter->outputs()[0];
TfLiteIntArray* output_dims = interpreter->tensor(output)->dims;
// assume output dims to be something like (1, 1, ... ,size)
auto output_size = output_dims->data[output_dims->size - 1];
switch (interpreter->tensor(output)->type) {
case kTfLiteFloat32:
tflite::label_image::get_top_n<float>(interpreter->typed_output_tensor<float>(0), output_size,
s->number_of_results, threshold, &top_results, true);
break;
case kTfLiteUInt8:
tflite::label_image::get_top_n<uint8_t>(interpreter->typed_output_tensor<uint8_t>(0),
output_size, s->number_of_results, threshold,
&top_results, false);
break;
default:
LOG(FATAL) << "cannot handle output type "
<< interpreter->tensor(input)->type << " yet";
exit(-1);
}
std::vector<string> labels;
size_t label_count;
if (ReadLabelsFile(s->labels_file_name, &labels, &label_count) != kTfLiteOk)
exit(-1);
// 打印Top-N结果
for (const auto& result : top_results) {
const float confidence = result.first;
const int index = result.second;
LOG(INFO) << confidence << ": " << index << " " << labels[index] << "\n";
}
}
int main() {
tflite::label_image::Settings s;
RunInference(&s);
}
(四)下载预训练模型
# Get model
curl https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz | tar xzv -C /tmp
# Get labels
curl https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_1.0_224_frozen.tgz | tar xzv -C /tmp mobilenet_v1_1.0_224/labels.txt
mv /tmp/mobilenet_v1_1.0_224/labels.txt /tmp/
(五)修改模型配置
在label_image.h中修改Settings:
struct Settings {
bool verbose = false;
bool accel = false;
bool old_accel = false;
bool input_floating = false;
bool profiling = false;
bool allow_fp16 = false;
bool gl_backend = false;
int loop_count = 1;
float input_mean = 127.5f;
float input_std = 127.5f;
string model_name = "/home/ai/CLionProjects/tflite/mobilenet_v1_1.0_224/mobilenet_v1_1.0_224.tflite";
tflite::FlatBufferModel* model;
string input_bmp_name = "/home/ai/CLionProjects/tflite/grace_hopper.bmp";
string labels_file_name = "/home/ai/CLionProjects/tflite/mobilenet_v1_1.0_224/labels.txt";
string input_layer_type = "uint8_t";
int number_of_threads = 4;
int number_of_results = 5;
int max_profiling_buffer_entries = 1024;
int number_of_warmup_runs = 2;
};
(六)运行实例
Top5分类结果输出如下:
Loaded model /tmp/mobilenet_v1_1.0_224.tflite
resolved reporter
invoked
average time: 68.12 ms
0.860174: 653 653:military uniform
0.0481017: 907 907:Windsor tie
0.00786704: 466 466:bulletproof vest
0.00644932: 514 514:cornet, horn, trumpet, trump
0.00608029: 543 543:drumstick
结果显示该图像被正确分类,平均耗时68.12ms,速度非常快!
上一篇: 阿里JAVA开发手册零度的思考理解(一)
下一篇: 《Java开发手册》泰山版来袭!