上课笔记篇---用Pytorch实现简单的线性回归
程序员文章站
2022-05-26 20:43:50
...
@用Pytorch实现简单线性回归
用Pytorch实现简单的线性回归
之前关于深学啥的,看了一些书和论文,但真正上手,还没做过。现在项目有需求,就在微信上报了个班,开始学习Pytorch上手的东西。这里做一个记录,以备自己以后查看。
直接上代码吧
这个代码是教程上的,版权归教程所有,如果侵权,我会马上删掉。
.
// An highlighted block
lr=0.1 #学习率
x=torch.rand(20,1) * 10 #随机生成的X
y=2*x+(5+torch.rand(20,1)) #在X基础上,添加随机量,生成Y
w=torch.randn((1),requires_grad=True) #随机给个W,后面的参数是自动求导时用的
b=torch.zeros((1),requires_grad=True) #b先给的是零
for iteration in range(1000):
wx = torch.mul(w,x) # 给所有的X先乘个w
y_pred = torch.add(wx, b) #然后根据y=wx+b求出预测的y
#根据Y和预测的Y_Pred求出两个之间的平均误差
loss=(0.5 * (y-y_pred) ** 2).mean()
#再根据这个误差进行反向传播,也就是梯度下降
loss.backward()
# 根据梯度下降获得的值更新w和b
b.data.sub_(lr * b.grad)
w.data.sub_(lr * w.grad)
#这里主要是绘制图像,将散点和求出的直线绘制出来
if iteration % 20 == 0:
# 原代码中缺了下面这一句,导致画图会出问题
plt.clf()
plt.scatter(x.data.numpy(),y.data.numpy())
plt.plot(x.data.numpy(),y_pred.data.numpy(),"r-",lw=5)
plt.text(2,20,'Loss*%.4f'% loss.data.numpy(),fontdict={'size':20,'color':'red'})
plt.xlim(1.5,10)
plt.ylim(8,28)
plt.title("Iteration : {}\n w: {} b: {}".format(iteration,w.data.numpy(),b.data.numpy()))
plt.pause(0.5)
# 如果误差小到一定范围,则退出循环
if loss.data.numpy() < 1:
break
我是代码说明
上面这段代码,实现的是对随机生成的一堆散点实现曲线的拟合。具体的解释本来想写在这呢,后来还是觉得直接写在代码注释中好一些。
上一篇: 线性回归模型使用pytorch的简洁实现
下一篇: MARKDOWN操作