欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

Batch Normalization--全连接神经网络和卷积神经网络实战

程序员文章站 2022-07-06 11:04:53
...

Batch Normalization原理

网上博客一大堆,说的也很明白,这里就简单的说一下我的个人理解:
  1. 对每一个特征值进行 0均值化,利于神经网络拟合时,对于自身的参数b,无需修改很多次,就可以达到收敛。(因为b的初始值是设为0的)
  2. 去相关化,由于图像信息相邻像素间的信息有很多是相关的,去相关虽然有一个的训练精度损失,但是更有易于分类。

Batch Normalization好处

  1. 训练收敛速度快!
  2. 训练对于drop_out, 正则化参数, 衰减系数等容错能力更强!
  3. 还有很多优点,但博主现在还没有深入研究

全连接BN和非BN的对比

Batch Normalization--全连接神经网络和卷积神经网络实战

上图是没有采用BN时候,利用SGD+动量法来优化的损失函数和分类准确率的变化趋势。
再来看一下采用了BN算法的结果

Batch Normalization--全连接神经网络和卷积神经网络实战

这效果也太好了!!!!!

全连接神经网络Batch Normalization实现

BN正向传播伪代码

输入: 一批量的样本x、 可学习参数gamma、 可学习参数beta

输出 : BN化的样本

算法:

  1. 计算批量样本的均值
  2. 计算批量样本方差
  3. 计算 x’ = (x - 均值) / sqrt(方差)
  4. 输出新样本值为 y = gamma * x’ + beta

代码实现如下

    sample_mean = np.mean(x, axis = 0)
    sample_var = np.var(x, axis = 0)
    x_hat = (x - sample_mean) / (np.sqrt(sample_var + eps))
    out = gamma * x_hat + beta
    cache = (gamma, x, sample_mean, sample_var, eps, x_hat)
    running_mean = momentum * running_mean + (1 - momentum) * sample_mean
    running_var = momentum * running_var + (1 - momentum) * sample_var

BN后向传播实现

这里就得先看看其求导的公式了,图片如下:

Batch Normalization--全连接神经网络和卷积神经网络实战

这是求导的公式图片

Batch Normalization--全连接神经网络和卷积神经网络实战

代码实现如下:

  #cache中存着前向传播时候的参数
  gamma, x, sample_mean, sample_var, eps, x_hat = cache
  N = x.shape[0]
  dx_hat = dout * gamma
  #计算方差对结果的偏导
  dvar = np.sum(dx_hat* (x - sample_mean) * -0.5 * np.power(sample_var + eps, -1.5), axis = 0)
  #计算均值对结果的偏导
  dmean = np.sum(dx_hat * -1 / np.sqrt(sample_var +eps), axis = 0) + dvar * np.mean(-2 * (x - sample_mean), axis =0)
  #计算该批量样本对结果的偏导
  dx = 1 / np.sqrt(sample_var + eps) * dx_hat + dvar * 2.0 / N * (x-sample_mean) + 1.0 / N * dmean
  #计算可学习参数gamma对结果的偏导
  dgamma = np.sum(x_hat * dout, axis = 0)
  #计算可学习参数beta对结果的偏导
  dbeta = np.sum(dout , axis = 0)
return dx, dgamma, dbeta
相关标签: 神经网络 batch