『Pytorch』BatchNorm
程序员文章站
2022-07-13 11:38:42
...
1. Batch Normalization
- 批:一批数据,通常为mini batch
- 标准化: 0均值,1方差
- 优点:
- 可以用更大的学习率,加速模型收敛
- 可以不用精心设计权值初始化
- 可以不用Dropout或较小的Dropout
- 可以不用L2或较小的权重衰减
- 可以不用local response normalization
2. BatchNorm
- nn.BatchNorm1d
- nn.BatchNorm2d
- nn.BatchNorm3d
# 三个方法都继承自基类_BatchNorm,下面是它的初始化方法
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True):
pass
主要参数:
- num_features: 一个样本特征数量(最重要)
- eps: 分母修正项
- momentum: 指数加权平均估计当前mean/var
- affine: 是否需要 affine transform
- track_running_stats: 是训练状态还是测试状态
三个具体方法的主要属性:
- running_mean: 均值
- running_var: 方差
- weight: affine transform中的
- bias: affine transform中的
weight和bias都是指数加权平均得到的,后一个mini-batch要考虑前一个mini-batch
input = batch_size * 特征数 * 特征的维度
特征的维度就是shape,是一个元祖,可以是多维
特征数例如BP网络中的一层神经元的个数,或者CNN中的每一层输出的feature maps 的个数(类似最一开始输入的图像的channel数)