欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

K近邻算法(k-nearest neighbor, kNN)

程序员文章站 2024-01-22 11:15:52
...

K近邻算法(K-nearest neighbor, KNN)

KNN是一种分类和回归方法。

  • KNN简介
  • KNN模型3要素
  • KNN优缺点
  • KNN应用
  • 参考文献

KNN简介

KNN思想

给定一个训练集T={(x1,y1),(x2,y2),...,(xN,yN)},对新输入的实例x ,在训练集中找到与实例 x 最近的k个实例,根据k个实例中大多数类所属的类作为实例x 所属的类。

KNN算法

K近邻算法(k-nearest neighbor, kNN)

KNN模型3要素

K值得选择、距离度量方法选择、分类决策规则选择 

K值得选择
应用中,一般选择较小的k值,交叉验证可以选择最优的k值。
k值减小,模型变复杂,容易过拟合(原因:选择较小k值时,近似误差减小,估计误差增大)。
近似误差:即对现有训练集的训练误差,更关注“训练”。
估计误差:即对测试集的测试误差,更关注“测试”。
距离度量方法选择
欧氏距离
曼哈顿距离
切比雪夫距离 等等
分类决策规则选择
最常用的是,大多数原则:即由输入实例的k个近邻样本中大多数的类别确定输入实例的类。

KNN优缺点

优点
简单、精度高
缺点
计算时间、空间复杂度高

KNN应用

使用knn算法识别手写数字数据集,链接:https://pan.baidu.com/s/1rgiGBLTMiybCCSUnzR1lYw 密码:yse7

# -*-coding:utf-8-*-

from numpy import *
import operator
from os import listdir


def classify0(inX, dataSet, labels, k):
    dataSetSize = dataSet.shape[0]  # shape[0]读取矩阵第一维的长度
    diffMat = tile(inX, (dataSetSize, 1)) - dataSet  # numpy.tile(A,B)函数重复A, B次
    sqDiffMat = diffMat**2
    sqDistances = sqDiffMat.sum(axis=1)
    distances = sqDistances**0.5
    #print(type(distances))
    sortedDistIndicies = distances.argsort()

    classCount = {}
    for i in range(k):
        voteIlabel = labels[sortedDistIndicies[i]]
        classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1

    sortedClassCount = sorted(classCount.items(), key = operator.itemgetter(1), reverse = True)
    return sortedClassCount[0][0]


def img2vector(filename):
    returnVect = zeros((1, 1024))
    fr = open(filename)
    for i in range(32):
        lineStr = fr.readline()
        for j in range(32):
            returnVect[0, 32*i + j] = int(lineStr[j])
    return returnVect


def handwritingClassTest():
    hwLabels = []
    trainingFileList = listdir('digits/trainingDigits')           # 加载训练集
    m = len(trainingFileList)
    trainingMat = zeros((m,1024))
    for i in range(m):
        fileNameStr = trainingFileList[i]
        fileStr = fileNameStr.split('.')[0]     # 提取文件名
        classNumStr = int(fileStr.split('_')[0])  # 提取类别标签
        hwLabels.append(classNumStr)
        trainingMat[i,:] = img2vector('digits/trainingDigits/%s' % fileNameStr)
    testFileList = listdir('digits/testDigits')        # 加载测试集
    errorCount = 0.0
    mTest = len(testFileList)
    for i in range(mTest):
        fileNameStr = testFileList[i]
        fileStr = fileNameStr.split('.')[0]
        classNumStr = int(fileStr.split('_')[0])
        vectorUnderTest = img2vector('digits/testDigits/%s' % fileNameStr)
        classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)
        print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr)
        if (classifierResult != classNumStr): errorCount += 1.0
    print ("\nthe total number of errors is: %d" % errorCount)
    print ("\nthe total error rate is: %f" % (errorCount/float(mTest)))


if __name__ == '__main__':
    handwritingClassTest()

程序运行结果:
K近邻算法(k-nearest neighbor, kNN)

参考文献

[1]李航. 统计学习方法[M]. 清华大学出版社, 2012.
[2]Peter Harrington. 机器学习实战[M]. 人民邮电出版社, 2013.