欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

从零开始学习pytorch之线性回归

程序员文章站 2024-01-11 17:17:16
...

代码实现

代码实现较简单,且注释很完整就不在一一赘述。

import torch
import matplotlib.pyplot as plt
# 随机数种子,保证每次随机数产生是一样的
torch.manual_seed(10)
# 学习率
lr = 0.1
# 创建数据集20个点(x,y)
x = torch.rand(20,1)*10
y = 2*x + (5 + torch.randn(20, 1))
# 初始化可训练指标w和b
w = torch.randn((1), requires_grad=True)
b = torch.zeros((1), requires_grad=True)

# 循环迭代1000次
for iteration in range(1000):
    # 前向传播
    wx = torch.mul(w, x)
    y_pred = torch.add(wx, b)
    # 损失函数采用MSE
    loss = (0.5*(y-y_pred)**2).mean()
    # 反向传播
    loss.backward()
    # 自更新过程
    b.data.sub_(lr*b.grad)
    w.data.sub_(lr*w.grad)
    # 每20次迭代画一次回归图像
    if iteration % 20 == 0:
        # 画出散点(x,y)
        plt.scatter(x.data.numpy(), y.data.numpy())
        # 画线性回归曲线
        plt.plot(x.data.numpy(), y_pred.data.numpy(), 'r', lw=5)
        # 显示loss值,文字位置在(2,20),字体大小20,颜色为红色
        plt.text(2, 20, 'Loss:%.4f'%loss.data.numpy(),fontdict={'size':20, 'color': 'red'})
        plt.xlim(1.5, 10)
        plt.ylim(8, 28)
        plt.title('Iteration: {}\nw:{} b:{}'.format(iteration, w.data.numpy(), b.data.numpy()))
        plt.pause(0.5)
        # 损失值小于1则停止迭代
        if loss.data.numpy()<1:
            break

运行结果

从零开始学习pytorch之线性回归