sklearn如何获得model里的参数
程序员文章站
2022-06-12 22:48:22
...
from __future__ import division
import time
import pickle
import numpy as np
from sklearn.model_selection import GridSearchCV
from sklearn.kernel_ridge import KernelRidge
import matplotlib.pyplot as plt
from sklearn.externals import joblib
rng = np.random.RandomState(0)
def gener(sta, end, num): # 生成y=x^2的测试集
# 随机干扰因子
sampleNo = num
mu = 0.01
sigma = 0.2
np.random.seed(0)
s = np.random.normal(mu, sigma, sampleNo)
X = np.linspace(sta, end, num) # 在返回(-1, 1)范围内的等差序列
pix = np.pi * X
Y = np.sin(pix) / pix + 0.1 * X + s
x = X.reshape(-1, 1)
y = Y
print(y.shape)
return x, y, num
X, y, num = gener(-3,3,50)
X_plot, y_plot, NUM = gener(-3,3,1000)
# #############################################################################
# Fit regression model
train_size = 40
kr = GridSearchCV(KernelRidge(kernel='rbf', gamma=0.1), cv=5,
param_grid={"alpha": [1e0, 0.1, 1e-2, 1e-3],
"gamma": np.logspace(-2, 2, 5)})
t0 = time.time()
kr.fit(X[:train_size], y[:train_size])
kr_fit = time.time() - t0
print("KRR complexity and bandwidth selected and model fitted in %.3f s"
% kr_fit)
t0 = time.time()
y_kr = kr.predict(X_plot)
kr_predict = time.time() - t0
print("KRR prediction for %d inputs in %.3f s"
% (X_plot.shape[0], kr_predict))
# #############################################################################
# Look at the results
plt.scatter(X[:100], y[:100], c='k', label='data', zorder=1,
edgecolors=(0, 0, 0))
plt.plot(X_plot, y_kr, c='g',
label='KRR (fit: %.3fs, predict: %.3fs)' % (kr_fit, kr_predict))
plt.xlabel('data')
plt.ylabel('target')
plt.legend()
plt.show()
print(dir(kr))
print(dir(kr.best_estimator_))
print(kr.best_estimator_.get_params())
coef = kr.best_estimator_.dual_coef_
def ridge(X, x, NUM, num):
x2 = x * x # 样本集的参数 逐个求平方
X2 = X * X # 测试集的参数 逐个求平方
h = 3 # 高斯核的带宽
hh = 2 * h * h
temp = np.tile(X2, num) + np.tile(np.matrix(x2).T, (NUM, 1)) - 2 * X * x.T
k = np.exp(-temp)
return k
k = ridge(X_plot, X[:40], NUM, 40)
y = np.matrix(k)*np.matrix(coef).T
plt.plot(X_plot,y,label='coef')
plt.legend()
plt.show()
print(coef)
这里用的KRR,所以只有一组系数,其他复杂模型不太了解
kr.best_estimator_.dual_coef_
上一篇: Python3 模块
下一篇: 20165218 结对编程练习-四则运算
推荐阅读
-
sklearn常用的API参数:sklearn.linear_model.LinearRegression
-
sklearn常用的API参数解析:sklearn.linear_model.LinearRegression
-
如何在postman里为类型为Edm.DateTime的OData参数指定正确格式的值 postmanSAPSAP云平台SAP Cloud PlatformABAP
-
获得的数据如何存不数据库(mysql)里,与数据库端口?请高人指点,多谢
-
sklearn如何获得model里的参数
-
thinkphp5 - thinkPHP 是不是 D() model.class 和 Controller.class 里数据库参数的优先级问题.
-
thinkphp3.2.3的Model目录里是存放什么文件的,是如何与控制器和视图文件相关联的?
-
thinkphp3.2.3的Model目录里是存放什么文件的,是如何与控制器和视图文件相关联的?
-
java 反射 如何获得子类继承的父类泛型参数
-
mysql如何获得http用get方式传过来的参数