SVC中线性核和高斯核的效果展示
程序员文章站
2024-02-23 20:57:52
...
源码
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
from sklearn import svm
plt.figure(figsize=(5, 4), dpi=140)
plt.subplot(1, 1, 1)
# 生成随机点
# 图1
from sklearn.datasets import make_blobs
X, y = make_blobs(n_samples=200,
n_features=2,
centers=[(1, 1), (2,1.5)],
random_state=4,
shuffle=False,
cluster_std=0.6) # 生成随机点
xmax, xmin = 3 , 0
ymax, ymin = 3 , 0
'''
# 图2
from sklearn.datasets import make_circles
X, y = make_circles(n_samples=100,
shuffle=True, # 打乱
noise=0.05, # 噪声0-1
random_state=4,
factor=0.3 ) # 小圆相对大圆的倍数
xmax, xmin = 1.5 , -1.5
ymax, ymin = 1.5 , -1.5
'''
'''
# 图3
from sklearn.datasets import make_moons
X, y = make_moons(n_samples = 100,
shuffle = True,
noise = 0.1, # 噪声0-1
random_state = 2)
xmax, xmin = 2.5, -1.5
ymax, ymin = 2 , -1.5
'''
plt.title('linear_SVC and rbf_SVC')
plt.xlim(xmin, xmax)
plt.ylim(ymin, ymax)
ax = plt.gca() # gca 代表当前坐标轴,即 'get current axis'
ax.spines['right'].set_color('none') # 隐藏坐标轴
ax.spines['top'].set_color('none')
############高斯核#########
rbf_svc = svm.SVC(kernel='rbf', C=1.5)
rbf_svc.fit(X, y)
temp_X,temp_Y=np.mgrid[xmin:xmax:200j, ymin:ymax:200j ] #生成网络采样点
grid_test=np.stack((temp_X.flat,temp_Y.flat) ,axis=1) #测试点
grid_hat = rbf_svc.predict(grid_test) # 预测分类值
grid_hat = grid_hat.reshape(temp_X.shape) # 使之与输入的形状相同
cm_light=matplotlib.colors.ListedColormap(['#A0FFA0', '#FFA0A0', '#A0A0FF'])
plt.pcolormesh(temp_X, temp_Y, grid_hat, cmap=cm_light) # 预测值的显示
############训练点的散点图############
plt.scatter(X[y==0][:, 0], X[y==0][:, 1], marker='o')
plt.scatter(X[y==1][:, 0], X[y==1][:, 1], marker='s')
############线性核############
linear_svc = svm.SVC(kernel='linear', C=1.5)
linear_svc.fit(X, y)
w = linear_svc.coef_
b = linear_svc.intercept_
# print(w)
# print(b)
line_x = np.linspace(xmin, xmax,30)
line_y = -(w[0,0]*line_x + b)/w[0,1]
plt.plot(line_x,line_y)
plt.show()
效果1
效果2
结论
高斯核牛逼!
参考
上一篇: python3 数据库的常用操作
推荐阅读