欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

python实现KNN分类算法

程序员文章站 2022-04-08 08:12:27
...

K近邻算法,简称KNN算法,作用就是分类
三大要素

  • 距离度量(常用利用欧式距离和曼哈顿距离比较远近)
  • K值选择(最近的K个邻居)
  • 决策规则(少数服从多数)
    python实现KNN分类算法
    已知训练集
    python实现KNN分类算法
import numpy
import matplotlib.pyplot as plt

'''
已知训练集和训练集类别、测试集
1.测试集广播和训练集一样的规格
2.计算距离
3.对距离列表排序返回最近的K个点的下标
4.有几个类别就设几个标签用来统计,循环排序列表,对类别判断,少数服从多数
5.数据可视化
'''

if __name__ == '__main__':
    pass
    k = 5 # 已知  奇数
    trainLabel = ["B", "B", "A", "A", "B", "B", "A", "B", "A", "A", "B", "A"]   # 训练集标签类别
    traindata = numpy.loadtxt("knndata1.txt", delimiter=",")    # 训练集中的坐标数据
    testdata = numpy.array([0.7, 0.7])  # 测试点

    # print(traindata)
    # print(testdata)

    # 测试点广播成和训练集相同的规格
    testdata = numpy.tile(testdata, (traindata.shape[0], 1))    # 不广播也行,为了画图方便还是广播
    # print(testdata)

    # 计算距离
    manDist = numpy.sum(abs(testdata - traindata), axis=1)  # 曼哈顿
    # print(manDist)
    eucDist = numpy.sum((testdata - traindata)**2, axis=1) ** 0.5  # 欧式距离  一般用欧式距离
    # print(eucDist)

    # 排序
    sortIndex = numpy.argsort(eucDist)
    # sortIndex = numpy.argsort(manDist)
    # print(sortIndex)

    a = b = 0
    for i in sortIndex[0:k]:
        if trainLabel[i] == "A":
            a += 1
        else:
            b += 1

    print(a, b)

    print("I am A") if a>b else print("I am B")


    # 数据可视化
    plt.figure()
    plt.title("WGS")
    for i in range(traindata.shape[0]): # 有几个点循环几次
        if trainLabel[i] == "A":
            plt.scatter(traindata[i, 0], traindata[i, 1], c="r")
        else:
            plt.scatter(traindata[i, 0], traindata[i, 1], c="g")
    if a>b:
        plt.scatter(testdata[0, 0], testdata[0, 1], c="r", marker="*", label="Test point")
    else:
        plt.scatter(testdata[0, 0], testdata[0, 1], c="g", marker="*", label="Test point")
    plt.grid(True)
    plt.legend(bbox_to_anchor=(0, 1.1), loc=2, borderaxespad=0)     # https://blog.csdn.net/Poul_henry/article/details/82533569
    # plt.show()
小案例

python实现KNN分类算法

import numpy
import matplotlib.pyplot as plt

if __name__ == '__main__':
    pass
    # 训练集
    traindata = {
        "California Max": [3, 104, "爱情片"],
        "He's Not Really into Dudes": [2, 100, "爱情片"],
        "Beautiful": [1, 81, "爱情片"],
        "Kevin Longblande": [101, 10, "动作片"],
        "Robo Slayer": [99, 5, "动作片"],
        "Amped": [98, 2, "动作片"],
    }
    # 测试集
    testdata = {"Hi boy": [18, 90, "未知"]}

    # 提取数据
    tranlable = []  # 标签
    getList = []  # 训练集
    x1 = testdata["Hi boy"]
    x2 = testdata["Hi boy"]
    new_test = [x1[0], x2[1]]
    for i in traindata.keys():
        temp = traindata[i]
        getList.append(temp[0])  # x
        getList.append(temp[1])  # y
        tranlable.append(temp[2])  # 标签

    x = [i for i in getList[0::2]]
    y = [i for i in getList[1::2]]
    new_traindata = numpy.c_[x, y]

    # 测试集规格和训练集一致
    new_test = numpy.tile(new_test, (new_traindata.shape[0], 1))

    # 欧式距离
    distance = numpy.sqrt(numpy.sum((new_test - new_traindata) ** 2, axis=1))

    # 排序返回原下标
    sortList = numpy.argsort(distance)

    # 分类
    k = 3
    a = b = 0
    for i in sortList[0:k]:
        if tranlable[i] == "爱情片":
            a += 1
        else:
            b += 1
    print("爱情片") if a>b else print("动作片")

    # 数据可视化
    for i in range(new_traindata.shape[0]):
        if tranlable[i] == "爱情片":
            plt.scatter(new_traindata[i, 0], new_traindata[i, 1], c="r")
        else:
            plt.scatter(new_traindata[i, 0], new_traindata[i, 1], c="g")
    if a>b:
        plt.scatter(new_test[0, 0], new_test[0, 1], c="r", marker="+")
    else:
        plt.scatter(new_test[0, 0], new_test[0, 1], c="g", marker="+")
    plt.show()
相关标签: 算法