欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

pytorch 入门学习 实现线性回归-5

程序员文章站 2022-06-11 22:46:33
...

pytorch 入门学习实现线性回归

使用pytorch实现线性回归

import  numpy as np
import matplotlib.pyplot as plt
import torch

#part1 prepare dataset
x_data = torch.Tensor([[1.0],[2.0],[3.0]]) #矩阵的形式
y_data = torch.Tensor([[2.0],[4.0],[6.0]])

#part2 design model using class
class LinearModel(torch.nn.Module):
    def __init__(self):
        super(LinearModel, self).__init__()   #官方写法
        self.linear = torch.nn.Linear(1,1)    #线性模型

    def forward(self,x):                     #重写callable forward
        y_pred = self.linear(x)
        return y_pred

#part3 construct loss and optimizer
#实例化,创建线性模型
model = LinearModel()
criterion = torch.nn.MSELoss(size_average=False)   #loss 不需要算均值
optimizer = torch.optim.SGD(model.parameters(),lr=0.01)  #SGD 随机梯度下降优化器;  model.parameters:输入权重; lr 学习率

#part4 training cycle
for epoch in range(100):
    y_pred = model(x_data)
    loss = criterion(y_pred,y_data)
    print(epoch,loss.item()) #.item取出矩阵的元素值

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

print('w =',model.linear.weight.item())#.item取出矩阵的元素值
print('b=',model.linear.bias.item())#.item取出矩阵的元素值
x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
print('y_pred = ',y_test.data.item())

#out:
0 16.281747817993164
1 7.293664455413818
2 3.29177188873291
3 1.5095974206924438
4 0.7155891060829163
5 0.3614935278892517
6 0.20324327051639557
7 0.13218620419502258
8 0.09995455294847488
9 0.08501528948545456
10 0.07778217643499374
11 0.07398831844329834
12 0.07173360139131546
13 0.07017242908477783
14 0.06892779469490051
15 0.0678321123123169
16 0.06681052595376968
17 0.06582929939031601
18 0.06487414985895157
19 0.06393765658140182
20 0.063016876578331
21 0.06211043894290924
22 0.061217449605464935
23 0.060337554663419724
24 0.059470437467098236
25 0.05861560255289078
26 0.057773198932409286
27 0.05694276839494705
28 0.05612459406256676
29 0.05531793832778931
30 0.054522860795259476
31 0.05373939126729965
32 0.05296697840094566
33 0.05220579355955124
34 0.05145558714866638
35 0.05071603134274483
36 0.04998715966939926
37 0.049268800765275955
38 0.048560768365859985
39 0.047862716019153595
40 0.04717487469315529
41 0.046496935188770294
42 0.04582870006561279
43 0.04517001658678055
44 0.04452100023627281
45 0.043881095945835114
46 0.043250422924757004
47 0.04262881353497505
48 0.042016178369522095
49 0.041412435472011566
D:\Anaconda3\envs\pcd\lib\site-packages\torch\nn_reduction.py:43: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.
  warnings.warn(warning.format(ret))
50 0.0408172532916069
51 0.040230605751276016
52 0.03965239226818085
53 0.03908245638012886
54 0.03852083534002304
55 0.03796733543276787
56 0.03742161765694618
57 0.0368838757276535
58 0.03635374456644058
59 0.035831205546855927
60 0.03531632944941521
61 0.03480881080031395
62 0.034308500587940216
63 0.03381550312042236
64 0.03332945331931114
65 0.032850489020347595
66 0.03237827494740486
67 0.03191307559609413
68 0.0314544141292572
69 0.031002411618828773
70 0.030556699261069298
71 0.03011762723326683
72 0.02968478389084339
73 0.029258109629154205
74 0.028837714344263077
75 0.028423231095075607
76 0.028014766052365303
77 0.02761211432516575
78 0.027215277776122093
79 0.02682417817413807
80 0.02643866091966629
81 0.026058759540319443
82 0.025684215128421783
83 0.02531503140926361
84 0.024951264262199402
85 0.024592645466327667
86 0.024239204823970795
87 0.02389087900519371
88 0.023547515273094177
89 0.0232091061770916
90 0.022875573486089706
91 0.02254680171608925
92 0.022222815081477165
93 0.0219033882021904
94 0.02158866822719574
95 0.021278396248817444
96 0.020972535014152527
97 0.020671162754297256
98 0.02037406899034977
99 0.020081277936697006
w = 1.905661940574646
b= 0.21445274353027344
y_pred =  7.837100505828857

Process finished with exit code 0