欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

pytorch之——Batch Normalization

程序员文章站 2022-05-18 10:04:26
...

pytorch之——Batch Normalization


一、Batch Normalization概念

1.Batch Normalization:批标准化

批:一批数据,通常为mini_batch
标准化:0均值,1方差
优点:
(1)可以用更大的学习率,加速模型收敛
(2)可以不用精心设计权值初始化(因为权值初始化也是为了缩放数据尺度,BN有着同样的作用)
(3)可以不用dropout或较小的dropout(论文中实验尝试得到)
(4)可以不用L2或者较小的weight decay(论文中实验尝试得到)
(5)可以不用LRN(local response normalization)(也是一种归一化)

2.计算方式

pytorch之——Batch Normalization
可学习参数gamma和beta的作用:增加模型的可学习能力。是为了使模型自己学习该数据是否需要标准化

二、Pytorch的Batch Normalization 1d/2d/3d实现

1._BatchNorm(基类)

pytorch之——Batch Normalization

2.Batch Normalization 1d/2d/3d

pytorch之——Batch Normalization
注意事项:pytorch在实现BN的时候,当前时刻的均值和方差也考虑了之前时刻的均值和方差,具体计算方式如上图中的running_mean和running_var。

pre_running_mean为之前时刻的均值。
mean_t为当前t时刻的均值和方差

3.Batch Normalization 1d/2d/3d的输入及计算方式

pytorch之——Batch Normalization
说明:BN在计算时,是在每一批数据的每一个特征维度上分别计算一个均值和方差,如上图的讲解

4.pytorch代码实现

bn = orch.nn.BatchNorm1d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) #1d

pytorch之——Batch Normalization

bn = torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)#2d

pytorch之——Batch Normalization

bn = torch.nn.BatchNorm3d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) #3d

pytorch之——Batch Normalization

参考

深度之眼pytorch框架班以及pytorch中文文档

相关标签: pytorch框架知识