从零开始学习pytorch之线性回归
程序员文章站
2024-01-11 17:17:16
...
代码实现
代码实现较简单,且注释很完整就不在一一赘述。
import torch
import matplotlib.pyplot as plt
# 随机数种子,保证每次随机数产生是一样的
torch.manual_seed(10)
# 学习率
lr = 0.1
# 创建数据集20个点(x,y)
x = torch.rand(20,1)*10
y = 2*x + (5 + torch.randn(20, 1))
# 初始化可训练指标w和b
w = torch.randn((1), requires_grad=True)
b = torch.zeros((1), requires_grad=True)
# 循环迭代1000次
for iteration in range(1000):
# 前向传播
wx = torch.mul(w, x)
y_pred = torch.add(wx, b)
# 损失函数采用MSE
loss = (0.5*(y-y_pred)**2).mean()
# 反向传播
loss.backward()
# 自更新过程
b.data.sub_(lr*b.grad)
w.data.sub_(lr*w.grad)
# 每20次迭代画一次回归图像
if iteration % 20 == 0:
# 画出散点(x,y)
plt.scatter(x.data.numpy(), y.data.numpy())
# 画线性回归曲线
plt.plot(x.data.numpy(), y_pred.data.numpy(), 'r', lw=5)
# 显示loss值,文字位置在(2,20),字体大小20,颜色为红色
plt.text(2, 20, 'Loss:%.4f'%loss.data.numpy(),fontdict={'size':20, 'color': 'red'})
plt.xlim(1.5, 10)
plt.ylim(8, 28)
plt.title('Iteration: {}\nw:{} b:{}'.format(iteration, w.data.numpy(), b.data.numpy()))
plt.pause(0.5)
# 损失值小于1则停止迭代
if loss.data.numpy()<1:
break