欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

机器学习框架ML.NET学习笔记【5】手写数字识别(续)

程序员文章站 2022-03-06 21:45:46
一、概述 上一篇文章我们利用ML.NET的多元分类算法实现了一个手写数字识别的例子,这个例子存在一个问题,就是输入的数据是预处理过的,很不直观,这次我们要直接通过图片来进行学习和判断。思路很简单,就是写一个自定义的数据处理通道,输入为文件名,输出为float数字,里面保存的是像素信息。 样本包括6万 ......

一、概述

 上一篇文章我们利用ml.net的多元分类算法实现了一个手写数字识别的例子,这个例子存在一个问题,就是输入的数据是预处理过的,很不直观,这次我们要直接通过图片来进行学习和判断。思路很简单,就是写一个自定义的数据处理通道,输入为文件名,输出为float数字,里面保存的是像素信息。

 样本包括6万张训练图片和1万张测试图片,图片为灰度图片,分辨率为20*20 。train_tags.tsv文件对每个图片的数值进行了标记,如下:

机器学习框架ML.NET学习笔记【5】手写数字识别(续)

  

二、源码

 全部代码: 

namespace multiclassclassification_mnist
{
    class program
    {
        //assets files download from:https://gitee.com/seabluescn/ml_assets
        static readonly string assetsfolder = @"d:\stepbystep\blogs\ml_assets\mnist";
        static readonly string traintagspath = path.combine(assetsfolder, "train_tags.tsv");
        static readonly string traindatafolder = path.combine(assetsfolder, "train");
        static readonly string modelpath = path.combine(environment.currentdirectory, "data", "sdca-model.zip");

        static void main(string[] args)
        {
            mlcontext mlcontext = new mlcontext(seed: 1);
          
            trainandsavemodel(mlcontext);
            testsomepredictions(mlcontext);

            console.writeline("hit any key to finish the app");
            console.readkey();
        }

        public static void trainandsavemodel(mlcontext mlcontext)
        {
            // step 1: 准备数据
            var fulldata = mlcontext.data.loadfromtextfile<inputdata>(path: traintagspath, separatorchar: '\t', hasheader: false);
            var traintestdata = mlcontext.data.traintestsplit(fulldata, testfraction: 0.1);
            var traindata = traintestdata.trainset;
            var testdata = traintestdata.testset;

            // step 2: 配置数据处理管道        
            var dataprocesspipeline = mlcontext.transforms.custommapping(new loadimageconversion().getmapping(), contractname: "loadimageconversionaction")
               .append(mlcontext.transforms.conversion.mapvaluetokey("label", "number", keyordinality: valuetokeymappingestimator.keyordinality.byvalue))
               .append(mlcontext.transforms.normalizemeanvariance( outputcolumnname: "featuresnormalizedbymeanvar", inputcolumnname: "imagepixels"));


            // step 3: 配置训练算法 (using a maximum entropy classification model trained with the l-bfgs method)
            var trainer = mlcontext.multiclassclassification.trainers.lbfgsmaximumentropy(labelcolumnname: "label", featurecolumnname: "featuresnormalizedbymeanvar");
            var trainingpipeline = dataprocesspipeline.append(trainer)
                 .append(mlcontext.transforms.conversion.mapkeytovalue("predictnumber", "label"));


            // step 4: 训练模型使其与数据集拟合           
            itransformer trainedmodel = trainingpipeline.fit(traindata);          

            // step 5:评估模型的准确性           
            var predictions = trainedmodel.transform(testdata);
            var metrics = mlcontext.multiclassclassification.evaluate(data: predictions, labelcolumnname: "label", scorecolumnname: "score");
            printmulticlassclassificationmetrics(trainer.tostring(), metrics);
          
            // step 6:保存模型            
            mlcontext.model.save(trainedmodel, traindata.schema, modelpath);           
        }

        private static void testsomepredictions(mlcontext mlcontext)
        {
            // load model           
            itransformer trainedmodel = mlcontext.model.load(modelpath, out var modelinputschema);

            // create prediction engine 
            var predengine = mlcontext.model.createpredictionengine<inputdata, outputdata>(trainedmodel);
          
            directoryinfo testfolder = new directoryinfo(path.combine(assetsfolder, "test"));           
            foreach(var image in testfolder.getfiles())
            {
                count++;

                inputdata img = new inputdata()
                {
                    filename = image.name
                };
                var result = predengine.predict(img);
               
                console.writeline($"current source={img.filename},predictresult={result.getpredictresult()}");                
            }
        }       
    }

    class inputdata
    {
        [loadcolumn(0)]
        public string filename;

        [loadcolumn(1)]
        public string number;

        [loadcolumn(1)]
        public float serial;       
    }

    class outputdata : inputdata
    {
        public float[] score;
        public int getpredictresult()
        {
            float max = 0;
            int index = 0;
            for (int i = 0; i < score.length; i++)
            {
                if (score[i] > max)
                {
                    max = score[i];
                    index = i;
                }
            }
            return index;
        }       
    }   
}

  

三、分析

 整个处理流程和上一篇文章基本一致,这里解释两个不一样的地方。

1、自定义的图片读取处理通道

namespace multiclassclassification_mnist
{
    public class loadimageconversioninput
    {
        public string  filename { get; set; }
    }
 
    public class loadimageconversionoutput
    {
        [vectortype(400)]
        public float[] imagepixels { get; set; }
        public string imagepath;
    }

    [custommappingfactoryattribute("loadimageconversionaction")]
    public class loadimageconversion : custommappingfactory<loadimageconversioninput, loadimageconversionoutput>
    {       
        static readonly string traindatafolder = @"d:\stepbystep\blogs\ml_assets\mnist\train";

        public void customaction(loadimageconversioninput input, loadimageconversionoutput output)
        {  
            string imagepath = path.combine(traindatafolder, input.filename);
            output.imagepath = imagepath;

            bitmap bmp = image.fromfile(imagepath) as bitmap;           

            output.imagepixels = new float[400];
            for (int x = 0; x < 20; x++)
                for (int y = 0; y < 20; y++)
                {
                    var pixel = bmp.getpixel(x, y);
                    var gray = (pixel.r + pixel.g + pixel.b) / 3 / 16;
                    output.imagepixels[x + y * 20] = gray;
                }           
            bmp.dispose();                     
        }

        public override action<loadimageconversioninput, loadimageconversionoutput> getmapping()
              => customaction;
    }
}

 这里可以看出,我们自定义的数据处理通道,输入为文件名称,输出是一个float数组,这里数组必须要指定宽度,由于图片分辨率为20*20,所以数组宽度指定为400,输出imagepath为文件详细地址,用来调试使用,没有实际用途。处理思路非常简单,遍历每个pixel,计算其灰度值,为了减少工作量我们把灰度值进行缩小,除以了16 ,由于后面数据会做归一化,所以这里影响不是太明显。

 

2、模型测试

            directoryinfo testfolder = new directoryinfo(path.combine(assetsfolder, "test"));
            int count = 0;
            int success = 0;
            foreach(var image in testfolder.getfiles())
            {
                count++;

                inputdata img = new inputdata()
                {
                    filename = image.name
                };
                var result = predengine.predict(img);

                if(int.parse(image.name.substring(0,1))==result.getpredictresult())
                {
                    success++;
                }                
            }

 我们把测试目录里的全面图片读出遍历了一遍,将其测试结果和实际结果做了一次验证,实际上是把评估(evaluate)的事情又重复做了一次,两次测试的成功率基本接近。

 

四、关于图片特征提取

我们是采用图片所有像素的灰度值来作为特征值的,但必须要强调的是:像素值矩阵不是图片的典型特征。虽然有时候对于较规则的图片,通过像素提取方式进行计算,也可以取得很好的效果,但在处理稍微复杂一点的图片的时候,就不管用了,原因很明显,我们人类在分析图片内容时看到的特征更多是线条等信息,绝对不是像素值,看下图:

机器学习框架ML.NET学习笔记【5】手写数字识别(续)

我们人类很容易就判断出这两个图片表达的是同一件事情,但其像素值特征却相差甚远。

 传统的图片特征提取方式很多,比如:sift、hog、lbp、haar等。 现在采用tensorflow的模型进行特征提取效果非常好。下一篇文章介绍图片分类时再进行详细介绍。 

 

五、资源获取

源码下载地址:https://github.com/seabluescn/study_ml.net

工程名称:multiclassclassification_mnist_useful

mnist资源获取:https://gitee.com/seabluescn/ml_assets

点击查看机器学习框架ml.net学习笔记系列文章目录