感知机算法的python实现+可视化
程序员文章站
2022-07-12 12:02:21
...
最入门的机器学习算法 简单的实现+可视化
上代码:
from numpy import *
import matplotlib.pyplot as plt
def getLinaerSeparatableData(weight,numLines):
a=len(weight)
dataSet=zeros((numLines,a+1))
for i in range(numLines):
b=random.rand(1,a)*20-10
if sum(b*weight)<=0:
dataSet[i]=append(b,-1)
else:
dataSet[i]=append(b,+1)
return dataSet
dataset=getLinaerSeparatableData([3,4],100)
def train(dataSet):
separated=False
numLines=dataSet.shape[0]
numFeatures=(dataSet.shape[1]-1)
w=zeros((1,numFeatures))
i=0
while separated==False and i<numLines:
if(dataSet[i][-1]*sum(w*dataSet[i,0:-1]))<=0:
w=w+dataSet[i][-1]*dataSet[i,0:-1]
separated=False
i=0
else:
i=i+1
return w
w=train(dataset)
a=w[0][0]
b=w[0][1]
x=[-10,10]
y=[(-a/b)*-10,(-a/b)*10]
plt.plot(x,y,color="black")
idx1=where(dataset[:,2]==1)
a1=dataset[idx1,0]
a2=dataset[idx1,1]
idx2=where(dataset[:,2]==-1)
b1=dataset[idx2,0]
b2=dataset[idx2,1]
plt.scatter(a1,a2,color="red")
plt.scatter(b1,b2,color="green")
plt.show()