KNN算法——实现手写数字识别(Sklearn实现)
程序员文章站
2024-03-08 08:39:57
...
KNN项目实战——手写数字识别
1、数据集介绍
需要识别的数字已经使用图形处理软件,处理成具有相同的色彩和大小:宽高是32像素x32像素的黑白图像。尽管采用本文格式存储图像不能有效地利用内存空间,但是为了方便理解,我们将图片转换为文本格式。
数字的文本格式如下:
数据集下载:
trainingDigits训练集下载 , testDigits测试集下载
这些文本格式存储的数字的文件命名也很有特点,格式为:数字的值_该数字的样本序号,如下:
2、准备数据:将图像转换为测试向量
将每个数字文件中32*32的二进制图像矩阵转换为1*1024的向量,作为一个样本输入。
3、代码实现
使用sklearn机器学习算法库中的KNN算法实现手写数字识别
import numpy as np
from os import listdir
from sklearn.neighbors import KNeighborsClassifier as KNN
"""
函数说明:将32x32的二进制图像转换为1x1024向量
"""
def img2vector(filename):
#创建1x1024零向量
returnVect = np.zeros((1, 1024))
#打开文件
fr = open(filename)
#按行读取
for i in range(32):
#读一行数据
lineStr = fr.readline()
#每一行的前32个元素依次添加到returnVect中
for j in range(32):
returnVect[0, 32*i+j] = int(lineStr[j])
#返回转换后的1x1024向量
return returnVect
"""
函数说明:手写数字分类测试
"""
def handwritingClassTest():
#训练集的Labels
hwLabels = []
#返回trainingDigits目录下的文件名
trainingFileList = listdir('trainingDigits')
#返回文件夹下文件的个数
m = len(trainingFileList)
#初始化训练的Mat矩阵,训练集
trainingMat = np.zeros((m, 1024))
#从文件名中解析出训练集的类别
for i in range(m):
#获得文件的名字
fileNameStr = trainingFileList[i]
#获得分类的数字
classNumber = int(fileNameStr.split('_')[0])
#将获得的类别添加到hwLabels中
hwLabels.append(classNumber)
#将每一个文件的1x1024数据存储到trainingMat矩阵中
trainingMat[i,:] = img2vector('trainingDigits/%s' % (fileNameStr))
#构建kNN分类器
neigh =KNN(n_neighbors = 3, algorithm = 'auto')
#拟合模型, trainingMat为训练矩阵,hwLabels为对应的标签
neigh.fit(trainingMat, hwLabels)
#返回testDigits目录下的文件列表
testFileList = listdir('testDigits')
#错误检测计数
errorCount = 0.0
#测试数据的数量
mTest = len(testFileList)
#从文件中解析出测试集的类别并进行 分类测试
for i in range(mTest):
#获得文件的名字
fileNameStr = testFileList[i]
#获得分类的数字
classNumber = int(fileNameStr.split('_')[0])
#获得测试集的1x1024向量,用于训练
vectorUnderTest = img2vector('testDigits/%s' % (fileNameStr))
#获得预测结果
classifierResult = neigh.predict(vectorUnderTest)
print("分类返回结果为%d\t真实结果为%d" % (classifierResult, classNumber))
if(classifierResult != classNumber):
errorCount += 1.0
print("总共错了%d个数据\n错误率为%f%%" % (errorCount, errorCount/mTest * 100))
"""
函数说明:main函数
"""
if __name__=='__main__':
handwritingClassTest()
结果输出:
上一篇: C语言位运算