KNN最邻近分类算法 python实现
程序员文章站
2024-01-25 16:08:58
...
最邻近分类算法,故名思意就是在距离空间里,如果一个样本的最接近的k个邻居里绝大多数属于某个类别,则该样本也属于这个类别。
下面我们分别举两个例子实现,分别是电影分类(两个特征,便于之间通过散点图观察)和植物分类。
先引入常用的几个模块
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
电影分类
创建数据
rom sklearn import neighbors #导入KNN分类模块
from matplotlib import font_manager
my_font = font_manager.FontProperties(fname="/Library/Fonts/Songti.ttc")
import warnings
warnings.filterwarnings('ignore')
#不发出警告
data = pd.DataFrame({'name':['北京遇上西雅图','喜欢你','疯狂动物城','战狼','力王','敢死队'],
'fight':[3,2,1,101,99,98],
'kiss':[104,100,81,10,5,2],
'type':['Romance','Romance','Romance','Action','Action','Action']})
plt.scatter(data[data['type'] == 'Romance']['fight'],data[data['type'] == 'Romance']['kiss'],color = 'r',label='Romance')
plt.scatter(data[data['type'] == 'Action']['fight'],data[data['type'] == 'Action']['kiss'],color = 'g',label='Action')
plt.grid()
plt.legend()
通过原始时间我们可以清楚的看出他们的特征。
这时我们对[18,90]打架次数为18,接吻次数为90的《你的名字》用KNN算法进行预测。
knn = neighbors.KNeighborsClassifier()#KNN模型
knn.fit(data[['fight','kiss']],data['type'])
knn.predict([[18,90]])
#array(['Romance'], dtype=object)
#预测结果为爱情片
plt.scatter(18,90,color = 'r',marker='x',label='Romance')
plt.ylabel('kiss')
plt.xlabel('fight')
plt.text(20,90,'your name',color='r')
我们通过散点图也可以清楚的看到预测的结果。
我们建立更多数据通过训练KNN模型对其预测,并通过散点图直观看他的结果。
data2 = pd.DataFrame(np.random.randn(100,2)*50, columns=['fight','kiss'])
data2['typetest'] = knn.predict(data2)
plt.scatter(data[data['type'] == 'Romance']['fight'],data[data['type'] == 'Romance']['kiss'],color = 'r',label='Romance')
plt.scatter(data[data['type'] == 'Action']['fight'],data[data['type'] == 'Action']['kiss'],color = 'g',label='Action')
plt.grid()
plt.legend()
plt.scatter(data2[data2['typetest'] == 'Romance']['fight'],data2[data2['typetest'] == 'Romance']['kiss'],color='r',marker='x',label='Romance')
plt.scatter(data2[data2['typetest'] == 'Action']['fight'],data2[data2['typetest'] == 'Action']['kiss'],color='g',marker='x',label='Action')
圆点为用于训练模型的6个数据,x点为新建数据。
植物分类
这里我调用sklearn中的datasets调用官方数据,并把它转成pandas的Dataframe。
from sklearn import datasets
iris = datasets.load_iris()
print(iris.keys())
print('数据长度为:%i条'%(len(iris['data'])))
#导入官方数据
print(iris.feature_names)
print(iris.target_names)
#print(iris.target)
print(iris.data[:5])
data = pd.DataFrame(iris.data,columns=iris.feature_names)
data['target'] = iris.target
#150个实例数据
#feature_names 特征分类:萼片长度 萼片宽度 花瓣长度 花瓣宽度
#目标类别 ['setosa' 'versicolor' 'virginica']
"""
dict_keys(['data', 'target', 'target_names', 'DESCR', 'feature_names', 'filename'])
数据长度为:150条
['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']
['setosa' 'versicolor' 'virginica']
[[5.1 3.5 1.4 0.2]
[4.9 3. 1.4 0.2]
[4.7 3.2 1.3 0.2]
[4.6 3.1 1.5 0.2]
[5. 3.6 1.4 0.2]]
"""
#通过merge添加类别名称
ty = pd.DataFrame({'target':[0,1,2],'target_names':['setosa','versicolor','virginica']})
data = pd.merge(data,ty,on='target')
data.head()
影响类别的四个特征[‘sepal length (cm)’, ‘sepal width (cm)’, ‘petal length (cm)’, ‘petal width (cm)’],我们对原数据训练,并预测四个特征依此为[0.2,0.1,0.3,0.4]的植物类别。
knn = neighbors.KNeighborsClassifier() #KNN分类模型
#knn.fit(iris.data, iris.target)
knn.fit(iris.data, data['target_names']) #参数(特征,类别)
pre_data = knn.predict([[0.2,0.1,0.3,0.4]])
pre_data
#array(['setosa'], dtype=object)
预测结果为setosa(刺芒野古草)
上一篇: Ajax简介
下一篇: 如何删除数组中所有指定的item?