欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

统计学习方法(1)——感知机实现(学习算法的原始形式)

程序员文章站 2022-07-11 12:22:22
...

最近在看李航老师的《统计学习方法》,打算实现每一个算法。置于算法的具体介绍和讲解,此处不做详细介绍,需要了解算法内容的同学,可以看一下书上的对应章节。

这次实现数据参照书中第二章例2.1
实现了感知机学习算法的原始形式

#!/usr/bin/env python
# -*- coding:utf-8 -*-
# author= icesun

import numpy as np

# 感知机学习算法的原始形式
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# author= icesun

import numpy as np

# 感知机学习算法的原始形式
class Perceptron():
    def __init__(self, x, y, lr):
        # 数据样例个数
        self.len = len(x)

        self.x = x
        self.y = y
        self.lr = lr
        # 随机初始化参数
        # self.w = np.random.random(x[0].shape)
        # self.b = np.random.random(y[0].shape)
        # 将参数初始化为0
        self.w = np.zeros(shape=(2))
        self.b = np.zeros(shape=(1,1))

    def print_w_b(self):
        print("w:", self.w)
        print("b:", self.b)

    def train(self):
        while True:
            # t记录正确分类的样本数
            t = 0
            # 对于样本点改成随机抽取,但是此处我选择直接按序遍历
            for x, y in zip(self.x, self.y):
                print(x, '--', y)
                if np.matmul(y, (np.matmul(x, self.w) + self.b)) <= 0:
                    # 更新参数
                    self.w = self.w + self.lr * y * x
                    self.b = self.b + self.lr * y
                    break
                else:
                    t += 1
            if t == self.len:
                break

if __name__ == '__main__':
    x = np.array([3, 3, 4, 3, 1, 1,1,1]).reshape(4,2)
    y = np.array([1, 1, -1]).reshape(3,1)
    perceptron = Perceptron(x, y, 1)
    perceptron.train()
    perceptron.print_w_b()
训练输出结果:
[3 3] -- [1]
[3 3] -- [1]
[4 3] -- [1]
[1 1] -- [-1]
[3 3] -- [1]
[4 3] -- [1]
[1 1] -- [-1]
[3 3] -- [1]
[4 3] -- [1]
[1 1] -- [-1]
[3 3] -- [1]
[3 3] -- [1]
[4 3] -- [1]
[1 1] -- [-1]
[3 3] -- [1]
[4 3] -- [1]
[1 1] -- [-1]
[3 3] -- [1]
[4 3] -- [1]
[1 1] -- [-1]
w: [1. 1.]
b: [[-3.]]