MNN Interpreter and Session
MNN 中 Interpreter 和 Session 之间的关系如下图所示:
- Interpreter 为模型解释器,亦为会话管理器,负责从文件加载模型,创建并执行会话;
- 一个 Interpreter 对应到一个 Net,管理基于该网络的多个任务;
- Session 表示推理任务,由 Pipeline 数组和输入输出张量字典组成;
- Session 根据 Schedule::schedule 生成的 Schedule::ScheduleInfo 创建;
- 归属于同一 Net 的不同 Session 间不能并发调用。
Interpreter
Interpreter 由模型文件构建,不允许复制或拷贝构造。
Interpreter::createFromFile
Interpreter::createFromFile 首先借助 FileLoader 分块读取文件内容,合并后由 Interpreter::createFromBufferInternal 创建一个 Interpreter 对象。
FileLoader::valid 判断文件句柄是否为空。
FileLoader::read 分块读取。
Content 包含读取的缓存以及由此构建出的网络、会话和张量信息。
FileLoader::merge 合并读取的结果到给定的缓存。
if (nullptr == file) {
MNN_PRINT("NULL file for create interpreter\n");
return nullptr;
}
std::unique_ptr<FileLoader> loader(new FileLoader(file));
if (!loader->valid()) {
MNN_PRINT("Create interpreter failed, open %s error\n", file);
return nullptr;
}
bool result = loader->read();
if (!result) {
MNN_PRINT("Read file error\n");
return nullptr;
}
if (loader->size() == 0) {
MNN_PRINT("Create interpreter failed, %s is empty\n", file);
return nullptr;
}
auto net = new Content;
bool success = loader->merge(net->buffer);
if (!success) {
return nullptr;
}
loader.reset();
return createFromBufferInternal(net);
Interpreter::createFromBufferInternal
Verifier 检查缓存完整性。
if (nullptr == net) {
MNN_PRINT("Buffer is null for create interpreter\n");
return nullptr;
}
flatbuffers::Verifier verify((const uint8_t*)(net->buffer.get()), net->buffer.size());
if (false == VerifyNetBuffer(verify)) {
MNN_PRINT("Invalidate buffer to create interpreter\n");
delete net;
return nullptr;
}
从缓存中获取定义的 Net,检查其中的 op 列表。
net->net = GetNet(net->buffer.get());
if (nullptr == net->net->oplists()) {
MNN_ERROR("Model has no oplist\n");
delete net;
return nullptr;
}
int opSize = net->net->oplists()->size();
for (int i=0; i<opSize; ++i) {
auto op = net->net->oplists()->GetAs<Op>(i);
if (nullptr == op || nullptr == op->outputIndexes()) {
MNN_ERROR("Invalid Model, the %d op is empty\n", i);
delete net;
return nullptr;
}
}
创建一个 Interpreter 对象,使用 create 函数的好处是及时检查错误并返回,但不能返回错误码。
return new Interpreter(net);
}
Interpreter::createSession
根据 ScheduleConfig 创建会话。
return createMultiPathSession({config});
Interpreter::createMultiPathSession
检查网络的缓存是否已释放。
if (nullptr == mNet->buffer.get()) {
MNN_ERROR("The model buffer has been released. Can't create session\n");
return nullptr;
}
Schedule::schedule 函数由网络和配置创建 Schedule::ScheduleInfo 结构体,进而根据该结果创建 Session。
Session::resize 将创建的 Session 追加到 Interpreter 存储的队列中。
auto info = Schedule::schedule(mNet->net, configs);
auto newSession = std::unique_ptr<Session>(new Session(info));
if (!newSession->valid()) {
MNN_PRINT("Invalide Session!!\n");
return nullptr;
}
auto result = newSession.get();
if (info.validForResize) {
result->resize();
}
mNet->sessions.emplace_back(std::move(newSession));
return result;
Interpreter::getSessionInput
Session::getInput 根据名字从输入字典中找到张量。
将张量与 Session 的对应关系记录到tensorMap
中。
MNN_ASSERT(nullptr != session);
if (session == nullptr) {
return nullptr;
}
auto tensor = session->getInput(name);
mNet->tensorMap.insert(std::make_pair(tensor, session));
return tensor;
Interpreter::resizeTensor
维度检查。
MNN_ASSERT(nullptr != tensor);
bool dirty = false;
if (tensor->buffer().dimensions != dims.size()) {
dirty = true;
} else {
for (int i = 0; i < dims.size(); ++i) {
if (tensor->buffer().dim[i].extent != dims[i]) {
dirty = true;
break;
}
}
}
if (!dirty) {
return;
}
设置张量的维度。
tensor->buffer().dimensions = (int)dims.size();
for (int i = 0; i < dims.size(); ++i) {
tensor->buffer().dim[i].extent = dims[i];
}
设置该张量相关 Session 的状态。
Session::setNeedResize
auto relatedSessionIter = mNet->tensorMap.find(tensor);
MNN_ASSERT(relatedSessionIter != mNet->tensorMap.end());
((MNN::Session*)relatedSessionIter->second)->setNeedResize();
Interpreter::resizeSession
if (mNet->buffer.get() == nullptr) {
MNN_ERROR("The model buffer has been released. Can't resize session\n");
return;
}
if (session->getNeedResize()) {
session->resize();
}
Interpreter::runSession
return session->run();
Interpreter::runSessionWithCallBack
封装回调函数。
auto beforeWrap = [&before](const std::vector<Tensor*>& tensors, const OperatorInfo* info) {
return before(tensors, info->name());
};
auto afterWrap = [&after](const std::vector<Tensor*>& tensors, const OperatorInfo* info) {
return after(tensors, info->name());
};
return runSessionWithCallBackInfo(session, beforeWrap, afterWrap, sync);
Interpreter::runSessionWithCallBackInfo
callBack
的名字不太一致。
Session::runWithCallBack
return session->runWithCallBack(before, callBack, sync);
Interpreter::releaseModel
mNet->buffer.release();
for (auto& iter : mNet->sessions) {
iter->releaseCache();
}
Interpreter::updateSessionToModel
if (mNet->buffer.get() == nullptr) {
MNN_ERROR("Can't updateSessionToModel because you called releaseModel before\n");
return INPUT_DATA_ERROR;
}
return session->updateToModel((Net*)mNet->net);
Session
Session 拥有多条 Pipeline 并记录输入和输出。
Session::Session
if (info.pipelineInfo.empty()) {
mValid = false;
return;
}
根据每个 Schedule::PipelineInfo 创建 Pipeline。
BackendFactory::create 根据 Backend::Info 创建后端。 mBackends
字典按类型存储了已创建的后端。
Session::_getDefaultBackend 返回 CPUBackend。
ARMv8-A 支持 fp16。
mTensors = info.allTensors;
for (auto& iter : info.pipelineInfo) {
if (mBackends.find(iter.first.type) == mBackends.end()) {
auto newBn = BackendFactory::create(iter.first);
if (nullptr == newBn) {
mValid = false;
return;
}
mBackends[iter.first.type].reset(newBn);
}
auto backend = mBackends.find(iter.first.type)->second.get();
auto cpuBackend = _getDefaultBackend();
#if defined(__aarch64__) && defined(ENABLE_ARMV82)
// choose Arm82Backend only when setting BackendConfig PrecisionMode
// to be Precision_Normal|Precision_Low
auto precisionModeSatisfy = false;
if(iter.first.user){
auto precisionMode = iter.first.user->precision;
if(precisionMode == BackendConfig::Precision_Low){
precisionModeSatisfy = true;
}
}
if (iter.first.type == MNN_FORWARD_CPU && precisionModeSatisfy && cpuBackend->mIsSupportFp16arith) {
// if (iter.first.type == MNN_FORWARD_CPU) { // debug on Mac
// when enable armv82 extension instruction set and forward type is cpu and cpu isa support fp16arith
// activate armv82 backend
// check backend is equal to be cpuBackend
MNN_ASSERT(backend == cpuBackend);
if (mBackends.find(MNN_FORWARD_CPU_EXTENSION) == mBackends.end()) {
Backend::Info bnInfo;
bnInfo.type = MNN_FORWARD_CPU_EXTENSION;
BackendConfig config;
config.sharedContext = static_cast<void*>(cpuBackend);
bnInfo.user = &config;
mBackends[bnInfo.type].reset(BackendFactory::create(bnInfo));
}
backend = mBackends.find(MNN_FORWARD_CPU_EXTENSION)->second.get();
if (backend == nullptr) {
MNN_PRINT("[MNNWarning]: armv82 backend is null\n");
backend = cpuBackend;
}
MNN_PRINT("\n[MNNInfo]:*************set armv82 backend*************\n");
}
#endif
std::shared_ptr<Pipeline> newPipeline(new Pipeline(iter.second, backend, cpuBackend));
mPipelines.emplace_back(std::move(newPipeline));
}
mInputs = info.inputTensors;
mOutputs = info.outputTensor;
Session::_getDefaultBackend
如果MNN_FORWARD_CPU
不存在则创建并加入mBackends
字典中。
auto defaultType = MNN_FORWARD_CPU;
if (mBackends.find(defaultType) == mBackends.end()) {
Backend::Info info;
info.type = defaultType;
info.numThread = 1;
mBackends[info.type].reset(BackendFactory::create(info));
}
auto cpuBackend = mBackends.find(defaultType)->second.get();
return cpuBackend;
Session::~Session
TensorUtils::clearHandleData 清除 Tensor 的主端句柄。
for (auto& t : mTensors) {
TensorUtils::clearHandleData(t.second.get());
}
mPipelines.clear();
mBackends.clear();
mTensors.clear();
Session::resize
Session::_clearCache 清除主端缓存。
Pipeline::Unit::prepare
_clearCache();
for (auto& b : mBackends) {
// avoid library not loaded
if(b.second){
b.second->onClearBuffer();
}
}
for (auto& iter : mPipelines) {
auto error = iter->prepare();
if (NO_ERROR != error) {
return error;
}
}
mNeedResize = false;
for (auto& b : mBackends) {
if(b.second){
b.second->onAllocateBuffer();
}
}
return NO_ERROR;
Session::_clearCache
TensorUtils::getDescribe 返回 Tensor::InsideDescribe 结构体,其记录了张量的布局和句柄数据类型。
TensorUtils::clearHandleData 释放张量中的数据。
for (auto& t : mTensors) {
auto describe = TensorUtils::getDescribe(t.second.get());
TensorUtils::clearHandleData(t.second.get());
describe->useCount = t.first;
describe->backend = nullptr;
}
Session::run
if (mNeedResize) {
MNN_ERROR("Can't run session because not resized\n");
return COMPUTE_SIZE_ERROR;
}
for (auto& iter : mPipelines) {
auto error = iter->execute();
if (NO_ERROR != error) {
return error;
}
}
return NO_ERROR;
Session::runWithCallBack
if (mNeedResize) {
MNN_ERROR("Can't run session because not resized\n");
return COMPUTE_SIZE_ERROR;
}
for (auto& iter : mPipelines) {
auto error = iter->executeCallBack(before, end);
if (NO_ERROR != error) {
return error;
}
}
if (sync) {
for (auto& bn : mBackends) {
if(bn.second){
bn.second->onWaitFinish();
}
}
}
return NO_ERROR;
Session::releaseCache
for (auto& p : mPipelines) {
auto code = p->releaseCache();
if (NO_ERROR != code) {
return code;
}
}
return NO_ERROR;
Session::updateToModel
遍历网络的算子列表
- 网络用于推理时跳过不是常量的算子;
- 用于训练时跳过非可训参数算子;
- 跳过输出数量异常的算子。
int opSize = net->oplists()->size();
for (int i = 0; i < opSize; ++i) {
auto op = net->oplists()->GetAs<Op>(i);
if (net->usage() == Usage_INFERENCE && op->type() != OpType_Const) {
continue;
}
if (net->usage() == Usage_TRAIN && op->type() != OpType_TrainableParam) {
continue;
}
if (!op->outputIndexes() || op->outputIndexes()->size() != 1) {
continue;
}
算子的参数为 Blob,跳过非浮点的算子。
auto index = op->outputIndexes()->data()[0];
auto blob = op->main_as_Blob();
if (blob->dataType() != DataType_DT_FLOAT) {
continue;
}
Tensor::createHostTensorFromDevice 在主端创建张量。
将 Blob 复制到 Tensor。
std::shared_ptr<Tensor> tensor = mTensors[index].second;
if (tensor->host<void>() == nullptr && tensor->deviceId() != 0) {
tensor.reset(Tensor::createHostTensorFromDevice(tensor.get(), true));
if (tensor.get() == nullptr) {
MNN_ERROR("failed to copy trained param from device to host\n");
return INVALID_VALUE;
}
}
::memcpy((void*)blob->float32s()->data(), tensor->host<float>(), tensor->size());
}
return NO_ERROR;