欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

TensorFLow Lite 开发手册(5)——TensorFlow Lite模型使用实例(分类模型)

程序员文章站 2022-04-13 15:43:29
...

(一)新建CLion工程

到(https://download.csdn.net/download/weixin_42499236/11892106)下载该工程,解压后如下图所示:
TensorFLow Lite 开发手册(5)——TensorFlow Lite模型使用实例(分类模型)

(二)编写Cmakelist

cmake_minimum_required(VERSION 3.15)
project(testlite)

set(CMAKE_CXX_STANDARD 14)

include_directories(/home/ai/CLionProjects/tensorflow-master/)
include_directories(/home/ai/CLionProjects/tensorflow-master/tensorflow/lite/tools/make/downloads/flatbuffers/include)
include_directories(/home/ai/CLionProjects/tensorflow-master/tensorflow/lite/tools/make/downloads/absl)

add_executable(testlite main.cpp bitmap_helpers.cc utils.cc)

target_link_libraries(testlite /home/ai/CLionProjects/tensorflow-master/tensorflow/lite/tools/make/gen/linux_x86_64/lib/libtensorflow-lite.a -lpthread -ldl -lrt)

(三)编写main.cpp

  • 导入头文件
#include <fcntl.h>      // NOLINT(build/include_order)
#include <getopt.h>     // NOLINT(build/include_order)
#include <sys/time.h>   // NOLINT(build/include_order)
#include <sys/types.h>  // NOLINT(build/include_order)
#include <sys/uio.h>    // NOLINT(build/include_order)
#include <unistd.h>     // NOLINT(build/include_order)

#include <iostream>
#include <cstdarg>
#include <cstdio>
#include <cstdlib>
#include <fstream>
#include <iomanip>
#include <iostream>
#include <map>
#include <memory>
#include <sstream>
#include <string>
#include <unordered_set>
#include <vector>

#include "bitmap_helpers.h"
#include "get_top_n.h"

#include "tensorflow/lite/model.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/optional_debug_tools.h"
#include "tensorflow/lite/string_util.h"
#include "tensorflow/lite/profiling/profiler.h"
#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
#include "absl/memory/memory.h"
#include "utils.h"

using namespace std;
  • 调用GPU、NNAPI加速(若无GPU,则默认使用CPU)
#define LOG(x) std::cerr

double get_us(struct timeval t) { return (t.tv_sec * 1000000 + t.tv_usec); }

using TfLiteDelegatePtr = tflite::Interpreter::TfLiteDelegatePtr;
using TfLiteDelegatePtrMap = std::map<std::string, TfLiteDelegatePtr>;

// 调用GPU
TfLiteDelegatePtr CreateGPUDelegate(tflite::label_image::Settings* s) {
#if defined(__ANDROID__)
    TfLiteGpuDelegateOptionsV2 gpu_opts = TfLiteGpuDelegateOptionsV2Default();
  gpu_opts.inference_preference =
      TFLITE_GPU_INFERENCE_PREFERENCE_SUSTAINED_SPEED;
  gpu_opts.is_precision_loss_allowed = s->allow_fp16 ? 1 : 0;
  return evaluation::CreateGPUDelegate(s->model, &gpu_opts);
#else
    return tflite::evaluation::CreateGPUDelegate(s->model);
#endif
}

TfLiteDelegatePtrMap GetDelegates(tflite::label_image::Settings* s) {
    TfLiteDelegatePtrMap delegates;
    if (s->gl_backend) {
        auto delegate = CreateGPUDelegate(s);
        if (!delegate) {
            LOG(INFO) << "GPU acceleration is unsupported on this platform.";
        } else {
            delegates.emplace("GPU", std::move(delegate));
        }
    }

    if (s->accel) {
        auto delegate = tflite::evaluation::CreateNNAPIDelegate();
        if (!delegate) {
            LOG(INFO) << "NNAPI acceleration is unsupported on this platform.";
        } else {
            delegates.emplace("NNAPI", tflite::evaluation::CreateNNAPIDelegate());
        }
    }
    return delegates;
}
  • 读取标签文件
TfLiteStatus ReadLabelsFile(const string& file_name,
                            std::vector<string>* result,
                            size_t* found_label_count) {
    std::ifstream file(file_name);
    if (!file) {
        LOG(FATAL) << "Labels file " << file_name << " not found\n";
        return kTfLiteError;
    }
    result->clear();
    string line;
    while (std::getline(file, line)) {
        result->push_back(line);
    }
    *found_label_count = result->size();
    const int padding = 16;
    while (result->size() % padding) {
        result->emplace_back();
    }
    return kTfLiteOk;
}
  • 打印模型节点信息
void PrintProfilingInfo(const tflite::profiling::ProfileEvent* e,
                        uint32_t subgraph_index, uint32_t op_index,
                        TfLiteRegistration registration) {
    // output something like
    // time (ms) , Node xxx, OpCode xxx, symblic name
    //      5.352, Node   5, OpCode   4, DEPTHWISE_CONV_2D

    LOG(INFO) << std::fixed << std::setw(10) << std::setprecision(3)
              << (e->end_timestamp_us - e->begin_timestamp_us) / 1000.0
              << ", Subgraph " << std::setw(3) << std::setprecision(3)
              << subgraph_index << ", Node " << std::setw(3)
              << std::setprecision(3) << op_index << ", OpCode " << std::setw(3)
              << std::setprecision(3) << registration.builtin_code << ", "
              << EnumNameBuiltinOperator(
                      static_cast<tflite::BuiltinOperator>(registration.builtin_code))
              << "\n";
}
  • 定义模型推理函数
void RunInference(tflite::label_image::Settings* s){
    if (!s->model_name.c_str()) {
        LOG(ERROR) << "no model file name\n";
        exit(-1);
    }

// 读取.tflite模型
    std::unique_ptr<tflite::FlatBufferModel> model;
    std::unique_ptr<tflite::Interpreter> interpreter;
    model = tflite::FlatBufferModel::BuildFromFile(s->model_name.c_str());
    if (!model) {
        LOG(FATAL) << "\nFailed to mmap model " << s->model_name << "\n";
        exit(-1);
    }
    s->model = model.get();
    LOG(INFO) << "Loaded model " << s->model_name << "\n";
    model->error_reporter();
    LOG(INFO) << "resolved reporter\n";
// 生成解释器
    tflite::ops::builtin::BuiltinOpResolver resolver;

    tflite::InterpreterBuilder(*model, resolver)(&interpreter);
    if (!interpreter) {
        LOG(FATAL) << "Failed to construct interpreter\n";
        exit(-1);
    }

    interpreter->UseNNAPI(s->old_accel);
    interpreter->SetAllowFp16PrecisionForFp32(s->allow_fp16);
// 打印解释器参数,包括张量大小、输入节点名称等
    if (s->verbose) {
        LOG(INFO) << "tensors size: " << interpreter->tensors_size() << "\n";
        LOG(INFO) << "nodes size: " << interpreter->nodes_size() << "\n";
        LOG(INFO) << "inputs: " << interpreter->inputs().size() << "\n";
        LOG(INFO) << "input(0) name: " << interpreter->GetInputName(0) << "\n";

        int t_size = interpreter->tensors_size();
        for (int i = 0; i < t_size; i++) {
            if (interpreter->tensor(i)->name)
                LOG(INFO) << i << ": " << interpreter->tensor(i)->name << ", "
                          << interpreter->tensor(i)->bytes << ", "
                          << interpreter->tensor(i)->type << ", "
                          << interpreter->tensor(i)->params.scale << ", "
                          << interpreter->tensor(i)->params.zero_point << "\n";
        }
    }

    if (s->number_of_threads != -1) {
        interpreter->SetNumThreads(s->number_of_threads);
    }

// 定义输入图像参数
    int image_width = 224;
    int image_height = 224;
    int image_channels = 3;
// 读取bmp图像
    std::vector<uint8_t> in = tflite::label_image::read_bmp(s->input_bmp_name, &image_width,
                                       &image_height, &image_channels, s);

    int input = interpreter->inputs()[0];

    if (s->verbose) LOG(INFO) << "input: " << input << "\n";

    const std::vector<int> inputs = interpreter->inputs();
    const std::vector<int> outputs = interpreter->outputs();

    if (s->verbose) {
        LOG(INFO) << "number of inputs: " << inputs.size() << "\n";
        LOG(INFO) << "number of outputs: " << outputs.size() << "\n";
    }

// 创建图
    auto delegates_ = GetDelegates(s);
    for (const auto& delegate : delegates_) {
        if (interpreter->ModifyGraphWithDelegate(delegate.second.get()) !=
            kTfLiteOk) {
            LOG(FATAL) << "Failed to apply " << delegate.first << " delegate.";
        } else {
            LOG(INFO) << "Applied " << delegate.first << " delegate.";
        }
    }

    if (interpreter->AllocateTensors() != kTfLiteOk) {
        LOG(FATAL) << "Failed to allocate tensors!";
    }

    if (s->verbose) PrintInterpreterState(interpreter.get());

// 获取输入张量元数据的维度等信息
    TfLiteIntArray* dims = interpreter->tensor(input)->dims;
    int wanted_height = dims->data[1];
    int wanted_width = dims->data[2];
    int wanted_channels = dims->data[3];

// 对图像进行resize
    switch (interpreter->tensor(input)->type) {
        case kTfLiteFloat32:
            s->input_floating = true;
            tflite::label_image::resize<float>(interpreter->typed_tensor<float>(input), in.data(),
                          image_height, image_width, image_channels, wanted_height,
                          wanted_width, wanted_channels, s);
            break;
        case kTfLiteUInt8:
            tflite::label_image::resize<uint8_t>(interpreter->typed_tensor<uint8_t>(input), in.data(),
                            image_height, image_width, image_channels, wanted_height,
                            wanted_width, wanted_channels, s);
            break;
        default:
            LOG(FATAL) << "cannot handle input type "
                       << interpreter->tensor(input)->type << " yet";
            exit(-1);
    }

// 调用解释器
    auto profiler =
            absl::make_unique<tflite::profiling::Profiler>(s->max_profiling_buffer_entries);
    interpreter->SetProfiler(profiler.get());

    if (s->profiling) profiler->StartProfiling();
    if (s->loop_count > 1)
        for (int i = 0; i < s->number_of_warmup_runs; i++) {
            if (interpreter->Invoke() != kTfLiteOk) {
                LOG(FATAL) << "Failed to invoke tflite!\n";
            }
        }
// 进行模型推理并计算运行时间
    struct timeval start_time, stop_time;
    gettimeofday(&start_time, nullptr);
    for (int i = 0; i < s->loop_count; i++) {
        if (interpreter->Invoke() != kTfLiteOk) {
            LOG(FATAL) << "Failed to invoke tflite!\n";
        }
    }
    gettimeofday(&stop_time, nullptr);
    LOG(INFO) << "invoked \n";
    LOG(INFO) << "average time: "
              << (get_us(stop_time) - get_us(start_time)) / (s->loop_count * 1000)
              << " ms \n";
// 打印运行事件
    if (s->profiling) {
        profiler->StopProfiling();
        auto profile_events = profiler->GetProfileEvents();
        for (int i = 0; i < profile_events.size(); i++) {
            auto subgraph_index = profile_events[i]->event_subgraph_index;
            auto op_index = profile_events[i]->event_metadata;
            const auto subgraph = interpreter->subgraph(subgraph_index);
            const auto node_and_registration =
                    subgraph->node_and_registration(op_index);
            const TfLiteRegistration registration = node_and_registration->second;
            PrintProfilingInfo(profile_events[i], subgraph_index, op_index,
                               registration);
        }
    }

    const float threshold = 0.001f;

    std::vector<std::pair<float, int>> top_results;

// 获取Top-N结果
    int output = interpreter->outputs()[0];
    TfLiteIntArray* output_dims = interpreter->tensor(output)->dims;
    // assume output dims to be something like (1, 1, ... ,size)
    auto output_size = output_dims->data[output_dims->size - 1];
    switch (interpreter->tensor(output)->type) {
        case kTfLiteFloat32:
            tflite::label_image::get_top_n<float>(interpreter->typed_output_tensor<float>(0), output_size,
                             s->number_of_results, threshold, &top_results, true);
            break;
        case kTfLiteUInt8:
            tflite::label_image::get_top_n<uint8_t>(interpreter->typed_output_tensor<uint8_t>(0),
                               output_size, s->number_of_results, threshold,
                               &top_results, false);
            break;
        default:
            LOG(FATAL) << "cannot handle output type "
                       << interpreter->tensor(input)->type << " yet";
            exit(-1);
    }

    std::vector<string> labels;
    size_t label_count;

    if (ReadLabelsFile(s->labels_file_name, &labels, &label_count) != kTfLiteOk)
        exit(-1);
// 打印Top-N结果
    for (const auto& result : top_results) {
        const float confidence = result.first;
        const int index = result.second;
        LOG(INFO) << confidence << ": " << index << " " << labels[index] << "\n";
    }
}

int main() {
    tflite::label_image::Settings s;
    RunInference(&s);
}

(四)下载预训练模型

# Get model
curl https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz | tar xzv -C /tmp

# Get labels
curl https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_1.0_224_frozen.tgz  | tar xzv -C /tmp  mobilenet_v1_1.0_224/labels.txt

mv /tmp/mobilenet_v1_1.0_224/labels.txt /tmp/

(五)修改模型配置

在label_image.h中修改Settings:

struct Settings {
  bool verbose = false;
  bool accel = false;
  bool old_accel = false;
  bool input_floating = false;
  bool profiling = false;
  bool allow_fp16 = false;
  bool gl_backend = false;
  int loop_count = 1;
  float input_mean = 127.5f;
  float input_std = 127.5f;
  string model_name = "/home/ai/CLionProjects/tflite/mobilenet_v1_1.0_224/mobilenet_v1_1.0_224.tflite";
  tflite::FlatBufferModel* model;
  string input_bmp_name = "/home/ai/CLionProjects/tflite/grace_hopper.bmp";
  string labels_file_name = "/home/ai/CLionProjects/tflite/mobilenet_v1_1.0_224/labels.txt";
  string input_layer_type = "uint8_t";
  int number_of_threads = 4;
  int number_of_results = 5;
  int max_profiling_buffer_entries = 1024;
  int number_of_warmup_runs = 2;
};

(六)运行实例

Top5分类结果输出如下:

Loaded model /tmp/mobilenet_v1_1.0_224.tflite
resolved reporter
invoked
average time: 68.12 ms
0.860174: 653 653:military uniform
0.0481017: 907 907:Windsor tie
0.00786704: 466 466:bulletproof vest
0.00644932: 514 514:cornet, horn, trumpet, trump
0.00608029: 543 543:drumstick

结果显示该图像被正确分类,平均耗时68.12ms,速度非常快!