欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

MNN Interpreter and Session

程序员文章站 2022-06-15 13:50:29
...

MNN 中 InterpreterSession 之间的关系如下图所示:

Interpreter
Session0
Session1
Pipeline0
Pipeline1
Pipeline2
Pipeline3
Pipeline4
Net

Interpreter

Interpreter 由模型文件构建,不允许复制或拷贝构造。

Interpreter::createFromFile

Created with Raphaël 2.2.0 Interpreter::createFromFile file FileLoader FileLoader::valid Yes or No? FileLoader::read Content FileLoader::merge createFromBufferInternal End yes no

Interpreter::createFromFile 首先借助 FileLoader 分块读取文件内容,合并后由 Interpreter::createFromBufferInternal 创建一个 Interpreter 对象。
FileLoader::valid 判断文件句柄是否为空。
FileLoader::read 分块读取。
Content 包含读取的缓存以及由此构建出的网络、会话和张量信息。
FileLoader::merge 合并读取的结果到给定的缓存。

    if (nullptr == file) {
        MNN_PRINT("NULL file for create interpreter\n");
        return nullptr;
    }
    std::unique_ptr<FileLoader> loader(new FileLoader(file));
    if (!loader->valid()) {
        MNN_PRINT("Create interpreter failed, open %s error\n", file);
        return nullptr;
    }
    bool result = loader->read();
    if (!result) {
        MNN_PRINT("Read file error\n");
        return nullptr;
    }
    if (loader->size() == 0) {
        MNN_PRINT("Create interpreter failed, %s is empty\n", file);
        return nullptr;
    }
    auto net     = new Content;
    bool success = loader->merge(net->buffer);
    if (!success) {
        return nullptr;
    }
    loader.reset();
    return createFromBufferInternal(net);

Interpreter::createFromBufferInternal

Interpreter::createFromBufferInternal
Interpreter::Interpreter

Verifier 检查缓存完整性。

    if (nullptr == net) {
        MNN_PRINT("Buffer is null for create interpreter\n");
        return nullptr;
    }
    flatbuffers::Verifier verify((const uint8_t*)(net->buffer.get()), net->buffer.size());
    if (false == VerifyNetBuffer(verify)) {
        MNN_PRINT("Invalidate buffer to create interpreter\n");
        delete net;
        return nullptr;
    }

从缓存中获取定义的 Net,检查其中的 op 列表。

    net->net = GetNet(net->buffer.get());
    if (nullptr == net->net->oplists()) {
        MNN_ERROR("Model has no oplist\n");
        delete net;
        return nullptr;
    }
    int opSize = net->net->oplists()->size();
    for (int i=0; i<opSize; ++i) {
        auto op = net->net->oplists()->GetAs<Op>(i);
        if (nullptr == op || nullptr == op->outputIndexes()) {
            MNN_ERROR("Invalid Model, the %d op is empty\n", i);
            delete net;
            return nullptr;
        }
    }

创建一个 Interpreter 对象,使用 create 函数的好处是及时检查错误并返回,但不能返回错误码。

    return new Interpreter(net);
}

Interpreter::createSession

Interpreter::createSession
Interpreter::createMultiPathSession

根据 ScheduleConfig 创建会话。

    return createMultiPathSession({config});

Interpreter::createMultiPathSession

Created with Raphaël 2.2.0 Interpreter::createMultiPathSession configs Schedule::schedule Session Session::valid Yes or No? validForResize? Session::resize newSession End yes no yes no

检查网络的缓存是否已释放。

    if (nullptr == mNet->buffer.get()) {
        MNN_ERROR("The model buffer has been released. Can't create session\n");
        return nullptr;
    }

Schedule::schedule 函数由网络和配置创建 Schedule::ScheduleInfo 结构体,进而根据该结果创建 Session
Session::resize 将创建的 Session 追加到 Interpreter 存储的队列中。

    auto info       = Schedule::schedule(mNet->net, configs);
    auto newSession = std::unique_ptr<Session>(new Session(info));
    if (!newSession->valid()) {
        MNN_PRINT("Invalide Session!!\n");
        return nullptr;
    }
    auto result = newSession.get();
    if (info.validForResize) {
        result->resize();
    }
    mNet->sessions.emplace_back(std::move(newSession));
    return result;

Interpreter::getSessionInput

Interpreter::getSessionInput
Session::getInput

Session::getInput 根据名字从输入字典中找到张量。
将张量与 Session 的对应关系记录到tensorMap中。

    MNN_ASSERT(nullptr != session);
    if (session == nullptr) {
        return nullptr;
    }
    auto tensor = session->getInput(name);
    mNet->tensorMap.insert(std::make_pair(tensor, session));
    return tensor;

Interpreter::resizeTensor

维度检查。

    MNN_ASSERT(nullptr != tensor);
    bool dirty = false;
    if (tensor->buffer().dimensions != dims.size()) {
        dirty = true;
    } else {
        for (int i = 0; i < dims.size(); ++i) {
            if (tensor->buffer().dim[i].extent != dims[i]) {
                dirty = true;
                break;
            }
        }
    }

    if (!dirty) {
        return;
    }

设置张量的维度。

    tensor->buffer().dimensions = (int)dims.size();
    for (int i = 0; i < dims.size(); ++i) {
        tensor->buffer().dim[i].extent = dims[i];
    }

设置该张量相关 Session 的状态。
Session::setNeedResize

    auto relatedSessionIter = mNet->tensorMap.find(tensor);
    MNN_ASSERT(relatedSessionIter != mNet->tensorMap.end());
    ((MNN::Session*)relatedSessionIter->second)->setNeedResize();

Interpreter::resizeSession

Interpreter::resizeSession
Session::resize

Session::resize

    if (mNet->buffer.get() == nullptr) {
        MNN_ERROR("The model buffer has been released. Can't resize session\n");
        return;
    }
    if (session->getNeedResize()) {
        session->resize();
    }

Interpreter::runSession

Interpreter::runSession
Session::run

Session::run

return session->run();

Interpreter::runSessionWithCallBack

Interpreter::runSessionWithCallBack
Interpreter::runSessionWithCallBackInfo

封装回调函数。

    auto beforeWrap = [&before](const std::vector<Tensor*>& tensors, const OperatorInfo* info) {
        return before(tensors, info->name());
    };
    auto afterWrap = [&after](const std::vector<Tensor*>& tensors, const OperatorInfo* info) {
        return after(tensors, info->name());
    };
    return runSessionWithCallBackInfo(session, beforeWrap, afterWrap, sync);

Interpreter::runSessionWithCallBackInfo

Interpreter::runSessionWithCallBackInfo
Session::runWithCallBack

callBack的名字不太一致。
Session::runWithCallBack

    return session->runWithCallBack(before, callBack, sync);

Interpreter::releaseModel

Session::releaseCache

    mNet->buffer.release();
    for (auto& iter : mNet->sessions) {
        iter->releaseCache();
    }

Interpreter::updateSessionToModel

Interpreter::updateSessionToModel
Session::updateToModel

Session::updateToModel

    if (mNet->buffer.get() == nullptr) {
        MNN_ERROR("Can't updateSessionToModel because you called releaseModel before\n");
        return INPUT_DATA_ERROR;
    }
    return session->updateToModel((Net*)mNet->net);

Session

Session 拥有多条 Pipeline 并记录输入和输出。

Session::Session

Session::Session
BackendFactory::create
Session::_getDefaultBackend
Pipeline::Pipeline
    if (info.pipelineInfo.empty()) {
        mValid = false;
        return;
    }

根据每个 Schedule::PipelineInfo 创建 Pipeline
BackendFactory::create 根据 Backend::Info 创建后端。 mBackends字典按类型存储了已创建的后端。
Session::_getDefaultBackend 返回 CPUBackend
ARMv8-A 支持 fp16。

    mTensors = info.allTensors;
    for (auto& iter : info.pipelineInfo) {
        if (mBackends.find(iter.first.type) == mBackends.end()) {
            auto newBn = BackendFactory::create(iter.first);
            if (nullptr == newBn) {
                mValid = false;
                return;
            }
            mBackends[iter.first.type].reset(newBn);
        }
        auto backend    = mBackends.find(iter.first.type)->second.get();
        auto cpuBackend = _getDefaultBackend();

#if defined(__aarch64__) && defined(ENABLE_ARMV82)
        // choose Arm82Backend only when setting BackendConfig PrecisionMode
        // to be Precision_Normal|Precision_Low
        auto precisionModeSatisfy = false;
        if(iter.first.user){
            auto precisionMode = iter.first.user->precision;
            if(precisionMode == BackendConfig::Precision_Low){
                precisionModeSatisfy = true;
            }
        }
        if (iter.first.type == MNN_FORWARD_CPU && precisionModeSatisfy && cpuBackend->mIsSupportFp16arith) {
        // if (iter.first.type == MNN_FORWARD_CPU) { // debug on Mac
            // when enable armv82 extension instruction set and forward type is cpu and cpu isa support fp16arith
            // activate armv82 backend
            // check backend is equal to be cpuBackend
            MNN_ASSERT(backend == cpuBackend);
            if (mBackends.find(MNN_FORWARD_CPU_EXTENSION) == mBackends.end()) {
                Backend::Info bnInfo;
                bnInfo.type = MNN_FORWARD_CPU_EXTENSION;
                BackendConfig config;
                config.sharedContext = static_cast<void*>(cpuBackend);
                bnInfo.user          = &config;
                mBackends[bnInfo.type].reset(BackendFactory::create(bnInfo));
            }
            backend = mBackends.find(MNN_FORWARD_CPU_EXTENSION)->second.get();
            if (backend == nullptr) {
                MNN_PRINT("[MNNWarning]: armv82 backend is null\n");
                backend = cpuBackend;
            }
            MNN_PRINT("\n[MNNInfo]:*************set armv82 backend*************\n");
        }
#endif
        std::shared_ptr<Pipeline> newPipeline(new Pipeline(iter.second, backend, cpuBackend));
        mPipelines.emplace_back(std::move(newPipeline));
    }
    mInputs  = info.inputTensors;
    mOutputs = info.outputTensor;

Session::_getDefaultBackend

如果MNN_FORWARD_CPU不存在则创建并加入mBackends字典中。

    auto defaultType = MNN_FORWARD_CPU;
    if (mBackends.find(defaultType) == mBackends.end()) {
        Backend::Info info;
        info.type      = defaultType;
        info.numThread = 1;
        mBackends[info.type].reset(BackendFactory::create(info));
    }
    auto cpuBackend = mBackends.find(defaultType)->second.get();
    return cpuBackend;

Session::~Session

TensorUtils::clearHandleData 清除 Tensor 的主端句柄。

    for (auto& t : mTensors) {
        TensorUtils::clearHandleData(t.second.get());
    }
    mPipelines.clear();
    mBackends.clear();
    mTensors.clear();

Session::resize

Session::_clearCache 清除主端缓存。
Pipeline::Unit::prepare

    _clearCache();
    for (auto& b : mBackends) {
        // avoid library not loaded
        if(b.second){
            b.second->onClearBuffer();
        }
    }

    for (auto& iter : mPipelines) {
        auto error = iter->prepare();
        if (NO_ERROR != error) {
            return error;
        }
    }
    mNeedResize = false;
    for (auto& b : mBackends) {
        if(b.second){
            b.second->onAllocateBuffer();
        }
    }

    return NO_ERROR;

Session::_clearCache

TensorUtils::getDescribe 返回 Tensor::InsideDescribe 结构体,其记录了张量的布局和句柄数据类型。
TensorUtils::clearHandleData 释放张量中的数据。

    for (auto& t : mTensors) {
        auto describe = TensorUtils::getDescribe(t.second.get());
        TensorUtils::clearHandleData(t.second.get());
        describe->useCount = t.first;
        describe->backend  = nullptr;
    }

Session::run

Session::run
Pipeline::execute

Pipeline::execute

    if (mNeedResize) {
        MNN_ERROR("Can't run session because not resized\n");
        return COMPUTE_SIZE_ERROR;
    }
    for (auto& iter : mPipelines) {
        auto error = iter->execute();
        if (NO_ERROR != error) {
            return error;
        }
    }
    return NO_ERROR;

Session::runWithCallBack

Session::runWithCallBack
Pipeline::executeCallBack
Backend::onWaitFinish

Pipeline::executeCallBack

    if (mNeedResize) {
        MNN_ERROR("Can't run session because not resized\n");
        return COMPUTE_SIZE_ERROR;
    }
    for (auto& iter : mPipelines) {
        auto error = iter->executeCallBack(before, end);
        if (NO_ERROR != error) {
            return error;
        }
    }
    if (sync) {
        for (auto& bn : mBackends) {
            if(bn.second){
                bn.second->onWaitFinish();
            }
        }
    }
    return NO_ERROR;

Session::releaseCache

Pipeline::releaseCache

    for (auto& p : mPipelines) {
        auto code = p->releaseCache();
        if (NO_ERROR != code) {
            return code;
        }
    }
    return NO_ERROR;

Session::updateToModel

遍历网络的算子列表

  • 网络用于推理时跳过不是常量的算子;
  • 用于训练时跳过非可训参数算子;
  • 跳过输出数量异常的算子。
    int opSize = net->oplists()->size();
    for (int i = 0; i < opSize; ++i) {
        auto op = net->oplists()->GetAs<Op>(i);
        if (net->usage() == Usage_INFERENCE && op->type() != OpType_Const) {
            continue;
        }
        if (net->usage() == Usage_TRAIN && op->type() != OpType_TrainableParam) {
            continue;
        }
        if (!op->outputIndexes() || op->outputIndexes()->size() != 1) {
            continue;
        }

算子的参数为 Blob,跳过非浮点的算子。

        auto index = op->outputIndexes()->data()[0];
        auto blob  = op->main_as_Blob();
        if (blob->dataType() != DataType_DT_FLOAT) {
            continue;
        }

Tensor::createHostTensorFromDevice 在主端创建张量。
Blob 复制到 Tensor

        std::shared_ptr<Tensor> tensor = mTensors[index].second;
        if (tensor->host<void>() == nullptr && tensor->deviceId() != 0) {
            tensor.reset(Tensor::createHostTensorFromDevice(tensor.get(), true));
            if (tensor.get() == nullptr) {
                MNN_ERROR("failed to copy trained param from device to host\n");
                return INVALID_VALUE;
            }
        }
        ::memcpy((void*)blob->float32s()->data(), tensor->host<float>(), tensor->size());
    }

    return NO_ERROR;

参考资料: