欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

机器学习框架ML.NET学习笔记【8】目标检测(采用YOLO2模型)

程序员文章站 2022-03-06 21:17:10
一、概述 本篇文章介绍通过YOLO模型进行目标识别的应用,原始代码来源于:https://github.com/dotnet/machinelearning-samples 实现的功能是输入一张图片,对图片中的目标进行识别,输出结果在图片中通过红色框线标记出来。如下: YOLO简介 YOLO(You ......

一、概述

本篇文章介绍通过yolo模型进行目标识别的应用,原始代码来源于:https://github.com/dotnet/machinelearning-samples

实现的功能是输入一张图片,对图片中的目标进行识别,输出结果在图片中通过红色框线标记出来。如下:

机器学习框架ML.NET学习笔记【8】目标检测(采用YOLO2模型)

 yolo简介

 yolo(you only look once)是一种最先进的实时目标检测系统。官方网站:https://pjreddie.com/darknet/yolo/

 本文采用的是tinyyolo2模型,可以识别的目标类型包括:"aeroplane", "bicycle", "bird", "boat", "bottle","bus", "car", "cat", "chair", "cow","diningtable", "dog", "horse", "motorbike", "person","pottedplant", "sheep", "sofa", "train", "tvmonitor" 。 

onnx简介

onnx 即open neural network exchange(开放神经网络交换格式),是一个用于表示深度学习模型的通用标准,可使模型在不同框架之间进行互相访问,其规范及代码主要由微软,亚马逊 ,facebook 和 ibm 等公司共同制定与开发。有了onnx标准,我们就可以在ml.net代码中使用通过其他机器学习框架训练并保存的模型。

 

二、代码分析

 1、main方法

        static void main(string[] args)
        {
            trainandsave();
            loadandpredict();

            console.writeline("press any key to exit!");
            console.readkey();
        }

 第一次运行时需要运行trainandsave方法,生成本地模型后,可以直接运行生产代码。

 

2、训练并保存模型

     static readonly string tagstsv = path.combine(trainimagesfolder,  "tags.tsv");       
     private static void trainandsave() { var mlcontext = new mlcontext(); var traindata = mlcontext.data.loadfromtextfile<imagenetdata>(tagstsv); var pipeline = mlcontext.transforms.loadimages(outputcolumnname: "image", imagefolder: trainimagesfolder, inputcolumnname: nameof(imagenetdata.imagepath)) .append(mlcontext.transforms.resizeimages(outputcolumnname: "image", imagewidth: imagenetsettings.imagewidth, imageheight: imagenetsettings.imageheight, inputcolumnname: "image")) .append(mlcontext.transforms.extractpixels(outputcolumnname: "image")) .append(mlcontext.transforms.applyonnxmodel(modelfile: yolo_modelfilepath, outputcolumnnames: new[] { tinyyolomodelsettings.modeloutput }, inputcolumnnames: new[] { tinyyolomodelsettings.modelinput })); var model = pipeline.fit(traindata); using (var file = file.openwrite(objectdetectionmodelfilepath)) mlcontext.model.save(model, traindata.schema, file); console.writeline("save model success!"); }

  imagenetdata类定义如下:

    public class imagenetdata
    {
        [loadcolumn(0)]
        public string imagepath;

        [loadcolumn(1)]
        public string label;
    }

tags.tsv文件中仅包含一条样本数据,因为模型已经训练好,不存在再次训练的意义。这里只要放一张图片样本即可,通过fit方法建立数据处理通道模型。

applyonnxmodel方法加载第三方onnx模型,

    public struct tinyyolomodelsettings
    {
        // input tensor name
        public const string modelinput = "image";

        // output tensor name
        public const string modeloutput = "grid";
    }

 其中,输入、输出的列名称是指定的。可以通过安装netron这样的工具来查询onnx文件的详细信息,可以看到输入输出的数据列名称。


3、应用
        private static void loadandpredict()
        {
            var mlcontext = new mlcontext();

            itransformer trainedmodel;
            using (var stream = file.openread(objectdetectionmodelfilepath))
            {
                trainedmodel = mlcontext.model.load(stream, out var modelinputschema);               
            }
            var predictionengine = mlcontext.model.createpredictionengine<imagenetdata, imagenetprediction>(trainedmodel);

            directoryinfo testdir = new directoryinfo(testimagesfolder);
            foreach (var jpgfile in testdir.getfiles("*.jpg"))
            {  
                imagenetdata image = new imagenetdata
                {
                    imagepath = jpgfile.fullname
                };               
var predicted = predictionengine.predict(image); predictimage(image.imagepath, predicted); } }
代码遍历一个文件夹下面的jpg文件。对每一个文件进行转换,获得预测结果。
imagenetprediction类定义如下:
    public class imagenetprediction
    {
        [columnname(tinyyolomodelsettings.modeloutput)]
        public float[] predictedlabels;       
    }

 输出的“grid”列数据是一个float数组,不能直接理解其含义,所以需要通过代码将其数据转换为便于理解的格式。

     yolowinmlparser _parser = new yolowinmlparser();
     ilist<yoloboundingbox> boundingboxes = _parser.parseoutputs(predicted.predictedlabels, 0.4f);            

yolowinmlparser.parseoutputs方法将float数组转为yoloboundingbox对象的列表,第二个参数是可信度阙值,只输出大于该可信度的数据。

yoloboundingbox类定义如下:

    class yoloboundingbox
    {    
        public string label { get; set; }
        public float confidence { get; set; }

        public float x { get; set; }
        public float y { get; set; }
        public float height { get; set; }
        public float width { get; set; }
        public rectanglef rect
        {
            get { return new rectanglef(x, y, width, height); }
        }
    }

 其中:label为目标类型,confidence为可行程度。

由于yolo的特点导致对同一个目标会输出多个同样的检测结果,所以还需要对检测结果进行过滤,去掉那些高度重合的结果。

     yolowinmlparser _parser = new yolowinmlparser();
     ilist<yoloboundingbox> boundingboxes = _parser.parseoutputs(predicted.predictedlabels, 0.4f); 
     var filteredboxes = _parser.nonmaxsuppress(boundingboxes, 5, 0.6f);

 yolowinmlparser.nonmaxsuppress第二个参数表示最多保留多少个结果,第三个参数表示重合率阙值,将去掉重合率大于该值的记录。

 

四、资源获取 

源码下载地址:https://github.com/seabluescn/study_ml.net

工程名称:yolo_objectdetection

资源获取:https://gitee.com/seabluescn/ml_assets (objectdetection)

点击查看机器学习框架ml.net学习笔记系列文章目录