AI入门算法之KNN(手写数字识别)
程序员文章站
2024-03-08 08:14:15
...
训练集和测试集来自点击打开链接
数字样本的处理方式是将其从32*32的0,1矩阵转换成1*1024的Numpy数组
例如其中的一个分类标签为0的32*32矩阵可表示为:
00000000000001111000000000000000
00000000000011111110000000000000
00000000001111111111000000000000
00000001111111111111100000000000
00000001111111011111100000000000
00000011111110000011110000000000
00000011111110000000111000000000
00000011111110000000111100000000
00000011111110000000011100000000
00000011111110000000011100000000
00000011111100000000011110000000
00000011111100000000001110000000
00000011111100000000001110000000
00000001111110000000000111000000
00000001111110000000000111000000
00000001111110000000000111000000
00000001111110000000000111000000
00000011111110000000001111000000
00000011110110000000001111000000
00000011110000000000011110000000
00000001111000000000001111000000
00000001111000000000011111000000
00000001111000000000111110000000
00000001111000000001111100000000
00000000111000000111111000000000
00000000111100011111110000000000
00000000111111111111110000000000
00000000011111111111110000000000
00000000011111111111100000000000
00000000001111111110000000000000
00000000000111110000000000000000
import operator
from numpy import *
from audioop import reverse
from nt import listdir
def classify0(inX,dataSet,labels,k):#knn算法的实现部分
dataSetSize=dataSet.shape[0]
diffMat=tile(inX,(dataSetSize,1))-dataSet
sqDiffMat=diffMat**2
sqDistance=sqDiffMat.sum(axis=1)
distance=sqDistance**0.5
sortedDistIndicies=distance.argsort()
classCount ={}
for i in range (k):
voteIlabel=labels[sortedDistIndicies[i]]
classCount[voteIlabel]=classCount.get(voteIlabel,0)+1
sortedClassCount=sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
return sortedClassCount[0][0]
def img2vector(filename):#打开文件并实现矩阵向数组的处理
returnVector=zeros((1,1024))
fr=open(filename)
for i in range(32):
lineStr=fr.readline()
for j in range(32):
returnVector[0,i*32+j]=int(lineStr[j])
return returnVector
def handWriting():#手写识别系统的实现
hwLabels=[]
trainingFileList=listdir('trainingDigits')
m=len(trainingFileList)
trainMat=zeros((m,1024))
for i in range (m):
filenamestr=trainingFileList[i]
filestr=filenamestr.split('.')[0]
classNumstr=int(filestr.split('_')[0])
hwLabels.append(classNumstr)
trainMat[i,:]=img2vector('trainingDigits\\%s' % filenamestr)
testFileList=listdir('testDigits')
errorCount=0.0
mtest=len(testFileList)
for i in range (mtest):
filenamestr=testFileList[i]
filestr=filenamestr.split('.')[0]
classNumstr=int(filestr.split('_')[0])
vectorUnderTest=img2vector('testDigits/%s' % filenamestr)
classifierResult=classify0(vectorUnderTest, trainMat, hwLabels, 3)
if(classifierResult!=classNumstr):
errorCount+=1.0
print('the total error rate is: %f' % (errorCount/float(mtest)))
handWriting()
最后得到的测试集正确率接近99%