欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

机器学习第二个算法KNN(最邻近规则分类KNN算法)

程序员文章站 2024-01-25 16:09:10
...

最近邻规则分类(K-NearestNeighbors)KNN 算法,是由Cover和Hart在1968年提出的算法,叫做instance-based -learning,是一种分类算法,也叫(lazy learning)。因为它没有建立具体的模型,而是对未知的对象直接归类,看他属于哪个类即可。


先看一个简单的例子:
对于电影的分类,我们先粗略的把电影分成两类,而且主要是根据电影中出现的打斗、接吻情节的次数,来将其分成暴力和浪漫两种类型的电影。


机器学习第二个算法KNN(最邻近规则分类KNN算法)

对于这类问题我们需要学会将一些具体元素抽象化,比如电影名字可以抽象成一个个点,而将打斗和接吻视为坐标位置或者某一类问题的属性,维度等,可能以后还会遇到更多的属性或维度,而这个例子可以简化为:
机器学习第二个算法KNN(最邻近规则分类KNN算法)


这里要求未知点的类型,可以直接求该点距离其他已知点的距离(距离有很多种算法,最常见的有欧拉距离,曼哈顿距离等),将距离最小的一些已知点(多少个?可以设置为参数K)选择出来,通常来说距离越近,类型也会越近,我们从K个已知的实例点中,对不同类型计数,通过少数服从多数原则来判定未知点的类型。


显而易见,该算法非常容易理解,但是实用性可能会受点的分布影像,另外对于数据的计算量要求偏高,而且对参数K比较敏感。


彭亮老师给出了两个示例代码,在详细讲解代码之前,我们先来了解一下数据集。对于机器学习算法而言已经存在了很多的数据集,甚至针对不同的算法都有不同的数据集。对于knn算法,数据集前人也早已准备好,甚至已经集成到python的工具包里了。
机器学习第二个算法KNN(最邻近规则分类KNN算法)
总共150个实例,有四个属性,或者说维度,又分为三个类别。数据集可以直接从库里调用,也可以从网上直接下载,保存为iris.txt格式。(网上直接搜iris即可。)可见具有不同属性的花瓣可以分为三个类别,抽象成一个四个维数组带上自己的标签即可。

机器学习第二个算法KNN(最邻近规则分类KNN算法)

还是利用python的机器学习库sklearn;对两块代码稍作分析,如下所示:

"""
直接调用打包好的函数
"""

from sklearn import neighbors
from sklearn import datasets


knn =neighbors.KNeighborsClassifier() #knn分类
iris = datasets.load_iris()           #调用数据集

#理解成训练
knn.fit(iris.data,iris.target)  #将数据填进knn

#直接预测
predictedLabel = knn.predict([[0.1,0.2,0.3,0.4]])   #预测数据
print(predictedLabel)

彭亮老师还自己写函数来实现knn算法,代码值得琢磨一下:

import csv       #读取文件
import random    #产生随机数
import math
import operator


# 读取数据并且将数据分成了训练集和测试集

def loadDataset(filename,split,trainingSet=[],testSet=[]):
    with open(filename,'rt') as csvfile:
        lines = csv.reader(csvfile)  #读取数的每一行

        dataSet = list(lines)  #将数据变成了列表
        #print(dataset)
        #print(len(dataset))    #长度是150,150组数据

        #for m in range(4):
            #print(m)                  #0,1,2,3 四个


        for x in range(len(dataSet)):  #(0,149) len(dataSet)150个
            for y in range(4):
                dataSet[x][y] =float(dataSet[x][y])   #数据类型需要是float才能进行数据处理 

            if random.random() < split:
                trainingSet.append(dataSet[x])
            else:
                testSet.append(dataSet[x])

        #print(len(trainingSet))
        #print(len(testSet))

def euclideanDistance(instance1,instance2,length):
    distance = 0;
    for x in range(length):
        distance +=pow(instance1[x] - instance2[x],2)
    return math.sqrt(distance)


def getNeighbors(trainingSet,testInstance,k):
    distances = []
    length = len(testInstance)-1       #testInstance是一个实例数据

    for x in range(len(trainingSet)):  #trainingSet的每组数据(训练的所有数据和测试的数据做距离计算)
    #for x in range(5):
        dist = euclideanDistance(testInstance,trainingSet[x],length)  
        distances.append((trainingSet[x],dist))  #把dist加进去了(([a, b, c, d, 'Iris-setosa'], dist))
    distances.sort(key = operator.itemgetter(1)) #将测试集所有数据和Instance的距离求出来排序
    #print(distances) 

    neighbors = []
    for x in range(3):
        #print(distances[x][0])
        #print(distances[x][1])
        neighbors.append(distances[x][0])  #只取k个,并且将列在后面的dist抛弃(取第一个数据)

    #print(distances[0][0])
    #print(distances[1][0])
    #print(neighbors)

    return neighbors


def getResponse(neighbors):
    classVotes = {}   #字典

    for x in range(len(neighbors)):
        response = neighbors[x][-1]    #最后一个,即目标

        #print(response)

        if response in classVotes:   #如果没有出现,则为1,出现过了,每出现一次加1
            classVotes[response] += 1
        else:
            classVotes[response] = 1

        sortedVotes = sorted(classVotes.items(),key = operator.itemgetter(1),reverse = True)
        #按字典里的第二个参数进行排序,即是出现的次数

    #print(sortedVotes[0][0])
    #print(classVotes.items())
    #print(sortedVotes)

    return sortedVotes[0][0]  #给出排列在第一的目标

def getAccuracy(testSet,predictions):
    correct = 0
    for x in range(len(testSet)):
        if testSet[x][-1] == predictions[x]:
            correct +=1
    return (correct/float(len(testSet)) *100.0)



def main():
    trainingSet = []
    testSet = []
    split = 0.66667
    loadDataset(r'D:\Machine_Learning\NearestNeighbors\iris.txt',split,trainingSet,testSet)

    #print('Train set:' +str(len(trainingSet)))
    #print('Test set:' +repr(len(testSet)))


    predictions = []
    k = 3
    t =0
    for x in range(len(testSet)):  #对测试集里的每一个元素进行判断
    #for x in range(1):

        neighbors = getNeighbors(trainingSet,testSet[x],k)

        result = getResponse(neighbors)

        predictions.append(result)

        print('>predicted=' +repr(result) +',actual = '+repr(testSet[x][-1]))

        if testSet[x][-1] != predictions[x]:           
            t +=1
            print('wrong '+'>predicted=' +repr(result) +',actual = '+repr(testSet[x][-1]))


    accuracy =getAccuracy(testSet,predictions)
    print('Accuracy:' +repr(accuracy) +'%')
    print('wrongNum ' +str(t))

main()

有不熟悉的地方可以把我注释掉的地方打印出来瞅瞅,就明白很多了。