欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

机器学习算法原理总结系列---算法基础之(4)最邻近规则分类(K-Nearest Neighbor)

程序员文章站 2024-01-25 16:04:47
...

最邻近规则分类(K-NearestNeighbor)简称KNN,虽然叫NN,其实并不是什么神经网络算法,而是基于计算距离的监督学习分类算法模型而已。

一、原理详解

  1. 综述
    1.1 Cover和Hart在1968年提出了最初的邻近算法
    1.2 分类(classification)算法
    1.3 输入基于实例的学习(instance-based learning), 懒惰学习(lazy learning)

  2. kNN算法的核心思想:
    如果一个样本在特征空间中的k个最相邻的样本中的大多数属于某一个类别,则该样本也属于这个类别,并具有这个类别上样本的特性。该方法在确定分类决策上只依据最邻近的一个或者几个样本的类别来决定待分样本所属的类别。 kNN方法在类别决策时,只与极少量的相邻样本有关。由于kNN方法主要靠周围有限的邻近的样本,而不是靠判别类域的方法来确定所属类别的,因此对于类域的交叉或重叠较多的待分样本集来说,kNN方法较其他方法更为适合。

定义如果太抽象,那直接看一个经典具体的例子,4张图让你豁然开朗。
机器学习算法原理总结系列---算法基础之(4)最邻近规则分类(K-Nearest Neighbor)
任务:图中显示了6部电影的打斗和接吻镜头数。假如有一部未看过的电影,如何确定它是爱情片还是动作片呢?
咱们先来做个替换,换成相对科学一点的表示方式:
机器学习算法原理总结系列---算法基础之(4)最邻近规则分类(K-Nearest Neighbor)
这样有A-F个样本,每个样本都有两个特征值,将其映射在直角坐标系上,标签分类有两种,一个是爱情片,一个是动作片。
好的,这样我就可以进一步画出坐标图。
机器学习算法原理总结系列---算法基础之(4)最邻近规则分类(K-Nearest Neighbor)

import matplotlib.pyplot as plt

plt.title("Use the number of fighting and kissing shots to categorize the movie")
plt.xlim(xmax=110, xmin=-20)
plt.ylim(ymax=110, ymin=-20)
plt.xlabel("The number of shots that appear in the movie")
plt.ylabel("The number of kissing shots appearing in the movie")
plt.annotate("Unknown movie", xy=(18, 90), xytext=(35, 89), arrowprops=dict(facecolor='black', shrink=0.1))
plt.annotate("A", xy=(3, 104))
plt.annotate("B", xy=(2, 100))
plt.annotate("C", xy=(1, 81))
plt.annotate("D", xy=(101, 10))
plt.annotate("E", xy=(99, 5))
plt.annotate("F", xy=(98, 2))

x = [3, 2, 1, 101, 99, 98, 18]
y = [104, 100, 81, 10, 5, 2, 90]
plt.plot(x, y, 'ro')
plt.show()

因此更加清楚的可以得出:
机器学习算法原理总结系列---算法基础之(4)最邻近规则分类(K-Nearest Neighbor)

若此时假定k=3,那么A、B、C三种电影的距离离未知电影的距离最近,并且这三部都是爱情片,所以我们预测,未知影片也是一部爱情片。
若k=4,那么就采用投票法则,少数服从多数,依然可以判定是爱情片。

故经过这里例子表述,总结算法详述为:

  1. 步骤:

    • 为了判断未知实例的类别,以所有已知类别的实例作为参照
    • 选择参数K
    • 计算未知实例与所有已知实例的距离
    • 选择最近K个已知实例
    • 根据少数服从多数的投票法则(majority-voting),让未知实例归类为K个最邻近样本中最多数的类别
  2. 细节:

    • 关于K
    • 关于距离的衡量方法:
      Euclidean Distance 定义
      机器学习算法原理总结系列---算法基础之(4)最邻近规则分类(K-Nearest Neighbor)
      机器学习算法原理总结系列---算法基础之(4)最邻近规则分类(K-Nearest Neighbor)
      其他距离衡量:余弦值(cos), 相关度 (correlation), 曼哈顿距离 (Manhattan distance)
  3. 算法优缺点:

    1. 算法优点

      • 简单
      • 易于理解
      • 容易实现
      • 通过对K的选择可具备丢噪音数据的健壮性
    2. 算法缺点
      机器学习算法原理总结系列---算法基础之(4)最邻近规则分类(K-Nearest Neighbor)

      • 需要大量空间储存所有已知实例
      • 算法复杂度高(需要比较所有已知实例与要分类的实例)
      • 当其样本分布不平衡时,比如其中一类样本过大(实例数量过多)占主导的时候,新的未知实例容易被归类为这个主导样本,因为这类样本实例的数量过大,但这个新的未知实例实际并木接近目标样本

二、原生代码实现

数据集依然iris数据集,下载地址为:

IRIS_TRAINING_URL = "http://download.tensorflow.org/data/iris_training.csv"
IRIS_TEST_URL = "http://download.tensorflow.org/data/iris_test.csv"

全部代码:

import csv
import math

import operator


def load_dataset(train_data_file, test_data_file):
    """
    1、数据来源是iris数据集,一共150例,其中分为3类:iris-setosa, iris-versicolor,iris-virginica
    2、在监督学习中使用标记 0 --->  iris-setosa
                         1 --->  iris-versicolor
                         2 --->  iris-virginica
    3、训练集一共选取120例,其中均匀分布着三类标记。
    4、测试集选取30例
    """
    iris_train_data = []
    with open(train_data_file, 'r') as reader:
        lines = csv.reader(reader)
        for k, row in enumerate(lines):
            if k == 0:
                continue
            else:
                for item in range(len(row) - 1):
                    row[item] = float(row[item])
            iris_train_data.append(row)

    iris_test_data = []
    with open(test_data_file, 'r') as reader:
        lines = csv.reader(reader)
        for k, row in enumerate(lines):
            if k == 0:
                continue
            else:
                for item in range(len(row) - 1):
                    row[item] = float(row[item])
            iris_test_data.append(row)

    return iris_train_data, iris_test_data


def euclidean_metric(instance_1, instance_2, length):
    distance = 0
    for x in range(length):
        distance += pow(instance_1[x] - instance_2[x], 2)
    return math.sqrt(distance)


def get_neighbors(train_set, test_instance, k):
    distances = []
    length = len(test_instance) - 1
    for x in range(len(train_set)):
        dist = euclidean_metric(test_instance, train_set[x], length)
        distances.append((train_set[x], dist))
    distances.sort(key=operator.itemgetter(1))
    neighbors = []
    for x in range(k):
        neighbors.append(distances[x][0])
    return neighbors


def get_response(neighbors):
    class_votes = {}
    for x in range(len(neighbors)):
        response = neighbors[x][-1]
        if response in class_votes:
            class_votes[response] += 1
        else:
            class_votes[response] = 1
    sorted_votes = sorted(class_votes.items(), key=operator.itemgetter(1), reverse=True)
    return sorted_votes[0][0]


def get_accuracy(test_set, predictions):
    correct = 0
    for x in range(len(test_set)):
        if test_set[x][-1] == predictions[x]:
            correct += 1
    return (correct / float(len(test_set))) * 100.0


if __name__ == '__main__':
    # prepare data
    iris_train_data, iris_test_data = load_dataset(r'iris_training.csv', r'iris_test.csv')
    # print(iris_train_data)
    # print(iris_test_data)

    # generate predictions
    predictions = []
    k = 3
    for x in range(len(iris_test_data)):
        neighbors = get_neighbors(iris_train_data, iris_test_data[x], k)
        result = get_response(neighbors)
        predictions.append(result)
        print('> predicted=' + repr(result) + ', actual=' + repr(iris_test_data[x][-1]))
    accuracy = get_accuracy(iris_test_data, predictions)
    print('Accuracy: ' + repr(accuracy) + '%')

机器学习算法原理总结系列---算法基础之(4)最邻近规则分类(K-Nearest Neighbor)
识别率到达了96.7%,分类效果已经非常优秀了。简单的神经网络NN算法都没有达到这样的识别率。所以说对于特定的数据集。并不一定什么都用深度神经网络。简单的机器学习基础算法都可以有很好的分类效果。

三、scikit-learn代码实现

from sklearn import neighbors
from sklearn import datasets

knn = neighbors.KNeighborsClassifier()

iris = datasets.load_iris()

print(iris)

knn.fit(iris.data, iris.target)

predictedLabel = knn.predict([[0.1, 0.2, 0.3, 0.4]])

print(predictedLabel)

代码的输出为【0】。大家自行研究。

scikit-learn库对整个KNN算法进行封装,基本就是黑盒操作,我们只需要创建KNN实例,然后把数据fit,之后调用实例的predict方法,就可以对测试数据进行分类预测了。

相关标签: 机器学习 算法