欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

线性回归的原理及其实现

程序员文章站 2022-03-30 22:58:28
...

线性回归的原理及其实现
线性回归的原理及其实现
线性回归的原理及其实现
线性回归的原理及其实现
线性回归的原理及其实现
如何求得正规方程的w呢?

1. 不妨来随机生成一些数据来测试一下。

import numpy as np
import matplotlib.pyplot as plt
 
X = 2*np.random.rand(100,1)
Y = 4+3*X+np.random.randn(100,1)   #y=4+3x+高斯噪声
  1. 数据可视化:
plt.plot(X, Y, "b.")
plt.xlabel("$x_1$", fontsize=18)
plt.ylabel("$Y$", rotation=0, fontsize=18)
plt.axis([0, 2, 0, 15])
plt.show()

线性回归的原理及其实现

3. 下面将进行正规方程的内容

将使用numpy库中的线性代数的模块(np.linalg)中的求逆(inv()),内积运算(dot())

X_b = np.c_[np.ones((100,1)),X] 
#print(X_b)
W_best = np.linalg.inv(X_b.T.dot(X_b)).dot(X_b.T).dot(Y) #正规方程
print(W_best) #打印 y=4+3x+噪声,这里已经很接近了

线性回归的原理及其实现

4. 带入x=0;2 ,求预测值

X_new = np.array([[0],[2]])  #带入x=0;x=2
X_new_b = np.c_[np.ones((2,1)), X_new]
y_predict = X_new_b.dot(W_best)
y_predict   #查看x=0;2对应的预测y值

线性回归的原理及其实现

5. 可视化回归方程

plt.plot(X_new, y_predict, "r-", linewidth=2, label="Predictions")
plt.plot(X, Y, "b.")
plt.xlabel("$x_1$", fontsize=18)
plt.ylabel("$Y$", rotation=0, fontsize=18)
plt.legend(loc="upper left", fontsize=14)
plt.axis([0, 2, 0, 15])
plt.show()

线性回归的原理及其实现

6. 使用sklearn库中linear_model模型求解

from sklearn.linear_model import LinearRegression
lin_reg = LinearRegression()
lin_reg.fit(X,Y)
lin_reg.intercept_,lin_reg.coef_

线性回归的原理及其实现
线性回归的原理及其实现

7. lin_reg.predict(X_new) 与y_predict的输出值是一样的。

线性回归的原理及其实现
线性回归的原理及其实现

相关标签: 机器学习