欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

数据结构--K近邻算法实现(python)

程序员文章站 2022-05-09 21:37:57
引言 学了两天python基础之后,以及简单了解了一下简单的机器学习算法,以及py的第三方库numpy之后,python做数值计算相比java果然无比的方便通过所学知识简单地实现了一下最基础的K近邻分类算法,在此写篇博客记录一下此次学习。算法原理 初中时我们就学过两点间的距离公式,横纵坐标越相近的点,其类别也必然相似,举个极端的例子,相同事物的坐标必定相同,当然反之可能并不成立。因此如果我们可以将一个事物的特征抽象成一个二维或者多维的坐标,运用两点间的距离公式,就可以判断某个未知点的类别。...

引言

 学了两天python基础之后,以及简单了解了一下简单的机器学习算法,以及py的第三方库numpy之后,python做数值计算相比java果然无比的方便
通过所学知识简单地实现了一下最基础的K近邻分类算法,在此写篇博客记录一下此次学习。 

算法原理

 初中时我们就学过两点间的距离公式,横纵坐标越相近的点,其类别也必然相似,举个极端的例子,相同事物的坐标必定相同,当然反之可能并不成立。
因此如果我们可以将一个事物的特征抽象成一个二维或者多维的坐标,运用两点间的距离公式,就可以判断某个未知点的类别。 

AB=(x1x2)2+(y1y2)2 |AB| = \sqrt{(x_1-x_2)^2+(y_1-y_2)^2}

两点间的距离公式
d(p,q)=d(q,p)=(q1p1)2+(q2p2)2+(q3p3)2+...+(qnpn)2=r=1n(qipi)2 d(p,q) = d (q,p) = \sqrt{(q_1-p_1)^2+(q_2-p_2)^2+(q_3-p_3)^2+...+(q_n-p_n)^2} = \sqrt{\sum_{r=1}^n(q_i-p_i)^2}
类推上面的二维向量的距离公式,得多维向量的距离公式。

代码实现

import numpy as np import operator import matplotlib.pyplot as plt def createData(): print("初始化标准数据") group = np.array([[2, 100], [8, 88], [100, 5], [105, 7]]) lables = ['I片', 'I片', 'II片', 'II片'] return group, lables # inX:待分类点坐标 # dataSet:初始点坐标形成的list # labels:类别list # k: 用于标识有多少类,方便循环 def classify(input, dataSet, labels, k): # numpy中的shape方法用于计算形状 eg: dataSet: 4*2 # print(dataSet.shape) dataSetSize = dataSet.shape[0] # numpy中的tile方法,用于对矩阵进行填充 # 将inX矩阵填充至与dataSet矩阵相同规模,后相减 diffMat = np.tile(input, (dataSetSize, 1)) - dataSet # 平方 sqDiffMat = diffMat ** 2 # 求和 sqDistance = sqDiffMat.sum(axis=1) # 开方 distance = sqDistance ** 0.5 # argsort()方法进行直接排序 sortDist = distance.argsort() classCount = {} for i in range(k): # 取出前k个元素的类别 voteIlabel = labels[sortDist[i]] # dict.get(key,default=None),字典的get()方法,返回指定键的值,如果值不在字典中返回默认值。 # 计算类别次数 classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1 # 排序 sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True) # 返回次数最多的类别,即所要分类的类别 return sortedClassCount[0][0] def draw(L): print(L) # 设置坐标图标题 plt.title("SimpleKNN") # 设置横轴 plt.xlabel("x") # 设置纵轴 plt.ylabel("y") # 描点 I类为红点,II类为绿点,待分类为蓝点,方便直观查看,检验分类效果 plt.plot([1, 5], [101, 89], 'ro') plt.plot([108, 115], [5, 8], 'go') plt.plot([L[0]], [L[1]], 'bo') plt.show() #初始化数据集 group, labels = createData() # 测试集 x = int(input("请输入待分类点横坐标x=")) y = int(input("请输入待分类点纵坐标y=")) test = [] test.append(x) test.append(y) print(type(test)) # kNN分类 test_class = classify(test, group, labels, 3) # 打印分类结果 print(test_class) # 绘图 draw(test) 

运行截图

数据结构--K近邻算法实现(python)

数据结构--K近邻算法实现(python)
数据结构--K近邻算法实现(python)
数据结构--K近邻算法实现(python)

本文地址:https://blog.csdn.net/qq_42338771/article/details/108243956