欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

腾讯开源框架TNN调用过程

程序员文章站 2022-04-12 08:48:28
第一步:模型转换,按照github一步一步来就ok了~此处无坑第二步:cmake建立vs工程,需要在cmakelist里面需要使用的accelerator,否则在getdevice会返回NULL值第三步:调用#include "tnn/utils/dims_vector_utils.h"#include "tnn/utils/blob_transfer_utils.h"#include "tnn/core/tnn.h"#include static voi...

第一步:模型转换,按照github一步一步来就ok了~此处无坑

第二步:cmake建立vs工程,需要在cmakelist里面需要使用的accelerator,否则在getdevice会返回NULL值

第三步:调用

#include "tnn/utils/dims_vector_utils.h"
#include "tnn/utils/blob_transfer_utils.h"
#include "tnn/core/tnn.h"
#include <fstream>
static void ModifynhwcTonchw(float* dst, float* src,
	int batch, int channel,
	int height, int width)
{
	for (int n = 0; n < batch; n++)
	{
		for (int c = 0; c < channel; c++)
		{
			for (int h = 0; h < height; h++)
			{
				for (int w = 0; w < width; w++)
				{
					dst[n*height*width*channel + c*height*width + h*width + w] =
						src[n*height*width*channel + h*width*channel + w*channel + c];
				}
			}
		}
	}
}

void CopyDataToDeviceFromFile(TNN_NS::BlobMap blob_map,std::string input_file,void* command_queue)
{
	//get input_blob info
	std::string input_name = (blob_map.begin())->first;
	TNN_NS::Blob* device_blob = (blob_map.begin())->second;
	TNN_NS::BlobConverter blob_converter(device_blob);
	TNN_NS::BlobDesc blob_desc = device_blob->GetBlobDesc();
	
	//get input data
	TNN_NS::BlobHandle data_handle;
	int data_count = TNN_NS::DimsVectorUtils::Count(blob_desc.dims);
	float* input_data = (float*)malloc(data_count * sizeof(float));
	FILE* fp = fopen(input_file.data(), "rb");
	if (fp == NULL)
	{
		printf("CopyDataToDeviceFromFile Err,read input file failed: %s\n",input_file.data());
	}
	fread(input_data, data_count, sizeof(float), fp);
	fclose(fp);

	//if necessary
	if (1)
	{
		float* trans_data = (float*)malloc(data_count * sizeof(float));
		ModifynhwcTonchw(trans_data, input_data,
			blob_desc.dims[0], blob_desc.dims[1],
			blob_desc.dims[2], blob_desc.dims[3]);
		free(input_data);
		input_data = trans_data;
	}

	data_handle.base = input_data;
	data_handle.bytes_offset = 0;

	//convert
	TNN_NS::Blob data_blob(blob_desc,data_handle);
	TNN_NS::CopyToDevice(device_blob, &data_blob, command_queue);

	free(input_data);
}

void CopyDataFromDevicveToFile(TNN_NS::BlobMap blob_map, std::string out_file, void* command_queue)
{
	//get output info
	TNN_NS::Blob* device_blob = (blob_map.begin())->second;
	TNN_NS::BlobConverter blob_converter_out(device_blob);
	TNN_NS::BlobDesc blob_desc = device_blob->GetBlobDesc();
	int data_count = TNN_NS::DimsVectorUtils::Count(blob_desc.dims);

	//get input data
	TNN_NS::BlobHandle data_handle;
	float* input_data = (float*)malloc(data_count * sizeof(float));
	data_handle.base = input_data;
	data_handle.bytes_offset = 0;

	//convert
	TNN_NS::Blob data_blob(blob_desc, data_handle);
	TNN_NS::CopyFromDevice(&data_blob, device_blob, command_queue);

	//write file
	FILE *fp = fopen(out_file.data(),"w");
	for (int i = 0; i < data_count; i++)
	{
		fprintf(fp, "%f\n", input_data[i]);
	}
	fclose(fp);
	
	free(input_data);
}

int main()
{
	std::string model_name = "test.opt.tnnmodel";
	std::string bin_name = "test.opt.tnnproto";
	std::string input_file = "input.txt";
	std::string output_file = "data.txt";

	TNN_NS::NetworkConfig myNet;
	TNN_NS::ModelConfig myModel;
	myModel.model_type = TNN_NS::MODEL_TYPE_TNN;
	myNet.device_type = TNN_NS::DEVICE_NAIVE;
	myNet.data_format = TNN_NS::DATA_FORMAT_NCHW;

	//read proto first
	std::ifstream proto_stream(bin_name);
	if (!proto_stream.is_open() || !proto_stream.good()) {
		printf("read proto_file failed!\n");
	}
	auto buffer =
		std::string((std::istreambuf_iterator<char>(proto_stream)), std::istreambuf_iterator<char>());
	myModel.params.push_back(buffer);

	//read model bin
	std::ifstream model_stream(model_name, std::ios::binary);
	if (!model_stream.is_open() || !model_stream.good()) {
		myModel.params.push_back("");
	}
	auto model_content =
		std::string((std::istreambuf_iterator<char>(model_stream)), std::istreambuf_iterator<char>());

	myModel.params.push_back(model_content);

	//Init
	TNN_NS::TNN net;
	TNN_NS::Status ret = net.Init(myModel);
	TNN_NS::InputShapesMap input_shape;
	auto instance = net.CreateInst(myNet, ret);

	TNN_NS::BlobMap input_blob_maps;
	TNN_NS::BlobMap output_blob_maps;
	void* command_queue;
	instance->GetAllInputBlobs(input_blob_maps);
	instance->GetAllOutputBlobs(output_blob_maps);
	instance->GetCommandQueue(&command_queue);

	CopyDataToDeviceFromFile(input_blob_maps,input_file, command_queue);

	ret = instance->Forward();

	CopyDataFromDevicveToFile(output_blob_maps, output_file, command_queue);

	ret = net.DeInit();
}

 

本文地址:https://blog.csdn.net/azheng_wen/article/details/107489859

相关标签: TNN