机器学习框架ML.NET学习笔记【5】手写数字识别(续)
一、概述
上一篇文章我们利用ml.net的多元分类算法实现了一个手写数字识别的例子,这个例子存在一个问题,就是输入的数据是预处理过的,很不直观,这次我们要直接通过图片来进行学习和判断。思路很简单,就是写一个自定义的数据处理通道,输入为文件名,输出为float数字,里面保存的是像素信息。
样本包括6万张训练图片和1万张测试图片,图片为灰度图片,分辨率为20*20 。train_tags.tsv文件对每个图片的数值进行了标记,如下:
二、源码
全部代码:
namespace multiclassclassification_mnist { class program { //assets files download from:https://gitee.com/seabluescn/ml_assets static readonly string assetsfolder = @"d:\stepbystep\blogs\ml_assets\mnist"; static readonly string traintagspath = path.combine(assetsfolder, "train_tags.tsv"); static readonly string traindatafolder = path.combine(assetsfolder, "train"); static readonly string modelpath = path.combine(environment.currentdirectory, "data", "sdca-model.zip"); static void main(string[] args) { mlcontext mlcontext = new mlcontext(seed: 1); trainandsavemodel(mlcontext); testsomepredictions(mlcontext); console.writeline("hit any key to finish the app"); console.readkey(); } public static void trainandsavemodel(mlcontext mlcontext) { // step 1: 准备数据 var fulldata = mlcontext.data.loadfromtextfile<inputdata>(path: traintagspath, separatorchar: '\t', hasheader: false); var traintestdata = mlcontext.data.traintestsplit(fulldata, testfraction: 0.1); var traindata = traintestdata.trainset; var testdata = traintestdata.testset; // step 2: 配置数据处理管道 var dataprocesspipeline = mlcontext.transforms.custommapping(new loadimageconversion().getmapping(), contractname: "loadimageconversionaction") .append(mlcontext.transforms.conversion.mapvaluetokey("label", "number", keyordinality: valuetokeymappingestimator.keyordinality.byvalue)) .append(mlcontext.transforms.normalizemeanvariance( outputcolumnname: "featuresnormalizedbymeanvar", inputcolumnname: "imagepixels")); // step 3: 配置训练算法 (using a maximum entropy classification model trained with the l-bfgs method) var trainer = mlcontext.multiclassclassification.trainers.lbfgsmaximumentropy(labelcolumnname: "label", featurecolumnname: "featuresnormalizedbymeanvar"); var trainingpipeline = dataprocesspipeline.append(trainer) .append(mlcontext.transforms.conversion.mapkeytovalue("predictnumber", "label")); // step 4: 训练模型使其与数据集拟合 itransformer trainedmodel = trainingpipeline.fit(traindata); // step 5:评估模型的准确性 var predictions = trainedmodel.transform(testdata); var metrics = mlcontext.multiclassclassification.evaluate(data: predictions, labelcolumnname: "label", scorecolumnname: "score"); printmulticlassclassificationmetrics(trainer.tostring(), metrics); // step 6:保存模型 mlcontext.model.save(trainedmodel, traindata.schema, modelpath); } private static void testsomepredictions(mlcontext mlcontext) { // load model itransformer trainedmodel = mlcontext.model.load(modelpath, out var modelinputschema); // create prediction engine var predengine = mlcontext.model.createpredictionengine<inputdata, outputdata>(trainedmodel); directoryinfo testfolder = new directoryinfo(path.combine(assetsfolder, "test")); foreach(var image in testfolder.getfiles()) { count++; inputdata img = new inputdata() { filename = image.name }; var result = predengine.predict(img); console.writeline($"current source={img.filename},predictresult={result.getpredictresult()}"); } } } class inputdata { [loadcolumn(0)] public string filename; [loadcolumn(1)] public string number; [loadcolumn(1)] public float serial; } class outputdata : inputdata { public float[] score; public int getpredictresult() { float max = 0; int index = 0; for (int i = 0; i < score.length; i++) { if (score[i] > max) { max = score[i]; index = i; } } return index; } } }
三、分析
整个处理流程和上一篇文章基本一致,这里解释两个不一样的地方。
1、自定义的图片读取处理通道
namespace multiclassclassification_mnist { public class loadimageconversioninput { public string filename { get; set; } } public class loadimageconversionoutput { [vectortype(400)] public float[] imagepixels { get; set; } public string imagepath; } [custommappingfactoryattribute("loadimageconversionaction")] public class loadimageconversion : custommappingfactory<loadimageconversioninput, loadimageconversionoutput> { static readonly string traindatafolder = @"d:\stepbystep\blogs\ml_assets\mnist\train"; public void customaction(loadimageconversioninput input, loadimageconversionoutput output) { string imagepath = path.combine(traindatafolder, input.filename); output.imagepath = imagepath; bitmap bmp = image.fromfile(imagepath) as bitmap; output.imagepixels = new float[400]; for (int x = 0; x < 20; x++) for (int y = 0; y < 20; y++) { var pixel = bmp.getpixel(x, y); var gray = (pixel.r + pixel.g + pixel.b) / 3 / 16; output.imagepixels[x + y * 20] = gray; } bmp.dispose(); } public override action<loadimageconversioninput, loadimageconversionoutput> getmapping() => customaction; } }
这里可以看出,我们自定义的数据处理通道,输入为文件名称,输出是一个float数组,这里数组必须要指定宽度,由于图片分辨率为20*20,所以数组宽度指定为400,输出imagepath为文件详细地址,用来调试使用,没有实际用途。处理思路非常简单,遍历每个pixel,计算其灰度值,为了减少工作量我们把灰度值进行缩小,除以了16 ,由于后面数据会做归一化,所以这里影响不是太明显。
2、模型测试
directoryinfo testfolder = new directoryinfo(path.combine(assetsfolder, "test")); int count = 0; int success = 0; foreach(var image in testfolder.getfiles()) { count++; inputdata img = new inputdata() { filename = image.name }; var result = predengine.predict(img); if(int.parse(image.name.substring(0,1))==result.getpredictresult()) { success++; } }
我们把测试目录里的全面图片读出遍历了一遍,将其测试结果和实际结果做了一次验证,实际上是把评估(evaluate)的事情又重复做了一次,两次测试的成功率基本接近。
四、关于图片特征提取
我们是采用图片所有像素的灰度值来作为特征值的,但必须要强调的是:像素值矩阵不是图片的典型特征。虽然有时候对于较规则的图片,通过像素提取方式进行计算,也可以取得很好的效果,但在处理稍微复杂一点的图片的时候,就不管用了,原因很明显,我们人类在分析图片内容时看到的特征更多是线条等信息,绝对不是像素值,看下图:
我们人类很容易就判断出这两个图片表达的是同一件事情,但其像素值特征却相差甚远。
传统的图片特征提取方式很多,比如:sift、hog、lbp、haar等。 现在采用tensorflow的模型进行特征提取效果非常好。下一篇文章介绍图片分类时再进行详细介绍。
五、资源获取
源码下载地址:https://github.com/seabluescn/study_ml.net
工程名称:multiclassclassification_mnist_useful
mnist资源获取:https://gitee.com/seabluescn/ml_assets
上一篇: 人生苦短,很快到站