机器学习——k近邻算法(KNN)
程序员文章站
2022-07-14 20:44:22
...
import math
import numpy as np
from collections import Counter
class KNNClassfiy(object):
def __init__(self,k):
#判断k有效
assert k>=1,'k must be valid'
self.k=k
self._xTrain=None
self._yTrain=None
def fit(self,xTrain,yTrain):
#判断输入的训练集有效
assert xTrain.shape[0]==yTrain.shape[0],\
'The size of xTrain must be equals to the size of yTrain'
#判断K有效
assert self.k<=xTrain.shape[0],\
'The size of xTrain must be least at k'
self._xTrain=xTrain
self._yTrain=yTrain
return self
def predict(self,X_predict):
# X_predict是预测数据数组,判断预测数据合法性,必须是二维数组
assert X_predict.shape[1]==self._xTrain.shape[1],\
'The feature of x must be equal to xTrain'
assert self._xTrain is not None and self._yTrain is not None,\
'must fit before predict'
y_predict=[self._predict(x) for x in X_predict]
return np.array(y_predict)
def _predict(self,x):
distances=[math.sqrt(np.sum((xTrain-x)**2)) for xTrain in self._xTrain]
nearest=np.argsort(distances)
top_y=[self._yTrain[i] for i in nearest[:self.k]]
votes=Counter(top_y)
print(votes.most_common(1))
return votes.most_common(1)[0][0]
def __repr__(self):
return self.k
KNN_clf=KNNClassfiy(k=6);
#先训练后预测
xTrain=np.array([[4.5,3.2],
[5.8,4.1],
[6.7,5.3],
[8.6,7.1],
[3.8,2.5],
[5.3,4.4],
[9.4,8.6],
[11.8,9.4],
[3.8,3.2],
[12.8,10.1]])
yTrain=np.array([0,0,1,1,0,0,1,1,0,1])
KNN_clf.fit(xTrain=xTrain,yTrain=yTrain)
x_predict=np.array([[6.9,5.7],[3.4,2.8]])
a=KNN_clf.predict(x_predict)
print(a[0],a[1])
代码比较简单,主要逻辑在于预测部分。
调用matplotlib绘制图形分布图:
步骤可简化如下:
- 确定k值
- 训练数据集
- 预测函数
K近邻算法主要解决分类问题,是机器学习中最简单的最基础的一种算法。
上一篇: 软件开发的葵花宝典[转载]
下一篇: 【机器学习】K-近邻算法(KNN)