线性回归的原理及其实现
程序员文章站
2022-03-30 22:58:28
...
如何求得正规方程的w呢?
1. 不妨来随机生成一些数据来测试一下。
import numpy as np
import matplotlib.pyplot as plt
X = 2*np.random.rand(100,1)
Y = 4+3*X+np.random.randn(100,1) #y=4+3x+高斯噪声
- 数据可视化:
plt.plot(X, Y, "b.")
plt.xlabel("$x_1$", fontsize=18)
plt.ylabel("$Y$", rotation=0, fontsize=18)
plt.axis([0, 2, 0, 15])
plt.show()
3. 下面将进行正规方程的内容
将使用numpy库中的线性代数的模块(np.linalg)中的求逆(inv()),内积运算(dot())
X_b = np.c_[np.ones((100,1)),X]
#print(X_b)
W_best = np.linalg.inv(X_b.T.dot(X_b)).dot(X_b.T).dot(Y) #正规方程
print(W_best) #打印 y=4+3x+噪声,这里已经很接近了
4. 带入x=0;2 ,求预测值
X_new = np.array([[0],[2]]) #带入x=0;x=2
X_new_b = np.c_[np.ones((2,1)), X_new]
y_predict = X_new_b.dot(W_best)
y_predict #查看x=0;2对应的预测y值
5. 可视化回归方程
plt.plot(X_new, y_predict, "r-", linewidth=2, label="Predictions")
plt.plot(X, Y, "b.")
plt.xlabel("$x_1$", fontsize=18)
plt.ylabel("$Y$", rotation=0, fontsize=18)
plt.legend(loc="upper left", fontsize=14)
plt.axis([0, 2, 0, 15])
plt.show()
6. 使用sklearn库中linear_model模型求解
from sklearn.linear_model import LinearRegression
lin_reg = LinearRegression()
lin_reg.fit(X,Y)
lin_reg.intercept_,lin_reg.coef_
7. lin_reg.predict(X_new) 与y_predict的输出值是一样的。
上一篇: HashMap的四种遍历
下一篇: 11.2 二分查找的原理及其实现