欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

Batch Normalization详细解读

程序员文章站 2023-12-21 14:19:10
...

这篇文章是论文 Batch Normalization Accelerating Deep Network Training by Reducing Internal Covariate Shift 的翻译,其中精简了部分内容的同时解释了相关的概念,如有错误,敬请指教。

Abstract

在神经网络训练过程中,前一层权重参数的改变会造成之后每层输入样本分布的改变,这造成了训练过程的困难。为了解决这个问题,通常会使用小的学习率和参数初始化技巧,这样导致了训练速度变慢,尤其是训练具有饱和非线性的模型。我们将这一现象定义为internal covariate shift,并提出通过规范化输入来解决。将标准化作为模型架构的一部分,并且对每一个training mini-batch使用标准化来增强算法的效果。BN允许我们使用更高的学习率,并且不用太过关心初始化。同时由于有正则化的效果,有时还能省略Dropout,加速模型训练。BN在ImageNet上的测试取得了很好的效果。

1 Introduction

SGD(Stochastic gradient descent)是一种训练神经网络的有效方法。SGD的改进算法,如momentum、Adagrad都应用广泛。SGD通过优化权重参数来减小loss:
Batch Normalization详细解读
训练过程中的mini-batch的梯度计算如下:
Batch Normalization详细解读

使用mini-batch的优点如下:

  • mini-batch的梯度是对整个数据集梯度的估计,当数据集过大时,计算量暴增,使用mini-batch能提高梯度更新的效率(可参考GoodFellow的DeepLearning教材)
  • 对于含有m个单独样本的mini-batch的一次计算比单独计算m次更快,这得益于并行计算。

虽然SGD简单有效,但是超参数调整十分麻烦,尤其是学习率和初始化。由于每层网络都受到之前所有网络的影响,随着网络加深,每层网络参数的微小变化都会被之后的网络逐渐放大。
每层网络参数的变化会导致一些问题,因为之后的网络必须适应新的分布。假设有一个网络的loss为:
Batch Normalization详细解读
其中F1,F2为任意变换,内层的函数
Batch Normalization详细解读
可以看成是外层函数的一个参数,即:
Batch Normalization详细解读
对于一步梯度更新:
Batch Normalization详细解读
可以看做是独立的网络F2,输入为x。输入分布的特殊性质可以使得训练过程更高效,比如训练集和测试集完全相同的时候。因此,如果固定x的分布,参数就不用再补偿因为x的分布改变引起的变化。
在子网络中固定输入的分布对于外层网络有积极的影响。考虑一个具有sigmoid**函数的网络:
Batch Normalization详细解读
其中,W和b是需要学习的参数,g(x)为sigmoid函数。由sigmoid函数的性质可知,当x的绝对值增大时,g的导数会逐渐减小至零。这意味着对于所有的x=Wu+b,除开绝对值较小的部分,其余的x值都会使梯度逐渐消失(gradient vanish)。由于x受W和b的影响,改变这两个参数有可能使得x的很多维度陷入饱和的非线性区,降低收敛速度。这个效应会随着网络层数的加深变得愈发明显。在实践过程中,这个问题通常可以使用ReLU来解决,但是如果可以控制非线性输入的分布更加稳定,就可以使优化器更少陷入饱和区域,以达到加速的目的。

2 Towards Reducing Internal Covariate Shift

我们将 Internal Covariate Shift1 定义为训练过程中网络参数的改变引起的网络**函数分布的改变。BN的目标是减少 Internal Covariate Shift ,以达到加速的目的。已知的是,如果网络的输入进行了白化23,训练时会收敛得更快。这可以通过线性变换,使得输入为零均值,单位方差,并且去除了相关性的数据分布。白化操作可以通过直接更改网络或者改变优化算法来实现。但是如果直接对原始数据进行白化,需要计算输入x的协方差矩阵和求逆,计算量会很大。这促使我们寻找一种替代的方法来完成输入标准化,并且在参数更新过程中不需要对整个数据集进行分析。
注:
1. 白化的解释可参考 UFLDL的教程4
2. Covariate Shift的解释可以参考这里

3 Normalization via Mini-Batch Statistics

因为对每一层的输入进行白化代价太大,这里作出两个必要的简化。

独立地标准化每一个标量特征(均值为零,方差为1),而不是将输入层和输出层的特征一起进行白化。对于一个d维的输入,标准化每一个单独的维度的公式为:
Batch Normalization详细解读
其中,期望和方差都是在 整个数据集 上进行计算。LeCun5 的文献表明,这种标准化的过程即使在数据没有去除相关性的条件下,仍然可以加速收敛。
注意,如果简单地对一层的输入进行标准化可能会影响这一层的表达能力。比如在使用sigmoid**函数的时候,如果简单地将输入数据进行零均值单位方差标准化,将使得原始数据更加集中,对应于sigmoid函数的中间部分,也就是只使用了函数的线性区域。详细的解释可以参考 这里 为避免这一问题,可以使用重构变换
Batch Normalization详细解读
其中的参数γ(k)β(k) 代表了放缩和平移的变换参数,需要和原始模型的参数一起学习。这样能恢复模型的表达能力。
注:对于重构变换的理解是,在标准化的过程中减去了均值,这相当于对原始数据的平移;除以标准差,这相当于对原始数据的放缩。因此重构变换可以考虑为反向的过程,即先进行放缩,也就是乘上γ(k);再进行平移,也就是加上β(k),只不过这两个参数需要进行学习。最特殊的情况是,如果设置 γ(k)=Var[x(k)]β(k)=E[x(k)] ,就可以恢复原始的数据。
在每个训练步骤的参数更新是基于整个训练集的情况下,我们可以使用整个训练集的来标准化**函数。但是这对于SGD不可行,因此作出第二个简化:
因为在SGD中使用mini-batches,每一个mini-batches在每一次**过程中产生均值和方差的估计。这样,标准化过程使用的统计数据就能在梯度反向传播过程中用上。注意使用mini-batches是通过方差的每一个单独的维度进行的,而不是协方差。如果是协方差的情况,将需要进行规范化,因为mini-batch的大小可能比白化的**函数要少,这将导致协方差矩阵奇异。
考虑一个mini-batch B,大小为m。BN算法流程如下:
Batch Normalization详细解读
其中,ε 是为保证数值稳定性添加到mini-bacth中的常数。
BN算法可以添加到网络中来控制**。注意到BN()并不是单独地在每一个训练样本中处理**,而是既依赖训练样本,也依赖于mini-batch中的其他样本。经过伸缩和平移变换的y 传播到其他层。如果忽略ε ,只要每一个mini-batch的元素是从同一个分布中采样,x^ 的取值分布都满足零均值和单位方差。子网络的输入具有固定的均值和方差。虽然所有x^(k) 的联合分布会在训练过程中发生改变。由于标准化子网络的引入会加速子网络的训练过程,最终加速整个网络的训练过程。
在训练过程中需要反向传播 的梯度,同时需要计算BN中参数的梯度,使用链式法则,简化前的公式为:
Batch Normalization详细解读
因此,BN是一个可微的变换,并且引入了正则化的效果。这保证了在模型训练时,每一层展示出更少的内部internal covariate shift的输入中进行学习,因此加速了训练过程。更重要的是,学习到的仿射变换参数可以保证网络的容量不受影响。

3.1 Training and Inference with Batch-Normalized Networks

对于BN的实际应用,原先接受 x 作为输入的单层网络现在接受BN(x) 作为输入。使用BN的网络可以用SGD及相关的改进算法(如Adagrad)来进行优化。但是在推理过程中,不需要网络进行标准化,而是希望网络的输出只依赖于原始的输入。因此,在网络训练完成后,使用如下公式:
Batch Normalization详细解读
即使用整个数据集所有样本,而不是一个mini-batch的统计资料。如果忽略ε ,标准化的**就具有和训练时相同的零均值和单位方差。使用无偏方差估计:
Batch Normalization详细解读
其中,期望是在mini-batch上进行计算,标准差是在所有样本上进行计算。因为均值和方差在推理过程中都是固定的,标准化只是一个用于**中的线性变换。下图是使用BN的网络的训练步骤:
Batch Normalization详细解读

3.2 Batch-Normalized Convolutional Networks

BN可以被应用到任何网络的**中。考虑一个如下的变换:
Batch Normalization详细解读
其中,Wb 是学习到的参数,g() 是非线性函数比如sigmoid或者ReLU。我们在非线性运算之前添加BN,即标准化x=W+b。当然,也可以直接标准化输入层中的u,但是由于u 其他非线性函数的输出,在训练过程中它的分布可能会发生改变。对它的一阶和二阶矩进行约束无法消除covariate shift。反之,x=W+b 则更可能具有一个对称、非稀疏的分布,更加近似正态分布。标准化这个分布更有希望产生一个具有稳定分布的**。
注意到标准化x=W+b 的过程中,偏差b ,可以忽略(去均值的过程)。因此,z=g(BN(Wu+b)) 可以替换为:
z=g(BN(Wu))
BN变换是应用到x=Wu 的每一个维度。在每一个维度都具有一对独立的参数γ(k)β(k)
对于卷级层,我们希望标准化过程符合卷积的一些性质以便于同一个特征图中的不同位置的元素,都可以使用相同的标准化。为了实现这一点,我们对一个mini-batch中的所有位置的**同时进行标准化。对于一个大小为m 的mini-batch,特征图的尺寸为pq,我们使用的有效mini-batch大小为m=mpq,在每一个特征图中学习到一对γ(k)β(k) 参数,而不是在每一处**中(有点类似于权值共享)。

3.3 Batch Normalization enables higher learning rates

通常,大的学习率会造成单层的参数放大,在反向传播的过程中,又会造成梯度的放大。但是如果使用BN的话,反向传播过程可以不受影响。假设放大倍数为a
Batch Normalization详细解读
不难推出:
Batch Normalization详细解读
由于放大过程不会影响单层的Jacobian矩阵,最终也不会影响反向传播。更重要的是,更大的权重会导致更小的梯度,因此BN具有稳定参数增长的作用。
我们可以猜想BN可能导致Jacobian矩阵具有接近于1的奇异值,这对于训练过程是有利的6。考虑两个相邻的层,同时对这两个层进行标准化输入。假设两个层之间的变换为:
z^=F(x^)
如果假定z^x^ 都属于正太分布并且不相关,那么
F(x^)Jx^ 对于给定的模型参数就近似于一个线性变换。并且如果z^x^ 都具有单位协方差,即I=cov[z^]=Jcov[x^]JT=JJT,因此J 的所有奇异值为1,这在反向啊传播的过程中可以稳定梯度的量级。实际的变换是非线性的,标准化之后的输入也不可能完全保证正态分布或者完全独立,但是BN仍然可以使反向传播的过程表现得更好。

4 Experiments

4.1 Activations over time

为了验证BN对于covariate shift的抑制作用,在MNIST数据集上进行验证。使用3个全连接层,每层100个神经元。每一个隐含层计算y=g(Wu+b),**函数为sigmoid,使用正态分布初始化W。loss使用交叉熵代价函数。训练50000步,每一个mini-batch有60个样本。实验结果见Figure 1。
Batch Normalization详细解读
从图(a)中可以看出,BN在少量数据的条件下能显著提高正确率。(b)(c)给出了15%、50%和85%的输入数据在整个训练过程中分布的差异,可以看到BN使得分布更加平滑,减少了internal covariate shift。

4.2 ImageNet classification

在Inception network中使用momentum优化器,mini-batch样本数为32,并使用3.2节中针对卷积的BN方法。

4.2.1 Accelerating BN Networks

简单地在网络中使用BN,并不能完全发挥出BN的优势,因此对网络的参数作出如下改变:

  • 增大学习率
  • 去除Dropout层
  • 减少L2正则
  • 增大学习率衰减
  • 去除LRN
  • 更加彻底地打乱训练样本
  • 减少图像亮度失真

4.2.2 Single-Network Classification

在LSVRC2012数据集上训练如下网络:

  • Inception:学习率0.0015
  • BN-Baseline:Inception+BN
  • BN-x5:Inception+BN+4.2.1中的改进,学习率0.0075
  • BN-x30:与BN-x5类似,初始学习率为0.045
  • BN-x5-Sigmoid:与BN-x5类似,但是使用sigmoid函数而不是ReLU
    Figure 2给出了验证集上整个训练过程中的正确率,可以看出,达到Inception相同正确率的水平,BN-x5训练步数最少。
    Batch Normalization详细解读

Figure 3给出了最大正确率和训练步数的表格:
Batch Normalization详细解读

5 Conclusion

给出keras中卷积层的BN实现的源代码:

input_shape = self.input_shape  
reduction_axes = list(range(len(input_shape)))  
del reduction_axes[self.axis]  
broadcast_shape = [1] * len(input_shape)  #len()用来统计行数
broadcast_shape[self.axis] = input_shape[self.axis]  
if train:  
    m = K.mean(X, axis=reduction_axes)  #求各维度均值
    brodcast_m = K.reshape(m, broadcast_shape)  #展开均值
    std = K.mean(K.square(X - brodcast_m) + self.epsilon, axis=reduction_axes)   #求方差
    std = K.sqrt(std)  #求标准差
    brodcast_std = K.reshape(std, broadcast_shape)  #展开标准差
    mean_update = self.momentum * self.running_mean + (1-self.momentum) * m  #更新均值
    std_update = self.momentum * self.running_std + (1-self.momentum) * std  #更新方差
    self.updates = [(self.running_mean, mean_update),  
                     (self.running_std, std_update)]  
    X_normed = (X - brodcast_m) / (brodcast_std + self.epsilon)  #标准化
else:  
    brodcast_m = K.reshape(self.running_mean, broadcast_shape)  
    brodcast_std = K.reshape(self.running_std, broadcast_shape)  
    X_normed = ((X - brodcast_m) /  
                 (brodcast_std + self.epsilon))  
out = K.reshape(self.gamma, broadcast_shape) * X_normed + K.reshape(self.beta, broadcast_shape) 

References


  1. Improving predictive inference under covariate shift by weighting the log-likelihood function
  2. LeCun, Y., Bottou, L., Orr, G., and Muller, K. Efficient
    backprop. In Orr, G. and K., Muller (eds.), Neural Networks:
    Tricks of the trade. Springer, 1998b.
  3. Wiesler, Simon and Ney, Hermann. A convergence analysis
    of log-linear training. In Shawe-Taylor, J., Zemel,
    R.S., Bartlett, P., Pereira, F.C.N., andWeinberger, K.Q.
    (eds.), Advances in Neural Information Processing Systems
    24, pp. 657–665,Granada, Spain, December 2011.
  4. http://ufldl.stanford.edu/wiki/index.php/%E7%99%BD%E5%8C%96
  5. LeCun, Y., Bottou, L., Orr, G., and Muller, K. Efficient
    backprop. In Orr, G. and K., Muller (eds.), Neural Networks:
    Tricks of the trade. Springer, 1998b.
  6. Saxe, Andrew M., McClelland, James L., and Ganguli,
    Surya. Exact solutions to the nonlinear dynamics
    of learning in deep linear neural networks. CoRR,
    abs/1312.6120, 2013.
相关标签: BN 神经网络

上一篇:

下一篇: