pytorch入门学习(三)-----线性回归模型
程序员文章站
2022-06-11 22:41:45
...
在运行之前需要matplotlib包,快速方便的方法可以看这个pycharm导入matplotlib
可以尝试设置不同的学习率,创建不同大小的训练数据集,以及不同的斜率和截距
import torch
import matplotlib.pyplot as plt
torch.manual_seed(10)
# 设定学习率为0.01
lr = 0.01
best_loss = float("inf")
# 创建训练数据集,
# 即制造出接近 y = 4x + 6 的数据集,加上torch.randn()制造噪声
x = torch.rand(100, 1) * 10
y = 4*x + (6 + torch.randn(100, 1))
# 构建线性回归参数
w = torch.randn((1), requires_grad=True)
b = torch.zeros((1), requires_grad=True)
for iteration in range(1000):
# 前向传播
wx = torch.mul(w, x)
y_pred = torch.add(wx, b)
# 计算 MSE loss
loss = (0.5 * (y - y_pred) ** 2).mean()
# 反向传播
loss.backward()
# 保留最低损失,最佳斜率及截距
current_loss = loss.item()
if current_loss < best_loss:
best_loss = current_loss
best_w = w
best_b = b
# 绘图
if loss.data.numpy() < 2:
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), y_pred.data.numpy(), 'r-', lw=5)
plt.text(2, 20, 'loss=%.4f' % loss.data.numpy(), fontdict={'size': 20, 'color': 'red'})
plt.xlim(1.5, 10)
plt.ylim(8, 28)
plt.title("Iteration: {}\nw: {} b: {}".format(iteration, w.data.numpy(), b.data.numpy()))
plt.pause(0.5)
if loss.data.numpy() < 0.6:
break
# 更新参数
b.data.sub_(lr * b.grad)
w.data.sub_(lr * w.grad)
print(best_loss)
推荐阅读
-
PyTorch搭建一维线性回归模型(二)
-
PyTorch搭建多项式回归模型(三)
-
[pysyft-002]联邦学习pysyft从入门到精通--三个节点训练一个线性分类器
-
Task1.0 学习笔记线性回归;Softmax与分类模型、多层感知机
-
动手学深度学习PyTorch-task1(线性回归;Softmax与分类模型;多层感知机)
-
Task1.0 学习笔记线性回归;Softmax与分类模型、多层感知机
-
《动手学深度学习》task1——线性回归、softmax与分类模型,多层感知机笔记
-
深度学习线性回归(pytorch)
-
PyTorch搭建多项式回归模型(三)
-
PyTorch搭建一维线性回归模型(二)