欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

机器学习框架ML.NET学习笔记【2】入门之二元分类

程序员文章站 2022-04-21 19:57:32
一、准备样本 接上一篇文章提到的问题:根据一个人的身高、体重来判断一个人的身材是否很好。但我手上没有样本数据,只能伪造一批数据了,伪造的数据比较标准,用来学习还是蛮合适的。 下面是我用来伪造数据的代码: string Filename = "./figure_full.csv"; StreamWri ......

一、准备样本

接上一篇文章提到的问题:根据一个人的身高、体重来判断一个人的身材是否很好。但我手上没有样本数据,只能伪造一批数据了,伪造的数据比较标准,用来学习还是蛮合适的。

下面是我用来伪造数据的代码:

           string filename = "./figure_full.csv";
            streamwriter sw = new streamwriter(filename, false);
            sw.writeline("height,weight,result");

            random random = new random();
            float height, weight;
            result result;

            for (int i = 0; i < 2000; i++)
            {
                height = random.next(150, 195);
                weight = random.next(70, 200);

                if (height > 170 && weight < 120)
                    result = result.good;
                else
                    result = result.bad;
               
                sw.writeline($"{height},{weight},{(int)result}");
            }


   enum result
    {
        bad=0,
        good=1
    }

制造成功后的数据如下:

机器学习框架ML.NET学习笔记【2】入门之二元分类

 用记事本打开:

机器学习框架ML.NET学习笔记【2】入门之二元分类

 

二、源码

数据准备好了,我们就用准备好的数据进行学习了,先贴出全部代码,然后再逐一解释:

namespace binaryclassification_figure
{
    class program
    {
        static readonly string datapath = path.combine(environment.currentdirectory, "data", "figure_full.csv");
        static readonly string modelpath = path.combine(environment.currentdirectory, "data", "fasttree_model.zip");

        static void main(string[] args)
        {
            trainandsave();
            loadandprediction();

            console.writeline("press any to exit!");
            console.readkey();
        }

        static void trainandsave()
        {
            mlcontext mlcontext = new mlcontext();           

            //准备数据
            var fulldata = mlcontext.data.loadfromtextfile<figuredata>(path: datapath, hasheader: true, separatorchar: ',');           
            var traintestdata = mlcontext.data.traintestsplit(fulldata,testfraction:0.2);
            var traindata = traintestdata.trainset;
            var testdata = traintestdata.testset;

            //训练 
            iestimator<itransformer> dataprocesspipeline = mlcontext.transforms.concatenate("features", new[] { "height", "weight" })
                .append(mlcontext.transforms.normalizemeanvariance(inputcolumnname: "features", outputcolumnname: "featuresnormalizedbymeanvar"));
            iestimator<itransformer> trainer = mlcontext.binaryclassification.trainers.fasttree(labelcolumnname: "result", featurecolumnname: "featuresnormalizedbymeanvar");
            iestimator<itransformer> trainingpipeline = dataprocesspipeline.append(trainer); 
            itransformer model = trainingpipeline.fit(traindata);

            //评估
            var predictions = model.transform(testdata);
            var metrics = mlcontext.binaryclassification.evaluate(data: predictions, labelcolumnname: "result", scorecolumnname: "score");
            printbinaryclassificationmetrics(trainer.tostring(), metrics);

            //保存模型
            mlcontext.model.save(model, traindata.schema, modelpath);
            console.writeline($"model file saved to :{modelpath}");
        }      

        static void loadandprediction()
        {
            var mlcontext = new mlcontext();
            itransformer model = mlcontext.model.load(modelpath, out var inputschema);
            var predictionengine = mlcontext.model.createpredictionengine<figuredata, figuredatepredicted>(model);

            figuredata test = new figuredata();
            test.weight = 115;
            test.height = 171;

            var prediction = predictionengine.predict(test);
            console.writeline($"predict result :{prediction.predictedlabel}");
        }      
    }

    public class figuredata
    {
        [loadcolumn(0)]
        public float height { get; set; }

        [loadcolumn(1)]
        public float weight { get; set; }

        [loadcolumn(2)]
        public bool result { get; set; }       
    }

    public class figuredatepredicted : figuredata
    {
        public bool predictedlabel;
    }
}

 

三、对代码的解释

1、读取样本数据

        string datapath = path.combine(environment.currentdirectory, "data", "figure_full.csv");
        mlcontext mlcontext = new mlcontext();          

            //准备数据
            var fulldata = mlcontext.data.loadfromtextfile<figuredata>(path: datapath, hasheader: true, separatorchar: ',');           
            var traintestdata = mlcontext.data.traintestsplit(fulldata,testfraction:0.2);
            var traindata = traintestdata.trainset;
            var testdata = traintestdata.testset;    

 loadfromtextfile<figuredata>(path: datapath, hasheader: true, separatorchar: ',')用来读取数据到dataview

figuredata类是和样本数据对应的实体类,loadcolumn特性指示该属性对应该条数据中的第几个数据。

    public class figuredata
    {
        [loadcolumn(0)]
        public float height { get; set; }

        [loadcolumn(1)]
        public float weight { get; set; }

        [loadcolumn(2)]
        public bool result { get; set; }       
    }

 path:文件路径

hasheader:文本文件是否包含标题

separatorchar:用来分割数据的字符,我们用的是逗号,常用的还有跳格符‘\t’

traintestsplit(fulldata,testfraction:0.2)用来随机分割数据,分成学习数据和评估用的数据,通常情况,如果数据较多,测试数据取20%左右比较合适,如果数据量较少,测试数据取10%左右比较合适。

如果不通过分割,准备两个数据文件,一个用来训练、一个用来评估,效果是一样的。

 

2、训练 

            //训练 
            iestimator<itransformer> dataprocesspipeline = mlcontext.transforms.concatenate("features", new[] { "height", "weight" })
                .append(mlcontext.transforms.normalizemeanvariance(inputcolumnname: "features", outputcolumnname: "featuresnormalizedbymeanvar"));
            iestimator<itransformer> trainer = mlcontext.binaryclassification.trainers.fasttree(labelcolumnname: "result", featurecolumnname: "featuresnormalizedbymeanvar");
            iestimator<itransformer> trainingpipeline = dataprocesspipeline.append(trainer); 
            itransformer model = trainingpipeline.fit(traindata);

  idataview这个数据集就类似一个表格,它的列(column)是可以动态增加的,一开始我们通过loadfromtextfile获得的数据集包括:height、weight、result这几个列,在进行训练之前,我们还要对这个数据集进行处理,形成符合我们要求的数据集。

concatenate这个方法是把多个列,组合成一个列,因为二元分类的机器学习算法只接收一个特征列,所以要把多个特征列(height、weight)组合成一个特征列features(组合的结果应该是个float数组)。

normalizemeanvariance是对列进行归一化处理,这里输入列为:features,输出列为:featuresnormalizedbymeanvar,归一化的含义见本文最后一节介绍。

数据集就绪以后,就要选择学习算法,针对二元分类,我们选择了快速决策树算法fasttree,我们需要告诉这个算法特征值放在哪个列里面(featuresnormalizedbymeanvar),标签值放在哪个列里面(result)。

链接数据处理管道和算法形成学习管道,将数据集中的数据逐一通过学习管道进行学习,形成机器学习模型。

有了这个模型我们就可以通过它进行实际应用了。但我们一般不会现在就使用这个模型,我们需要先评估一下这个模型,然后把模型保存下来。以后应用时再通过文件读取出模型,然后进行应用,这样就不用等待学习的时间了,通常学习的时间都比较长。

 

3、评估 

            //评估
            var predictions = model.transform(testdata);
            var metrics = mlcontext.binaryclassification.evaluate(data: predictions, labelcolumnname: "result");
            printbinaryclassificationmetrics(trainer.tostring(), metrics);

  评估的过程就是对测试数据集进行批量转换(transform),转换过的数据集会多出一个“predictedlabel”的列,这个就是模型评估的结果,逐条将这个结果和实际结果(result)进行比较,就最终形成了效果评估数据。

我们可以打印这个评估结果,查看其成功率,一般成功率大于97%就是比较好的模型了。由于我们伪造的数据比较整齐,所以我们这次评估的成功率为100%。

注意:评估过程不会提升现有的模型能力,只是对现有模型的一种检测。

 

4、保存模型 

//保存模型
           string modelpath = path.combine(environment.currentdirectory, "data", "fasttree_model.zip");
            mlcontext.model.save(model, traindata.schema, modelpath);
            console.writeline($"model file saved to :{modelpath}");

 这个没啥好解释的。

 

5、读取模型并创建预测引擎 

           //读取模型
            var mlcontext = new mlcontext();
            itransformer model = mlcontext.model.load(modelpath, out var inputschema);

            //创建预测引擎
            var predictionengine = mlcontext.model.createpredictionengine<figuredata, figuredatepredicted>(model);

 创建预测引擎的功能和transform是类似的,不过transform是处理批量记录,这里只处理一条数据,而且这里的输入输出是实体对象,定义如下:

   public class figuredata
    {
        [loadcolumn(0)]
        public float height { get; set; }

        [loadcolumn(1)]
        public float weight { get; set; }

        [loadcolumn(2)]
        public bool result { get; set; }       
    }

    public class figuredatepredicted : figuredata
    {
        public bool predictedlabel;
    }

 由于预测结果里放在“predictedlabel”字段中,所以figuredatepredicted类必须要包含predictedlabel属性,目前figuredatepredicted 类是从figuredata类继承的,由于我们只用到predictedlabel属性,所以不继承也没有关系,如果继承的话,后面要调试的话会方便一点。

 

6、应用 

            figuredata test = new figuredata
            {
                weight = 115,
                height = 171
            };

            var prediction = predictionengine.predict(test);
            console.writeline($"predict result :{prediction.predictedlabel}");

 这部分代码就比较简单,test是我们要预测的对象,预测后打印出预测结果。

 

四、附:数据归一化

 机器学习的算法中一般会有很多的乘法运算,当运算的数字过大时,很容易在多次运算后溢出,为了防止这种情况,就要对数据进行归一化处理。归一化的目标就是把参与运算的特征数变为(0,1)或(-1,1)之间的浮点数,常见的处理方式有:min-max标准化、log函数转换、对数函数转换等。

机器学习框架ML.NET学习笔记【2】入门之二元分类

我们这次采用的是平均方差归一化方法。

 

五、资源

源码下载地址:https://github.com/seabluescn/study_ml.net

工程名称:binaryclassification_figure