论文标题: Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift
论文链接: https://arxiv.org/pdf/1502.03167.pdf
Introduction
这篇论文的主要工作就是提出了如今深度学习常见的 Batch Normalization,来加速深层网络训练的收敛,以及在 Inception v1 的基础上做了一些训练方式和结构上的改进(Inception v2),其在 ImageNet 分类任务上是超越了当时最好的成绩:4.9% top-5 validation error。
Motivation
这篇论文拟解决的问题是深度学习中很关键的问题,深度神经网络模型的训练为什么困难、收敛慢? 这个问题的解决在之前的工作中,有从尝试新的激活函数角度,如 ReLU、Maxout、PReLU等;有从权值初始化方法角度,如Xavier初始化、Kaiming初始化等,但收益相对都不是很高。
作者也指出了因为这个问题的存在使得在设置 learning rate、初始化方法还有激活函数上得要很慎重,而这个问题的实质是因为在训练过程中,随着前一层参数的改变,哪怕很小的改变,也会因为网络加深而被放大,而这种改变使得每一层输入的分布发生改变,从而每一层需要持续地适应这种改变,这种现象在论文中被称为 Internal covariate shift (ICS)。
从另外一个角度看,把网络的中间层看作是 sub-network 以及采用 sigmoid 激活函数 z=g(Wu+b)z = g(Wu + b)z=g(Wu+b),其中 uuu 是该层的输入,而 WWW 和 bbb 是学习的参数。随着 ∣x∣|x|∣x∣ 增大,g′(x)g'(x)g′(x) 趋向于0,这也就意味着会发生我们熟知的梯度消失(本质是落入饱和区),后来 ReLU 和 一些初始化方法是较好的缓解了这个问题。
换一个思路想,造成这个问题比较直观的原因是因为 ∣x∣|x|∣x∣ 的增大,即输入分布的变化,以及上面提到的 ICS 现象也是因为每一层的输入分布发生改变,那么很自然的想法就是,如果能确保输入的分布稳定,那么就不容易陷入饱和区域,从而梯度消失的问题也就得到很好的缓和,训练收敛的速度也随之提升了,问题就迎刃而解了。
看着好像蛮简单的,但还是有随之而来的问题:
- 确保输入的分布稳定,即 Normalization ,该怎么做?
- Normalization 能使得输入不落入饱和区域,反过来就是限制输入落入激活函数的线性区域,那这样网络不就失去了非线性的表达能力了吗,这该怎么弥补?
上面的两个也就是 Batch Normalization 这篇论文工作的核心。
Methods
白化(whitening) 作为一个很重要的数据预处理方法,它能使得模型训练收敛的更快,而白化一般包含两个目的:
- 去除特征之间的相关性(特征独立);
- 使得所有特征具有相同的均值和方差(同分布)
白化可以使得模型的输入标准化为均值为0,方差为1,那可以考虑将白化拓展到每一层的输入,就能使得每一层的分布趋于稳定。然而,标准的白化操作代价昂贵,特别是我们还希望白化操作是可微的,保证白化操作可以通过反向传播来更新梯度,即:
?Norm(x,X)?xand?Norm(x,X)?X\frac{\partial \mathrm{Norm}(x, \mathcal{X})}{\partial x} \quad and \quad \frac{\partial \mathrm{Norm}(x, \mathcal{X})}{\partial \mathcal{X}} ?x?Norm(x,X)?and?X?Norm(x,X)?
Training and Inference with Batch-Normalizaed Networks
于是就有了本篇论文的工作,Batch Normalization(BN),即对白化做了简化,将其作用到每一层的输入,使得输入分布稳定。
第一个简化是只对每一个特征维度做 Normalization,并没有考虑特征之间去相关,论文中也提到了,“such normalization speeds up convergence, even when the features are not decorrelated.”,即:
x^(k)=x(k)?E[x(k)]Var[x(k)]\widehat{x}^{(k)} = \frac{x^{(k)}-\mathrm{E}[x^{(k)}]}{\sqrt{\mathrm{Var}[x^{(k)}]}} x
(k)=Var[x(k)]?x(k)?E[x(k)]?
第二个简化是不通过整个数据集来统计 E[xk]\mathrm{E}[x^{k}]E[xk] 和 Var[x(k)]\mathrm{Var}[x^{(k)}]Var[x(k)],而是通过一个 mini-batch 内的激活值来估计均值和方差 ,得到:
μB←1m∑i=1mxiσB2←1m∑i=1m(xi?μB)2x^i←xi?μBσB2+?\mu_{\mathcal{B}} \leftarrow \frac{1}{m} \sum_{i=1}^m {x_i} \\ \sigma_{\mathcal{B}}^2 \leftarrow \frac{1}{m} \sum_{i=1}^m {(x_i - \mu_{\mathcal{B}})^2} \\ \widehat{x}_i \leftarrow \frac{x_i - \mu_{\mathcal{B}}}{\sqrt{\sigma_{\mathcal{B}}^2 + \epsilon}} μB?←m1?i=1∑m?xi?σB2?←m1?i=1∑m?(xi??μB?)2x
i?←σB2?+??xi??μB??
到这里 Normalization 就实现了,其实在后续的一些 Normalization 的工作中,都可以统一成上述第三个式子,而不同就是第一个式子和第二个式子怎么求,即如何确定神经元集合(根据集合内激活值估计均值和方差)。
而这样的 Normalization 是存在问题的,如论文中提到 “simply normalizing each input of a layer may change what the layer can represent.”,即简单地做 Normalization 会减低了网络的非线性的表达能力,比如采用 sigmoid 激活函数,normalization 会限制激活值落入到线性区域(近似线性),而这片区域是近似 [-2, 2] 这个区间,而在标准正态分布中,落入 [-2, 2] 的概率是95%。
这也就前面提到问题,作者采用的方式是 “make sure that the transformation inserted in the network can represent the identity transform”,为了实现这种可以恒等变换,引入了两个可学习的参数 γ(k)\gamma^{(k)}γ(k) 和 β(k)\beta^{(k)}β(k),使得:
y(k)=γ(k)x^(k)+β(k)y^{(k)} = \gamma^{(k)} \widehat{x}^{(k)} + \beta^{(k)} y(k)=γ(k)x
(k)+β(k)
而在最极端的情况下,学习的两个参数分别为 γ(k)=Var[x(k)]\gamma^{(k)} = \sqrt{\mathrm{Var[x^{(k)}]}}γ(k)=Var[x(k)]? 和 β(k)=E[x(k)]\beta^{(k)} = \mathrm{E}[x^{(k)}]β(k)=E[x(k)],则可以恢复到原先的激活值,即恢复了网络的表达能力(restore the representation power of the network)。
Batch Normalization 大致的算法过程如下:
BN 采用 mini-batch 来估计均值和方差,这在训练的时候是可行的,但在 inference 或 online inference 时,是单实例的,不存在 mini-batch,所以就无法获得BN计算所需的均值和方差,这就需要利用训练阶段的Batch统计值,估计一个总体的均值和方差,从而实现 inference 阶段的 normalization:
x^=x?E[x]Var[x]+?E[x]←EB[uB]Var[x]←m1+mEB[σB2]\widehat{x}=\frac{x-\mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} \\ \mathrm{E}[x] \leftarrow \mathrm{E}_{\mathcal{B}}[u_{\mathcal{B}}] \\ \mathrm{Var}[x] \leftarrow \frac{m}{1 + m}\mathrm{E}_{\mathcal{B}}[\sigma^2_{\mathcal{B}}] x
=Var[x]+??x?E[x]?E[x]←EB?[uB?]Var[x]←1+mm?EB?[σB2?]
值得注意的是,在PyTorch代码实现的时候,会去采用指数滑动平均(Exponential Moving Average)来实现总体的估计,见PyTorch文档描述:
Mathematically, the update rule for running statistics here is x^new=(1?momentum)×x^+momentum×xt\widehat{x}_{new} = (1 - \mathrm{momentum}) \times \widehat{x} + \mathrm{momentum} \times x_tx new?=(1?momentum)×x +momentum×xt?, where x^\widehat{x}x is the estimated statistic and and xtx_txt? is new observed value.
Batch Normalization enables higher learning rates
在论文摘要部分提到了,“Batch Normalization allows us to use much higher learning rates and be less careful about initialization.”,在还没看正文时,就有疑问,为什么 BN 可以使得采用更大的学习率?作者专门在一小节中对其进行了解释。
通常在训练深层网络,不会使用太大的学习率,因为它易导致梯度爆炸、梯度消失或者陷入到 poor local minima。而前面也提到,BN 能够一定程度上避免输入落入到激活函数的饱和区域,缓解了梯度消失的问题;另外因为每层输入都有 normalization 的存在,缓解了 ICS 的存在,使得每层的输入分布稳定,参数的变化(反向传播的梯度)也趋于稳定(不会因为随着层数加深,参数变化被放大),较好地缓解了梯度爆炸。
论文中称 “Batch Normalization also makes training more resilient to parameter scale” ,并对这种参数更新(梯度)的稳定做了一个分析。因为通常,较大的学习率会使得参数的规模增大(increase the scale of layer parameters),假设增大了 aaa 倍,但因为 BN 的存在,反向传播并不会受到参数增大的影响,从而导致的梯度爆炸:
BN(Wu)=BN((aW)u)?BN((aW)u)?u=?BN(Wu)?u?BN((aW)u)?(aW)=1a??BN(Wu)?W\mathrm{BN}(Wu) = \mathrm{BN}((aW)u) \\ \frac{\partial \mathrm{BN}((aW)u)}{\partial u} = \frac{\partial \mathrm{BN}(Wu)}{\partial u} \\ \frac{\partial \mathrm{BN}((aW)u)}{\partial (aW)} = \frac{1}{a} \cdot \frac{\partial \mathrm{BN}(Wu)}{\partial W} BN(Wu)=BN((aW)u)?u?BN((aW)u)?=?u?BN(Wu)??(aW)?BN((aW)u)?=a1???W?BN(Wu)?
第三个式子还能看到,更大的参数返回会导致更小的梯度,从而上面的结论也得到了验证:BN 的存在,使得参数变化趋于稳定,故能使用更大的学习率。
另外,论文中还提到了一点,暂时还不理解,"We further conjecture that Batch Normalization may lead the layer Jacobians to have singular values close to 1, which is known to be beneficial for training "
Batch Normalization regularizes the model
在论文的摘要中提到因为 BN 的存在,使得 Dropout 可以被移除或者减小神经元被 drop 的概率,换句话说,BN 具备了 dropout 提升模型泛化能力(缓解过拟合)的功能,主要是因为通过 batch 内的激活值估计均值和方差,不是根据单一 sample 的值做模型优化,这可以变相地看成是某一种约束。
这边有三点需要注意的是:
- 论文中讨论将BN放在激活函数前好还是后好,“but since u is likely the output of another nonlinearity, the shape of its distribution is likely to change during training, and constraining its first and second moments would not eliminate the covariate shift.”,不是很明白,而 [1] 中提到 “不少研究表明将BN放在激活函数之后效果更好。” ,但并没有给出参考文献,这里还需后面验证。
- 对于卷积层,在估计均值和方差时考虑的神经元激活值集合并不是标量特征,而是一个 feature map 内的特征值,“we jointly normalize all the activations in a mini-batch over all locations. In Alg. 1, we let B be the set of all values in a feature map across both the elements of a mini-batch and spatial locations”
- 在 Normalization 前的线性变换 Wu+bWu + bWu+b,通常会省略掉偏置项 bbb ,因为它的作用会被随后BN的均值给抵消掉(它的作用会被 β\betaβ 代替)。
Experiments
作者通过一个很简单的实验来验证BN是否能较好的缓解ICS问题。在MNIST数据集上,一个简单的三层的神经网络,在每一层之前加入BN,观察收敛速度和每层的输入分布变化,结果如下图:
另外一个实验就是在 ImageNet Classification 任务上,实验的模型结构在 Inception-v1 上做的改进是参考 VGG 中小尺寸的卷积核的思想,将 5×5 的卷积核替换为 两层 3 × 3 的卷积核,并且加宽了网络和引入了BN,另外值得注意的是,在Inception-v2的网络结构设计中,用 stride=2 代替了 Inception-v1 中 max pool 做 feature map 的尺寸缩减。
为了进一步加速 BN networks 训练收敛的速度,作者进一步改进了网络和训练的超参数:
- Increase learning rate:在上面也提到了可以使用更大的学习率,那么增大学习率来加快模型训练;
- Remove dropout:前面同样提到 BN 能一定程度缓解过拟合,那就将 dropout 移除来提速;
- Reduce the L2L_2L2? weight regularization:减少 L2 正则项的权值;
- Accelerate the learning rate decay:因为 BN 加快了模型的训练,所以相应的学习率的衰减也得加快,这里采用了指数衰减;
- Remove Local Response Normalization:这里是移除了 LRN,其实在之前的一些工作中就不用它了;
- Shuffle training examples more thoroughly
- Reduce the photometric distortions
另外,作者还做了一些调整,设置了几个模型:
- BN-Baseline:即相较于 Inception-v1 引入了 BN;
- BN-x5:将学习率增大了5倍,为0.0075
- BN-x30:将学习率增大了30倍,为0.045
- BN-x5-Sigmoid:和 BN-x5 类似,但激活函数换成了Sigmoid。
下面是实验的结果,可以看到仅采用 BN 的Baselin 其收敛速度是快于 Inception v1,另外 BN-x5 更是快了14倍(达到Inception 的准确率),而虽 BN-x30 较之 BN-x5 会慢一些,但其准确率是最高的。有意思的是,论文中还玩了一个小小的文字游戏:we apply Batch Normalization to the bestperforming ImageNet classification network, and show that we can match its performance using only 7% of the training steps, and can further exceed its accuracy by a substantial margin. ,其实这里 7% 是由 BN-x5 实现的,而 BN-x30 是在准确率上做了提升。
Conclusion
Batch Normalization 加速了深层网络的训练,使得现在 Normalization 成了深度学习的标配。而在论文的 Conclusion 部分提出的一些对未来工作的看法也指出了 BN 的一些缺陷,如 BN 在 RNN 上并不那么work,主要原因是过于依赖 Batch 内进行统计,这也使得有了后来的 Layer Normalization 等一系列工作。
References
[1] 深度学习中的Normalization模型, https://zhuanlan.zhihu.com/p/43200897
[2] 深入解读Inception V2之Batch Normalization(附源码), https://zhuanlan.zhihu.com/p/50444499
[3] 详解深度学习中的Normalization,不只是BN(1), https://blog.csdn.net/c9Yv2cf9I06K2A9E/article/details/79276708
[4] Batch Normalization导读, https://zhuanlan.zhihu.com/p/38176412