欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

机器学习框架ML.NET学习笔记【4】多元分类之手写数字识别

程序员文章站 2022-03-20 10:13:44
一、问题与解决方案 通过多元分类算法进行手写数字识别,手写数字的图片分辨率为8*8的灰度图片、已经预先进行过处理,读取了各像素点的灰度值,并进行了标记。 其中第0列是序号(不参与运算)、1-64列是像素值、65列是结果。 我们以64位像素值为特征进行多元分类,算法采用SDCA最大熵分类算法。 二、源 ......

一、问题与解决方案

通过多元分类算法进行手写数字识别,手写数字的图片分辨率为8*8的灰度图片、已经预先进行过处理,读取了各像素点的灰度值,并进行了标记。

机器学习框架ML.NET学习笔记【4】多元分类之手写数字识别

其中第0列是序号(不参与运算)、1-64列是像素值、65列是结果。

我们以64位像素值为特征进行多元分类,算法采用sdca最大熵分类算法。

 

二、源码

 先贴出全部代码:

namespace multiclassclassification_mnist
{
    class program
    {
        static readonly string traindatapath = path.combine(environment.currentdirectory, "data", "optdigits-full.csv");
        static readonly string modelpath = path.combine(environment.currentdirectory, "data", "sdca-model.zip");

        static void main(string[] args)
        {
            mlcontext mlcontext = new mlcontext(seed: 1);
          
            trainandsavemodel(mlcontext);
            testsomepredictions(mlcontext);

            console.writeline("hit any key to finish the app");
            console.readkey();
        }
              

        public static void trainandsavemodel(mlcontext mlcontext)
        {
            // step 1: 准备数据
            var fulldata = mlcontext.data.loadfromtextfile(path: traindatapath,
                    columns: new[]
                    {
                        new textloader.column("serial", datakind.single, 0),
                        new textloader.column("pixelvalues", datakind.single, 1, 64),
                        new textloader.column("number", datakind.single, 65)
                    },
                    hasheader: true,
                    separatorchar: ','
                    );

            var traintestdata = mlcontext.data.traintestsplit(fulldata, testfraction: 0.2);
            var traindata = traintestdata.trainset;
            var testdata = traintestdata.testset;

            // step 2: 配置数据处理管道        
            var dataprocesspipeline = mlcontext.transforms.conversion.mapvaluetokey("label", "number", keyordinality: valuetokeymappingestimator.keyordinality.byvalue);

            // step 3: 配置训练算法
            var trainer = mlcontext.multiclassclassification.trainers.sdcamaximumentropy(labelcolumnname: "label", featurecolumnname: "pixelvalues");
            var trainingpipeline = dataprocesspipeline.append(trainer)
              .append(mlcontext.transforms.conversion.mapkeytovalue("number", "label"));
            
            // step 4: 训练模型使其与数据集拟合
            console.writeline("=============== train the model fitting to the dataset ===============");           

            itransformer trainedmodel = trainingpipeline.fit(traindata);         


            // step 5:评估模型的准确性
            console.writeline("===== evaluating model's accuracy with test data =====");
            var predictions = trainedmodel.transform(testdata);
            var metrics = mlcontext.multiclassclassification.evaluate(data: predictions, labelcolumnname: "number", scorecolumnname: "score");
            printmulticlassclassificationmetrics(trainer.tostring(), metrics);
         
            // step 6:保存模型              
            mlcontext.componentcatalog.registerassembly(typeof(debugconversion).assembly);
            mlcontext.model.save(trainedmodel, traindata.schema, modelpath);
            console.writeline("the model is saved to {0}", modelpath);
        }

        private static void testsomepredictions(mlcontext mlcontext)
        {
            // load model           
            itransformer trainedmodel = mlcontext.model.load(modelpath, out var modelinputschema);

            // create prediction engine 
            var predengine = mlcontext.model.createpredictionengine<inputdata, outputdata>(trainedmodel);

            //num 1
            inputdata mnist1 = new inputdata()
            {               
                pixelvalues = new float[] { 0, 0, 0, 0, 14, 13, 1, 0, 0, 0, 0, 5, 16, 16, 2, 0, 0, 0, 0, 14, 16, 12, 0, 0, 0, 1, 10, 16, 16, 12, 0, 0, 0, 3, 12, 14, 16, 9, 0, 0, 0, 0, 0, 5, 16, 15, 0, 0, 0, 0, 0, 4, 16, 14, 0, 0, 0, 0, 0, 1, 13, 16, 1, 0 }
            }; 
            var resultprediction1 = predengine.predict(mnist1);
            resultprediction1.printtoconsole();           
        }      
    }

    class inputdata
    {
        public float serial;
        [vectortype(64)]
        public float[] pixelvalues;               
        public float number;       
    }

    class outputdata : inputdata
    {  
        public float[] score;  
    }   
}

  

三、分析

 整体流程和二元分类没有什么区别,下面解释一下有差异的两个地方。

 1、加载数据

      // step 1: 准备数据
            var fulldata = mlcontext.data.loadfromtextfile(path: traindatapath,
                    columns: new[]
                    {
                        new textloader.column("serial", datakind.single, 0),
                        new textloader.column("pixelvalues", datakind.single, 1, 64),
                        new textloader.column("number", datakind.single, 65)
                    },
                    hasheader: true,
                    separatorchar: ','
                    );

  这次我们不是通过实体对象来加载数据,而是通过列信息来进行加载,其中pixelvalues是特征值,number是标签值。

 

2、训练通道

            // step 2: 配置数据处理管道        
            var dataprocesspipeline = mlcontext.transforms.conversion.mapvaluetokey("label", "number", keyordinality: valuetokeymappingestimator.keyordinality.byvalue)

// step 3: 配置训练算法 var trainer = mlcontext.multiclassclassification.trainers.sdcamaximumentropy(labelcolumnname: "label", featurecolumnname: "pixelvalues");
var trainingpipeline = dataprocesspipeline.append(trainer)
.append(mlcontext.transforms.conversion.mapkeytovalue("number", "label"));

// step 4: 训练模型使其与数据集拟合
itransformer trainedmodel = trainingpipeline.fit(traindata);

 首先通过mapvaluetokey方法将number值转换为key类型,多元分类算法要求标签值必须是这种类型(类似枚举类型,二元分类要求标签为bool类型)。关于这个转换的原因及编码方式,下面详细介绍。

 

四、键值类型编码与独热编码

 mapvaluetokey功能是将(字符串)值类型转换为keytpye类型。

有时候某些输入字段用来表示类型(类别特征),但本身并没有特别的含义,比如编号、电话号码、行政区域名称或编码等,这里需要把这些类型转换为1到一个整数如1-300来进行重新编号。

举个简单的例子,我们进行图片识别的时候,目标结果可能是“猫咪”、“小狗”、“人物”这些分类,需要把这些分类转换为1、2、3这样的整数。但本文的标签值本身就是1、2、3,为什么还要转换呢?因为我们这里的一二三其实不是数学意义上的数字,而是一种标志,可以理解为壹、贰、叁,所以要进行编码。

 mapkeytovalue和mapvaluetokey相反,它把将键类型转换回其原始值(字符串)。就是说标签是文本格式,在运算前已经被转换为数字枚举类型了,此时预测结果为数字,通过mapkeytovalue将其结果转换为对应文本。

mapvaluetokey一般是对标签值进行编码,一般不用于特征值,如果是特征值为字符串类型的,建议采用独热编码。独热编码即 one-hot 编码,又称一位有效编码,其方法是使用n位状态寄存器来对n个状态进行编码,每个状态都由他独立的寄存器位,并且在任意时候,其中只有一位有效。例如:

自然状态码为:0,1,2,3,4,5
独热编码为:000001,000010,000100,001000,010000,100000

怎么理解这个事情呢?举个例子,假如我们要进行人的身材的分析,但我们希望加入地域特征,比如:“黑龙江”、“山东”、“湖南”、“广东”这种特征,但这种字符串机器学习是不认识的,必须转换为浮点数,刚才提到mapkeytovalue可以把字符串转换为数字,为什么这里要采用独热编码呢?简单来说,假设把地域名称转换为1到10几个数字,在欧氏几何中1到3的欧拉距离和1到9的欧拉距离是不等的,但经过独热编码后,任意两点间的欧拉距离都是相等的,而我们这里的地域特征仅仅是想表达分类关系,彼此之间没有其他逻辑关系,所以应该采用独热编码。
机器学习框架ML.NET学习笔记【4】多元分类之手写数字识别

 

五、进度调试

一般机器算法的数据拟合过程时间都比较长,有时程序跑了两个小时还没结束,也不知道还需要多长时间,着实让人着急,所以及时了解学习进度,是很有必要的。

由于机器学习算法一般都有“递归直到收敛”这种操作,所以我们是没有办法预先知道最终运算次数的,能做到的只能打印一些过程信息,看到程序在动,心里也有点底,当系统跑过一次之后,基本就大致知道需要多少次拟合了,后面再调试就可以大致了解进度了。补充一句,可不可以在测试阶段先减少样本数据进行快速调试,调试通过后再切换到全样本进行训练?其实不行,有时候样本数量小,可能会引起指标震荡,时间反而长了。

之前在githube上看到有人通过mlcontext.log事件来打印调试信息,我试了一下,发现没法控制筛选内容,不太方便,后来想到一个方法,就是新增一个自定义数据处理通道,这个通道不做具体事情,就打印调试信息。

类定义:

namespace multiclassclassification_mnist
{
    public class debugconversioninput
    {
        public float serial { get; set; }
    }
 
    public class debugconversionoutput
    {
        public float debugfeature { get; set; }
    }

    [custommappingfactoryattribute("debugconversionaction")]
    public class debugconversion : custommappingfactory<debugconversioninput, debugconversionoutput>
    {       static long totalcount = 0;

        public void customaction(debugconversioninput input, debugconversionoutput output)
        {
            output.debugfeature = 1.0f;  
totalcount++; console.writeline($"debugconversion.customaction's debug info.totalcount={totalcount} "); } public override action<debugconversioninput, debugconversionoutput> getmapping() => customaction; } }

 使用方法:

 var dataprocesspipeline = mlcontext.transforms.custommapping(new debugconversion().getmapping(), contractname: "debugconversionaction")
       .append(...)
       .append(mlcontext.transforms.concatenate("features", new string[] { "realfeatures", "debugfeature" }));

 通过custommapping加载我们自定义的数据处理通道,由于数据集是懒加载(lazy)的,所以必须把我们自定义数据处理通道的输出加入为特征值,才能参与运算,然后算法在操作每一条数据时都会调用到customaction方法,这样就可以打印进度信息了。为了不影响运算结果,我们把这个数据处理通道的输出值固定为1.0f 。

 

六、资源获取

源码下载地址:https://github.com/seabluescn/study_ml.net

工程名称:multiclassclassification_mnist

点击查看机器学习框架ml.net学习笔记系列文章目录

上一篇: *的植物

下一篇: 刚拿到驾照