pytorch之——Batch Normalization
程序员文章站
2022-05-18 10:04:26
...
pytorch之——Batch Normalization
文章目录
一、Batch Normalization概念
1.Batch Normalization:批标准化
批:一批数据,通常为mini_batch
标准化:0均值,1方差
优点:
(1)可以用更大的学习率,加速模型收敛
(2)可以不用精心设计权值初始化(因为权值初始化也是为了缩放数据尺度,BN有着同样的作用)
(3)可以不用dropout或较小的dropout(论文中实验尝试得到)
(4)可以不用L2或者较小的weight decay(论文中实验尝试得到)
(5)可以不用LRN(local response normalization)(也是一种归一化)
2.计算方式
可学习参数gamma和beta的作用:增加模型的可学习能力。是为了使模型自己学习该数据是否需要标准化
二、Pytorch的Batch Normalization 1d/2d/3d实现
1._BatchNorm(基类)
2.Batch Normalization 1d/2d/3d
注意事项:pytorch在实现BN的时候,当前时刻的均值和方差也考虑了之前时刻的均值和方差,具体计算方式如上图中的running_mean和running_var。
pre_running_mean为之前时刻的均值。
mean_t为当前t时刻的均值和方差
3.Batch Normalization 1d/2d/3d的输入及计算方式
说明:BN在计算时,是在每一批数据的每一个特征维度上分别计算一个均值和方差,如上图的讲解
4.pytorch代码实现
bn = orch.nn.BatchNorm1d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) #1d
bn = torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)#2d
bn = torch.nn.BatchNorm3d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) #3d
参考
深度之眼pytorch框架班以及pytorch中文文档