欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

机器学习小白修炼之路

程序员文章站 2024-03-14 20:20:29
...

机器学习基本算法之KNN近邻算法(有监督学习)

一、思路、图解及度量距离

思路
找出某个待预测样本在特征空间中最相似的K个样本,如果其中绝大多数属于某个类别,则待测样本也属于那个类别。

图解:
机器学习小白修炼之路如果K=3,绿色圆圈最相近的三个样本分别是两个红色三角形和一个蓝色正方形,根据决策,判定绿色圆圈属于红色的三角形一类。

如果K=5,绿色圆圈最相近的五个样本分别是两个红色三角形和三个蓝色正方形,根据决策,判定绿色圆圈属于蓝色的正方形一类。

两种度量的距离:
1、欧拉距离:简单的平面几何中两点之间的直线距离。

机器学习小白修炼之路
2、曼哈顿距离:棋盘上会使用的距离计算方法。
机器学习小白修炼之路

二、算法步骤

1、计算已知类别数据集中的点与当前点的距离
2、按照距离一次排序
3、选取与当前点距离最小的K个点
4、确定前K个点所在类别的出现概率
5、返回前K个点出现频率最高的类别作为当前点预测分类

三、存在缺陷与改进措施

缺陷
如果样本不平衡时,一个类的样本容量很大,其他类样本容量很小时,很大可能会导致当输入一个新样本时,该样本的 K 个邻居中大容量类的样本占多数,而忽略了小样本容量的影响。

手法
不同距离的样本给予不同的权重。

四、代码实现

数据可视化
机器学习小白修炼之路
python实现

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

#加载特征向量和标签
def load_dataset(filename):
    file=open(filename,"rb")
    lines=file.readlines()
    m=len(lines)
    index=0
    featuresSet=np.zeros((m,3))
    label=[]
    for line in lines:
        #取掉每一行的头尾空格
        line=line.decode("utf-8").strip()
        #切片的返回值是 一串向量
        featuresSet[index,:]=line.split('\t')[0:3]
        label.append(int(line.split("\t")[-1]))
        index+=1
    return featuresSet,label

#参数分别表示测试集合的某行、训练集、标签、k值
def knn_classify(row_feature,featuresSet,label,k):
    m=len(featuresSet[:,1])
    n=len(featuresSet[1,:])

    #归一化,python里面的数据归一化真的和Matlab里面完全不一样,python复杂点
    #矩阵.min(0):结果是每一列最小值组成的向量;  矩阵.min(1):结果是每一行最小值组成的向量
    #mean()和std()同min()和max()
    minVal=featuresSet.min(0)
    mean=featuresSet.mean(0)
    std=featuresSet.std(0)
    maxVal=featuresSet.max(0)
    #普通的归一化
    #featuresSet=(featuresSet-np.tile(minVal,(m,1)))/(np.tile(maxVal,(m,1))-np.tile(minVal,(m,1)))
    #row_feature=(row_feature-minVal)/(maxVal-minVal)
    #np.tile(x,(m,n)) 类似铺瓷砖,将x铺成m*n的样子
    #标准的归一化
    featuresSet=(featuresSet-np.tile(mean,(m,1)))/np.tile(std,(m,1))
    row_feature=(row_feature-mean)/std

    #求出待预测样本和所有训练集中的欧拉距离
    diffMat=np.tile(row_feature,(m,1))-featuresSet
    sqDiffMat=diffMat**2;
    sqDistance=sqDiffMat.sum(axis=1)
    distance=sqDistance**0.5

    #某几种类别的样本容量差距很大时,给distance附上权重
    #distance=distance/sqDistance
    
    #argsort()函数是将list从小到大排序,提取出index
    sortedDistIndices=distance.argsort()
    #字典的形式:{标签:标签出现的次数}
    classCount={}
    for i in range(k):
        voteLabel=label[sortedDistIndices[i]]
        #根据字典的key找到对应的value,指定default0
        classCount[voteLabel]=classCount.get(voteLabel,0)+1

    #items()函数表示将字典转换成list。python3删除了cmp的定义,这玩意和快排的cmp是一个道理
    #默认是递增排序,reverse=True反转为递减排序,找到出现次数最多的那个label
    sortedclassCount=sorted(classCount.items(),key=lambda s:s[1],reverse=True)
    return sortedclassCount[0][0]

def test(numTestVecs):
    x,y=load_dataset("datingTestSet.txt")
    m=x.shape[0]
    errorCount=0.0
    for i in range(numTestVecs):
        classifierResult=knn_classify(x[i,:],x[numTestVecs:m,:],y[numTestVecs:m],4)
        print("模型预测值:%d, 真实值:%d" %(classifierResult,y[i]))
        if(classifierResult!=y[i]):
            errorCount+=1.0
    errorRate=errorCount/float(numTestVecs)
    print("正确率:%f"%(1-errorRate))
    return 1-errorRate

def createScatter():
    featureSet,label=load_dataset("datingTestSet.txt")
    type1_x=[]
    type1_y=[]
    type2_x=[]
    type2_y=[]
    type3_x=[]
    type3_y=[]
    fig=plt.figure()
    axes=plt.subplot(1,1,1)
    #防止中文乱码
    plt.rcParams["font.sans-serif"]=["SimHei"]
    for i in range(len(label)):
        if label[i]==1:
            type1_x.append(featureSet[i][0])
            type1_y.append(featureSet[i][1])
        if label[i]==2:
            type2_x.append(featureSet[i][0])
            type2_y.append(featureSet[i][1])
        if label[i]==3:
            type3_x.append(featureSet[i][0])
            type3_y.append(featureSet[i][1])
    type1=axes.scatter(type1_x,type1_y,s=20,c='red')
    type2=axes.scatter(type2_x,type2_y,s=40,c='green')
    type3=axes.scatter(type3_x,type3_y,s=50,c='blue')
    plt.title(u'Pokemon')
    plt.xlabel(u'种族值*100')
    plt.ylabel(u"个体值")
    axes.legend((type1,type2,type3),(u'堪比鲤鱼王', u'潜力一般', u'极具发展潜力'), loc=2)
    plt.show()

if __name__ == '__main__':
    x,y=load_dataset("datingTestSet.txt")
    createScatter()
    #1-200条数据为测试集,201-1000条数据为样本集
    test(200)

结果
机器学习小白修炼之路

五、数据集下载


数据集下载链接

提取码:lxnm