欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

感知机算法的python实现+可视化

程序员文章站 2022-07-12 12:02:21
...

最入门的机器学习算法 简单的实现+可视化

上代码:


from numpy import *
import matplotlib.pyplot as plt
def getLinaerSeparatableData(weight,numLines):
    a=len(weight)
    dataSet=zeros((numLines,a+1))
    for i in range(numLines):
        b=random.rand(1,a)*20-10

        if sum(b*weight)<=0:
            dataSet[i]=append(b,-1)
        else:
            dataSet[i]=append(b,+1)
    return dataSet

dataset=getLinaerSeparatableData([3,4],100)

def train(dataSet):
    separated=False
    numLines=dataSet.shape[0]
    numFeatures=(dataSet.shape[1]-1)
    w=zeros((1,numFeatures))
    i=0
    while separated==False and i<numLines:
        if(dataSet[i][-1]*sum(w*dataSet[i,0:-1]))<=0:
            w=w+dataSet[i][-1]*dataSet[i,0:-1]
            separated=False
            i=0
        else:
            i=i+1
    return w
w=train(dataset)
a=w[0][0]
b=w[0][1]
x=[-10,10]
y=[(-a/b)*-10,(-a/b)*10]
plt.plot(x,y,color="black")
idx1=where(dataset[:,2]==1)
a1=dataset[idx1,0]
a2=dataset[idx1,1]
idx2=where(dataset[:,2]==-1)
b1=dataset[idx2,0]
b2=dataset[idx2,1]
plt.scatter(a1,a2,color="red")
plt.scatter(b1,b2,color="green")
plt.show()




相关标签: 感知机