欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

机器学习——k近邻算法(KNN)

程序员文章站 2022-07-14 20:44:22
...
import math
import numpy as np
from collections import Counter
class KNNClassfiy(object):
    def __init__(self,k):
    #判断k有效
        assert k>=1,'k must be valid'
        self.k=k
        self._xTrain=None
        self._yTrain=None


    def fit(self,xTrain,yTrain):
    #判断输入的训练集有效
        assert xTrain.shape[0]==yTrain.shape[0],\
            'The size of xTrain must be equals to the size of yTrain'
    #判断K有效   
        assert self.k<=xTrain.shape[0],\
            'The size of xTrain must be least at k'
        self._xTrain=xTrain
        self._yTrain=yTrain
        return self

    def predict(self,X_predict):
        # X_predict是预测数据数组,判断预测数据合法性,必须是二维数组
        assert X_predict.shape[1]==self._xTrain.shape[1],\
            'The feature of x must be equal to xTrain'
        assert self._xTrain is not None and self._yTrain is not None,\
            'must fit before predict'
        y_predict=[self._predict(x) for  x in X_predict]
        return np.array(y_predict)

    def _predict(self,x):
        distances=[math.sqrt(np.sum((xTrain-x)**2)) for xTrain in self._xTrain]
        nearest=np.argsort(distances)
        top_y=[self._yTrain[i] for i in nearest[:self.k]]
        votes=Counter(top_y)
        print(votes.most_common(1))
        return votes.most_common(1)[0][0]
    def __repr__(self):
        return self.k

KNN_clf=KNNClassfiy(k=6);
#先训练后预测
xTrain=np.array([[4.5,3.2],
                 [5.8,4.1],
                 [6.7,5.3],
                 [8.6,7.1],
                 [3.8,2.5],
                 [5.3,4.4],
                 [9.4,8.6],
                 [11.8,9.4],
                 [3.8,3.2],
                 [12.8,10.1]])
yTrain=np.array([0,0,1,1,0,0,1,1,0,1])
KNN_clf.fit(xTrain=xTrain,yTrain=yTrain)
x_predict=np.array([[6.9,5.7],[3.4,2.8]])
a=KNN_clf.predict(x_predict)
print(a[0],a[1])

代码比较简单,主要逻辑在于预测部分。

调用matplotlib绘制图形分布图

机器学习——k近邻算法(KNN)

步骤可简化如下:

  • 确定k值
  • 训练数据集
  • 预测函数

K近邻算法主要解决分类问题,是机器学习中最简单的最基础的一种算法。

相关标签: K近邻算法