欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

『Pytorch』BatchNorm

程序员文章站 2022-07-13 11:38:42
...

1. Batch Normalization

  • 批:一批数据,通常为mini batch
  • 标准化: 0均值,1方差
  • 优点:
    1. 可以用更大的学习率,加速模型收敛
    2. 可以不用精心设计权值初始化
    3. 可以不用Dropout或较小的Dropout
    4. 可以不用L2或较小的权重衰减
    5. 可以不用local response normalization

2. BatchNorm

  • nn.BatchNorm1d
  • nn.BatchNorm2d
  • nn.BatchNorm3d
# 三个方法都继承自基类_BatchNorm,下面是它的初始化方法
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True):
    pass

主要参数:

  • num_features: 一个样本特征数量(最重要)
  • eps: 分母修正项
  • momentum: 指数加权平均估计当前mean/var
  • affine: 是否需要 affine transform
  • track_running_stats: 是训练状态还是测试状态

『Pytorch』BatchNorm

三个具体方法的主要属性:

  • running_mean: 均值
  • running_var: 方差
  • weight: affine transform中的 γ\gamma
  • bias: affine transform中的 β\beta

weight和bias都是指数加权平均得到的,后一个mini-batch要考虑前一个mini-batch

input = batch_size * 特征数 * 特征的维度

  • 特征的维度就是shape,是一个元祖,可以是多维

  • 特征数例如BP网络中的一层神经元的个数,或者CNN中的每一层输出的feature maps 的个数(类似最一开始输入的图像的channel数)