学习笔记|Pytorch使用教程25(Batch Normalization)
程序员文章站
2022-07-16 17:23:05
...
学习笔记|Pytorch使用教程25
本学习笔记主要摘自“深度之眼”,做一个总结,方便查阅。
使用Pytorch版本为1.2
- Batch Normalization概念
- PyTorch的Batch Normalization 1d/2d/3d实现
一.Batch Normalization概念
Batch Normalization :批标准化
批:一批数据,通常为mini- batch
标准化: 0均值,1方差
优点:
- 1.可以用更大学习率,加速模型收敛
- 2.可以不用精心设计权值初始化
- 3.可以不用dropout或较小的dropout
- 4.可以不用L2或者较小的weight decay
- 5.可以不用LRN(local response normalization)
- 《Batch Normalization: Accelerating Deep Network Training by Reducing
Internal Covariate Shit》
测试代码:
import torch
import numpy as np
import torch.nn as nn
from tools.common_tools import set_seed
set_seed(1) # 设置随机种子
class MLP(nn.Module):
def __init__(self, neural_num, layers=100):
super(MLP, self).__init__()
self.linears = nn.ModuleList([nn.Linear(neural_num, neural_num, bias=False) for i in range(layers)])
self.bns = nn.ModuleList([nn.BatchNorm1d(neural_num) for i in range(layers)])
self.neural_num = neural_num
def forward(self, x):
for (i, linear), bn in zip(enumerate(self.linears), self.bns):
x = linear(x)
# x = bn(x)
x = torch.relu(x)
if torch.isnan(x.std()):
print("output is nan in {} layers".format(i))
break
print("layers:{}, std:{}".format(i, x.std().item()))
return x
def initialize(self):
for m in self.modules():
if isinstance(m, nn.Linear):
# method 1
nn.init.normal_(m.weight.data, std=1) # normal: mean=0, std=1
# method 2 kaiming
# nn.init.kaiming_normal_(m.weight.data)
neural_nums = 256
layer_nums = 100
batch_size = 16
net = MLP(neural_nums, layer_nums)
# net.initialize()
inputs = torch.randn((batch_size, neural_nums)) # normal: mean=0, std=1
output = net(inputs)
print(output)
输出:
layers:0, std:0.3342404067516327
layers:1, std:0.13787388801574707
layers:2, std:0.05783054977655411
layers:3, std:0.02498556487262249
layers:4, std:0.009679116308689117
layers:5, std:0.0040797945111989975
layers:6, std:0.0016723505686968565
layers:7, std:0.000768698868341744
......
layers:93, std:7.51512610937515e-38
layers:94, std:2.6169094678434883e-38
layers:95, std:1.1516209894049713e-38
layers:96, std:4.344910860036386e-39
layers:97, std:1.5943525511579185e-39
layers:98, std:5.721221370145363e-40
layers:99, std:2.4877251637158477e-40
tensor([[0.0000e+00, 2.1158e-41, 0.0000e+00, ..., 0.0000e+00, 0.0000e+00,
0.0000e+00],
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, 0.0000e+00,
0.0000e+00],
[0.0000e+00, 5.1800e-41, 0.0000e+00, ..., 0.0000e+00, 0.0000e+00,
0.0000e+00],
...,
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, 0.0000e+00,
0.0000e+00],
[0.0000e+00, 5.8066e-41, 0.0000e+00, ..., 0.0000e+00, 0.0000e+00,
0.0000e+00],
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, 0.0000e+00,
0.0000e+00]], grad_fn=<ReluBackward0>)
发现参数在100层的时候非常小了。
现在进行初始化,设置:net.initialize()
输出:
layers:0, std:9.35224723815918
layers:1, std:112.47123718261719
layers:2, std:1322.805419921875
layers:3, std:14569.419921875
layers:4, std:154672.703125
layers:5, std:1834037.125
layers:6, std:18807968.0
layers:7, std:209552880.0
......
layers:28, std:3.221297392084588e+30
layers:29, std:3.530939139138446e+31
layers:30, std:4.525336236359181e+32
layers:31, std:4.714992054712809e+33
layers:32, std:5.369568386632447e+34
layers:33, std:6.712290740934239e+35
layers:34, std:7.451081630611702e+36
output is nan in 35 layers
tensor([[3.2625e+36, 0.0000e+00, 7.2931e+37, ..., 0.0000e+00, 0.0000e+00,
2.5465e+38],
[3.9236e+36, 0.0000e+00, 7.5033e+37, ..., 0.0000e+00, 0.0000e+00,
2.1274e+38],
[0.0000e+00, 0.0000e+00, 4.4931e+37, ..., 0.0000e+00, 0.0000e+00,
1.7016e+38],
...,
[0.0000e+00, 0.0000e+00, 2.4222e+37, ..., 0.0000e+00, 0.0000e+00,
2.5295e+38],
[4.7380e+37, 0.0000e+00, 2.1579e+37, ..., 0.0000e+00, 0.0000e+00,
2.6028e+38],
[0.0000e+00, 0.0000e+00, 6.0877e+37, ..., 0.0000e+00, 0.0000e+00,
2.1695e+38]], grad_fn=<ReluBackward0>)
网络在35层的时候就出现了nan的情况。
使用凯明初始化:nn.init.kaiming_normal_(m.weight.data)
输出:
layers:0, std:0.826629638671875
layers:1, std:0.878681480884552
layers:2, std:0.9134420156478882
layers:3, std:0.8892467617988586
layers:4, std:0.8344276547431946
layers:5, std:0.87453693151474
layers:6, std:0.792696475982666
layers:7, std:0.7806451916694641
......
layers:92, std:0.6094536185264587
layers:93, std:0.6019036173820496
layers:94, std:0.595414936542511
layers:95, std:0.6624482870101929
layers:96, std:0.6377813220024109
layers:97, std:0.6079217195510864
layers:98, std:0.6579239368438721
layers:99, std:0.6668398976325989
tensor([[0.0000, 1.3437, 0.0000, ..., 0.0000, 0.6444, 1.1867],
[0.0000, 0.9757, 0.0000, ..., 0.0000, 0.4645, 0.8594],
[0.0000, 1.0023, 0.0000, ..., 0.0000, 0.5147, 0.9196],
...,
[0.0000, 1.2873, 0.0000, ..., 0.0000, 0.6454, 1.1411],
[0.0000, 1.3588, 0.0000, ..., 0.0000, 0.6749, 1.2437],
[0.0000, 1.1807, 0.0000, ..., 0.0000, 0.5668, 1.0600]],
grad_fn=<ReluBackward0>)
数据有一定的波动,现在加入bn层:x = bn(x)
输出:
layers:0, std:0.5872595906257629
layers:1, std:0.579325795173645
layers:2, std:0.5757012367248535
layers:3, std:0.5840616822242737
layers:4, std:0.5781518220901489
layers:5, std:0.5856173634529114
layers:6, std:0.5862171053886414
......
layers:95, std:0.5735476016998291
layers:96, std:0.5807774662971497
layers:97, std:0.5868753790855408
layers:98, std:0.5801646113395691
layers:99, std:0.5738694667816162
tensor([[0.0000, 0.0000, 0.0000, ..., 0.0000, 1.4841, 0.0000],
[1.5034, 0.0000, 0.2277, ..., 0.3768, 0.0000, 0.0000],
[0.9003, 0.0000, 1.7231, ..., 0.0000, 0.0000, 1.1034],
...,
[0.0000, 0.0000, 0.6059, ..., 0.0000, 0.0000, 0.0000],
[0.7283, 0.6607, 0.4622, ..., 0.0000, 0.0000, 0.0000],
[0.0331, 0.0000, 1.0855, ..., 1.2032, 0.0000, 0.3746]],
grad_fn=<ReluBackward0>)
数据尺度保持的很好,如果放弃初始化:#net.initialize()
输出:
layers:0, std:0.5751240849494934
layers:1, std:0.5803307890892029
layers:2, std:0.5825020670890808
layers:3, std:0.5823132395744324
......
layers:97, std:0.5814812183380127
layers:98, std:0.5802980661392212
layers:99, std:0.5824452638626099
tensor([[2.4655, 0.3893, 0.0000, ..., 1.9130, 0.7964, 0.7588],
[0.3542, 0.1579, 2.3155, ..., 0.0500, 0.2595, 0.0000],
[0.0000, 0.0000, 0.2838, ..., 0.0000, 0.9119, 0.2732],
...,
[0.0000, 1.5330, 0.0000, ..., 0.1120, 0.0000, 1.9477],
[0.0000, 0.0000, 2.0451, ..., 0.0000, 0.0000, 0.0000],
[0.5085, 0.8023, 0.3493, ..., 0.2117, 0.0000, 0.0000]],
grad_fn=<ReluBackward0>)
数据尺度仍然保持的很好。
注意:bn层需要在**函数前使用。
搭建带BN层的LeNet网络:
class LeNet_bn(nn.Module):
def __init__(self, classes):
super(LeNet_bn, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.bn1 = nn.BatchNorm2d(num_features=6)
self.conv2 = nn.Conv2d(6, 16, 5)
self.bn2 = nn.BatchNorm2d(num_features=16)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.bn3 = nn.BatchNorm1d(num_features=120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, classes)
def forward(self, x):
out = self.conv1(x)
out = self.bn1(out)
out = F.relu(out)
out = F.max_pool2d(out, 2)
out = self.conv2(out)
out = self.bn2(out)
out = F.relu(out)
out = F.max_pool2d(out, 2)
out = out.view(out.size(0), -1)
out = self.fc1(out)
out = self.bn3(out)
out = F.relu(out)
out = F.relu(self.fc2(out))
out = self.fc3(out)
return out
def initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.xavier_normal_(m.weight.data)
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight.data, 0, 1)
m.bias.data.zero_()
下面测试bn层在网络中的使用。
使用的完整训练代码在:学习笔记|Pytorch使用教程05(Dataloader与Dataset)
先不加bn层,且不初始化进性测试:
# ============================ step 2/5 模型 ============================
#net = LeNet_bn(classes=2)
net = LeNet(classes=2)
# net.initialize_weights()
输出:
Training:Epoch[000/010] Iteration[010/010] Loss: 0.6966 Acc:50.00%
Valid: Epoch[000/010] Iteration[002/002] Loss: 1.3483 Acc:50.00%
Training:Epoch[001/010] Iteration[010/010] Loss: 0.6888 Acc:53.75%
Valid: Epoch[001/010] Iteration[002/002] Loss: 1.3469 Acc:53.75%
Training:Epoch[002/010] Iteration[010/010] Loss: 0.6822 Acc:60.62%
Valid: Epoch[002/010] Iteration[002/002] Loss: 1.3270 Acc:60.62%
Training:Epoch[003/010] Iteration[010/010] Loss: 0.6739 Acc:81.25%
Valid: Epoch[003/010] Iteration[002/002] Loss: 1.2961 Acc:81.25%
Training:Epoch[004/010] Iteration[010/010] Loss: 0.6466 Acc:83.75%
Valid: Epoch[004/010] Iteration[002/002] Loss: 1.1401 Acc:83.75%
Training:Epoch[005/010] Iteration[010/010] Loss: 0.5422 Acc:95.62%
Valid: Epoch[005/010] Iteration[002/002] Loss: 0.6329 Acc:95.62%
Training:Epoch[006/010] Iteration[010/010] Loss: 0.2208 Acc:96.88%
Valid: Epoch[006/010] Iteration[002/002] Loss: 0.0163 Acc:96.88%
Training:Epoch[007/010] Iteration[010/010] Loss: 0.1321 Acc:95.62%
Valid: Epoch[007/010] Iteration[002/002] Loss: 0.0006 Acc:95.62%
Training:Epoch[008/010] Iteration[010/010] Loss: 0.2649 Acc:93.75%
Valid: Epoch[008/010] Iteration[002/002] Loss: 1.2047 Acc:93.75%
Training:Epoch[009/010] Iteration[010/010] Loss: 0.4774 Acc:87.50%
Valid: Epoch[009/010] Iteration[002/002] Loss: 0.4023 Acc:87.50%
进行初始化:net.initialize_weights()
输出:
Training:Epoch[000/010] Iteration[010/010] Loss: 0.6846 Acc:53.75%
Valid: Epoch[000/010] Iteration[002/002] Loss: 0.9805 Acc:53.75%
Training:Epoch[001/010] Iteration[010/010] Loss: 0.4099 Acc:85.00%
Valid: Epoch[001/010] Iteration[002/002] Loss: 0.0829 Acc:85.00%
Training:Epoch[002/010] Iteration[010/010] Loss: 0.1470 Acc:94.38%
Valid: Epoch[002/010] Iteration[002/002] Loss: 0.0035 Acc:94.38%
Training:Epoch[003/010] Iteration[010/010] Loss: 0.4276 Acc:88.12%
Valid: Epoch[003/010] Iteration[002/002] Loss: 0.2250 Acc:88.12%
Training:Epoch[004/010] Iteration[010/010] Loss: 0.3169 Acc:87.50%
Valid: Epoch[004/010] Iteration[002/002] Loss: 0.1232 Acc:87.50%
Training:Epoch[005/010] Iteration[010/010] Loss: 0.2026 Acc:91.88%
Valid: Epoch[005/010] Iteration[002/002] Loss: 0.0132 Acc:91.88%
Training:Epoch[006/010] Iteration[010/010] Loss: 0.1064 Acc:95.62%
Valid: Epoch[006/010] Iteration[002/002] Loss: 0.0002 Acc:95.62%
Training:Epoch[007/010] Iteration[010/010] Loss: 0.0482 Acc:99.38%
Valid: Epoch[007/010] Iteration[002/002] Loss: 0.0006 Acc:99.38%
Training:Epoch[008/010] Iteration[010/010] Loss: 0.0069 Acc:100.00%
Valid: Epoch[008/010] Iteration[002/002] Loss: 0.0000 Acc:100.00%
Training:Epoch[009/010] Iteration[010/010] Loss: 0.0133 Acc:99.38%
Valid: Epoch[009/010] Iteration[002/002] Loss: 0.0000 Acc:99.38%
接下来使用LeNet_bn网络进行训练:net = LeNet_bn(classes=2)
Training:Epoch[000/010] Iteration[010/010] Loss: 0.6666 Acc:60.00%
Valid: Epoch[000/010] Iteration[002/002] Loss: 1.2814 Acc:60.00%
Training:Epoch[001/010] Iteration[010/010] Loss: 0.4274 Acc:93.12%
Valid: Epoch[001/010] Iteration[002/002] Loss: 0.4916 Acc:93.12%
Training:Epoch[002/010] Iteration[010/010] Loss: 0.1601 Acc:98.75%
Valid: Epoch[002/010] Iteration[002/002] Loss: 0.0878 Acc:98.75%
Training:Epoch[003/010] Iteration[010/010] Loss: 0.0688 Acc:100.00%
Valid: Epoch[003/010] Iteration[002/002] Loss: 0.0104 Acc:100.00%
Training:Epoch[004/010] Iteration[010/010] Loss: 0.0406 Acc:98.75%
Valid: Epoch[004/010] Iteration[002/002] Loss: 0.0109 Acc:98.75%
Training:Epoch[005/010] Iteration[010/010] Loss: 0.0895 Acc:95.62%
Valid: Epoch[005/010] Iteration[002/002] Loss: 0.0061 Acc:95.62%
Training:Epoch[006/010] Iteration[010/010] Loss: 0.0765 Acc:95.62%
Valid: Epoch[006/010] Iteration[002/002] Loss: 0.0675 Acc:95.62%
Training:Epoch[007/010] Iteration[010/010] Loss: 0.0370 Acc:98.75%
Valid: Epoch[007/010] Iteration[002/002] Loss: 0.0069 Acc:98.75%
Training:Epoch[008/010] Iteration[010/010] Loss: 0.0144 Acc:100.00%
Valid: Epoch[008/010] Iteration[002/002] Loss: 0.0028 Acc:100.00%
Training:Epoch[009/010] Iteration[010/010] Loss: 0.0365 Acc:98.75%
Valid: Epoch[009/010] Iteration[002/002] Loss: 0.0015 Acc:98.75%
二.PyTorch的Batch Normalization 1d/2d/3d实现
1._BatchNorm
- nn.BatchN orm1d
- nn. BatchNorm2d
- nn. BatchNorm2d
参数: - num_features :一个样本特征数量(最重要)
- eps:分母修正项
- momentum :指数加权平均估计当前mean/var
- affine :是否需要affine transform
- track_running_stats :是训练状态,还是测试状态
- 训练:均值和方差采用指数加权平均计算
- 测试:当前统计值
主要属性:
- running_mean :均值
- running_var :方差
- weight : affine transform中的gamma
-
bias : affine transform中的beta
1.nn.BatchNorm1d
测试代码:
import torch
import numpy as np
import torch.nn as nn
from tools.common_tools import set_seed
set_seed(1) # 设置随机种子
# ======================================== nn.BatchNorm1d
flag = 1
# flag = 0
if flag:
batch_size = 3
num_features = 5
momentum = 0.3
features_shape = (1)
feature_map = torch.ones(features_shape) # 1D
feature_maps = torch.stack([feature_map*(i+1) for i in range(num_features)], dim=0) # 2D
feature_maps_bs = torch.stack([feature_maps for i in range(batch_size)], dim=0) # 3D
print("input data:\n{} shape is {}".format(feature_maps_bs, feature_maps_bs.shape))
bn = nn.BatchNorm1d(num_features=num_features, momentum=momentum)
running_mean, running_var = 0, 1
for i in range(2):
outputs = bn(feature_maps_bs)
print("\niteration:{}, running mean: {} ".format(i, bn.running_mean))
print("iteration:{}, running var:{} ".format(i, bn.running_var))
mean_t, var_t = 2, 0
running_mean = (1 - momentum) * running_mean + momentum * mean_t
running_var = (1 - momentum) * running_var + momentum * var_t
print("iteration:{}, 第二个特征的running mean: {} ".format(i, running_mean))
print("iteration:{}, 第二个特征的running var:{}".format(i, running_var))
输出:
input data:
tensor([[[1.],
[2.],
[3.],
[4.],
[5.]],
[[1.],
[2.],
[3.],
[4.],
[5.]],
[[1.],
[2.],
[3.],
[4.],
[5.]]]) shape is torch.Size([3, 5, 1])
iteration:0, running mean: tensor([0.3000, 0.6000, 0.9000, 1.2000, 1.5000])
iteration:0, running var:tensor([0.7000, 0.7000, 0.7000, 0.7000, 0.7000])
iteration:0, 第二个特征的running mean: 0.6
iteration:0, 第二个特征的running var:0.7
iteration:1, running mean: tensor([0.5100, 1.0200, 1.5300, 2.0400, 2.5500])
iteration:1, running var:tensor([0.4900, 0.4900, 0.4900, 0.4900, 0.4900])
iteration:1, 第二个特征的running mean: 1.02
iteration:1, 第二个特征的running var:0.48999999999999994
2.nn.BatchNorm2d
测试代码:
# ======================================== nn.BatchNorm2d
flag = 1
# flag = 0
if flag:
batch_size = 3
num_features = 6
momentum = 0.3
features_shape = (2, 2)
feature_map = torch.ones(features_shape) # 2D
feature_maps = torch.stack([feature_map*(i+1) for i in range(num_features)], dim=0) # 3D
feature_maps_bs = torch.stack([feature_maps for i in range(batch_size)], dim=0) # 4D
print("input data:\n{} shape is {}".format(feature_maps_bs, feature_maps_bs.shape))
bn = nn.BatchNorm2d(num_features=num_features, momentum=momentum)
running_mean, running_var = 0, 1
for i in range(2):
outputs = bn(feature_maps_bs)
print("\niter:{}, running_mean.shape: {}".format(i, bn.running_mean.shape))
print("iter:{}, running_var.shape: {}".format(i, bn.running_var.shape))
print("iter:{}, weight.shape: {}".format(i, bn.weight.shape))
print("iter:{}, bias.shape: {}".format(i, bn.bias.shape))
输出:
input data:
tensor([[[[1., 1.],
[1., 1.]],
[[2., 2.],
[2., 2.]],
[[3., 3.],
[3., 3.]],
[[4., 4.],
[4., 4.]],
[[5., 5.],
[5., 5.]],
[[6., 6.],
[6., 6.]]],
[[[1., 1.],
[1., 1.]],
[[2., 2.],
[2., 2.]],
[[3., 3.],
[3., 3.]],
[[4., 4.],
[4., 4.]],
[[5., 5.],
[5., 5.]],
[[6., 6.],
[6., 6.]]],
[[[1., 1.],
[1., 1.]],
[[2., 2.],
[2., 2.]],
[[3., 3.],
[3., 3.]],
[[4., 4.],
[4., 4.]],
[[5., 5.],
[5., 5.]],
[[6., 6.],
[6., 6.]]]]) shape is torch.Size([3, 6, 2, 2])
iter:0, running_mean.shape: torch.Size([6])
iter:0, running_var.shape: torch.Size([6])
iter:0, weight.shape: torch.Size([6])
iter:0, bias.shape: torch.Size([6])
iter:1, running_mean.shape: torch.Size([6])
iter:1, running_var.shape: torch.Size([6])
iter:1, weight.shape: torch.Size([6])
iter:1, bias.shape: torch.Size([6])
3.nn.BatchNorm3d
测试代码:
# ======================================== nn.BatchNorm3d
flag = 1
# flag = 0
if flag:
batch_size = 3
num_features = 4
momentum = 0.3
features_shape = (2, 2, 3)
feature = torch.ones(features_shape) # 3D
feature_map = torch.stack([feature * (i + 1) for i in range(num_features)], dim=0) # 4D
feature_maps = torch.stack([feature_map for i in range(batch_size)], dim=0) # 5D
print("input data:\n{} shape is {}".format(feature_maps, feature_maps.shape))
bn = nn.BatchNorm3d(num_features=num_features, momentum=momentum)
running_mean, running_var = 0, 1
for i in range(2):
outputs = bn(feature_maps)
print("\niter:{}, running_mean.shape: {}".format(i, bn.running_mean.shape))
print("iter:{}, running_var.shape: {}".format(i, bn.running_var.shape))
print("iter:{}, weight.shape: {}".format(i, bn.weight.shape))
print("iter:{}, bias.shape: {}".format(i, bn.bias.shape))
输出:
input data:
tensor([[[[[1., 1., 1.],
[1., 1., 1.]],
[[1., 1., 1.],
[1., 1., 1.]]],
[[[2., 2., 2.],
[2., 2., 2.]],
[[2., 2., 2.],
[2., 2., 2.]]],
[[[3., 3., 3.],
[3., 3., 3.]],
[[3., 3., 3.],
[3., 3., 3.]]],
[[[4., 4., 4.],
[4., 4., 4.]],
[[4., 4., 4.],
[4., 4., 4.]]]],
[[[[1., 1., 1.],
[1., 1., 1.]],
[[1., 1., 1.],
[1., 1., 1.]]],
[[[2., 2., 2.],
[2., 2., 2.]],
[[2., 2., 2.],
[2., 2., 2.]]],
[[[3., 3., 3.],
[3., 3., 3.]],
[[3., 3., 3.],
[3., 3., 3.]]],
[[[4., 4., 4.],
[4., 4., 4.]],
[[4., 4., 4.],
[4., 4., 4.]]]],
[[[[1., 1., 1.],
[1., 1., 1.]],
[[1., 1., 1.],
[1., 1., 1.]]],
[[[2., 2., 2.],
[2., 2., 2.]],
[[2., 2., 2.],
[2., 2., 2.]]],
[[[3., 3., 3.],
[3., 3., 3.]],
[[3., 3., 3.],
[3., 3., 3.]]],
[[[4., 4., 4.],
[4., 4., 4.]],
[[4., 4., 4.],
[4., 4., 4.]]]]]) shape is torch.Size([3, 4, 2, 2, 3])
iter:0, running_mean.shape: torch.Size([4])
iter:0, running_var.shape: torch.Size([4])
iter:0, weight.shape: torch.Size([4])
iter:0, bias.shape: torch.Size([4])
iter:1, running_mean.shape: torch.Size([4])
iter:1, running_var.shape: torch.Size([4])
iter:1, weight.shape: torch.Size([4])
iter:1, bias.shape: torch.Size([4])