Layer Normalization 中的不变性Invariance 分析
- 1.介绍
- 2.计算
-
- 2.1计算矩阵
- 2.2 μ 与 σ 的计算
- 2.3 数据发生的变化
-
- 2.3.1权重矩阵**W**缩放
- 2.3.2权重矩阵**W**偏置
- 2.3.3其它
- 2.3六种变化的对比
1.介绍
引用
文章引用2016年的Layer Normalization 这篇文章,总结文章中所提及的层归一化、批量归一化、权重归一化的不变性对比。
原文链接:layer normalization
计算公式
The proposed layer normalization is related to batch normalization and weight normalization. Although, their normalization scalars are computed differently, these methods can be summarized as normalizing the summed inputs a i a_{i} ai? to a neuron through the two scalars μ and σ. They also learn an adaptive bias b and gain g for each neuron after the normalization.
h i = f ( g i σ i ( a i ? μ i ) + b i ) h_{i} = f( \dfrac{g_{i}}{σ_{i}}(a_{i} ?μ_{i}) + b_{i}) hi?=f(σi?gi??(ai??μi?)+bi?)
此处提出的层归一化与批量归一化和权重归一化有关。尽管它们的归一化标量的计算方式不同,但这些方法可以概括为通过两个标量 μ 和 σ 对神经元的总和输入 a i a_{i} ai? 进行归一化。他们还在标准化后为每个神经元学习了一个自适应偏差 b 和增益 g。
h i = f ( g i σ i ( a i ? μ i ) + b i ) h_{i} = f( \dfrac{g_{i}}{σ_{i}}(a_{i} ?μ_{i}) + b_{i}) hi?=f(σi?gi??(ai??μi?)+bi?)
下表提供了3种归一化的不变性对比
所谓不变性,即即使计算数据或参数矩阵发生变化,如:放缩和偏移,经过归一化后使输出数据不发生改变。 如此,使得网络更加稳定,加速训练。
2.计算
根据神经网络的计算矩阵,推导出归一化是如何将发生变化的数据拉回到原输出上的。
2.1计算矩阵
权重矩阵W
W 11 W_{11} W11? | W 21 W_{21} W21? |
W 12 W_{12} W12? | W 22 W_{22} W22? |
W 13 W_{13} W13? | W 23 W_{23} W23? |
输入矩阵X
X 11 X_{11} X11? | X 12 X_{12} X12? | X 13 X_{13} X13? |
X 21 X_{21} X21? | X 22 X_{22} X22? | X 23 X_{23} X23? |
输出矩阵summed input
X 11 ? W 11 + X 12 ? W 12 + X 13 ? W 13 X_{11}*W_{11}+X_{12}*W_{12}+X_{13}*W_{13} X11??W11?+X12??W12?+X13??W13? | X 11 ? W 21 + X 12 ? W 22 + X 13 ? W 23 X_{11}*W_{21}+X_{12}*W_{22}+X_{13}*W_{23} X11??W21?+X12??W22?+X13??W23? | |
X 21 ? W 11 + X 22 ? W 12 + X 23 ? W 13 X_{21}*W_{11}+X_{22}*W_{12}+X_{23}*W_{13} X21??W11?+X22??W12?+X23??W13? | X 21 ? W 21 + X 22 ? W 22 + X 23 ? W 23 X_{21}*W_{21}+X_{22}*W_{22}+X_{23}*W_{23} X21??W21?+X22??W22?+X23??W23? |
即X*W = input
2.2 μ 与 σ 的计算
batch norm:每个批次第 L 层的第 i 个单元的均值作为 μ ,即计算batch维度的统计量
μ i l = E [ a i l ] , σ i l = E ( a i l ? μ i l ) 2 μ^{l}_{i} = \Epsilon [a^{l}_{i}] , σ^{l}_{i} = \sqrt{\Epsilon (a^{l}_{i} - μ^{l}_{i})^{2}} μil?=E[ail?],σil?=E(ail??μil?)2?
layer norm:每个样本第 L 层的H个单元的均值作为 μ ,即计算layer维度的统计量
μ l = 1 H Σ i = 1 H a i l , σ l = 1 H Σ i = 1 H ( a i l ? μ l ) 2 μ^{l} = \dfrac{1}{H} \Sigma_{i=1}^{H}a^{l}_{i} , σ^{l} = \sqrt{ \dfrac{1}{H}\Sigma_{i=1}^{H}(a^{l}_{i} - μ^{l})^{2}} μl=H1?Σi=1H?ail?,σl=H1?Σi=1H?(ail??μl)2?
wight norm:在权重归一化中, μ 是 0, σ = ‖w‖2
因此对应到计算矩阵中
batch norm:计算矩阵的变化在列维度上的均值μ 和 σ 变化,以第一列为例
1 / 2 ( x 11 ? w 11 + x 12 ? w 12 + x 13 ? w 13 ) + 1 / 2 ( x 21 ? w 11 + x 22 ? w 12 + x 23 ? w 13 ) 1/2(x_{11}*w_{11}+x_{12}*w_{12}+x_{13}*w_{13})+1/2(x_{21}*w_{11}+x_{22}*w_{12}+x_{23}*w_{13}) 1/2(x11??w11?+x12??w12?+x13??w13?)+1/2(x21??w11?+x22??w12?+x23??w13?)
layer norm:计算矩阵的变化在行维度上的均值μ 和 σ 变化,以第一行为例
1 / 2 ( x 11 ? w 11 + x 12 ? w 12 + x 13 ? w 13 ) + 1 / 2 ( x 11 ? w 21 + x 12 ? w 22 + x 13 ? w 23 ) 1/2(x_{11}*w_{11}+x_{12}*w_{12}+x_{13}*w_{13})+1/2(x_{11}*w_{21}+x_{12}*w_{22}+x_{13}*w_{23}) 1/2(x11??w11?+x12??w12?+x13??w13?)+1/2(x11??w21?+x12??w22?+x13??w23?)
2.3 数据发生的变化
对数据矩阵的变化分为三种,在论文中叫做 matrix re-scaling 、matrix re-centering 、 vector re-scaling,即矩阵的缩放、偏置和向量缩放。按数据变化对象又分为两类,分别是输入矩阵X、权重矩阵W,组合起来有六种变化。
2.3.1权重矩阵W缩放
对W 分别乘2,会引起输出矩阵summed input怎样的变化。
befor: X 11 ? W 11 + X 12 ? W 12 + X 13 ? W 13 X_{11}*W_{11}+X_{12}*W_{12}+X_{13}*W_{13} X11??W11?+X12??W12?+X13??W13?
after: 2 ? ( X 11 ? W 11 + X 12 ? W 12 + X 13 ? W 13 ) 2*(X_{11}*W_{11}+X_{12}*W_{12}+X_{13}*W_{13}) 2?(X11??W11?+X12??W12?+X13??W13?)
将原来的输出记为 Y 11 Y 12 Y 21 Y 22 Y_{11} Y_{12} Y_{21} Y_{22} Y11?Y12?Y21?Y22?,则新的矩阵input为
2 Y 11 2Y_{11} 2Y11? | 2 Y 12 2Y_{12} 2Y12? |
2 Y 21 2Y_{21} 2Y21? | 2 Y 22 2Y_{22} 2Y22? |
根据2.2中的均值方差计算公式,μ 和 σ 产生的变化
batch norm下的均值μ变为2倍,σ 也变为2倍
layer norm下的均值μ变为2倍,σ 也变为2倍
weight norm下的范数σ = ‖w‖2变为原来的2倍
那么经过norm后得到的输出会产生怎样的变化:
根据公式: h i = f ( g i σ i ( a i ? μ i ) + b i ) h_{i} = f( \dfrac{g_{i}}{σ_{i}}(a_{i} ?μ_{i}) + b_{i}) hi?=f(σi?gi??(ai??μi?)+bi?),其中&g_{i}$ 和 b i b_{i} bi? 属于可学习参数,这里不管。
befor | after | |
---|---|---|
batch norm | (input-μ)/σ | (2input-2μ)/2σ |
layer norm | (input-μ)/σ | (2input-2μ)/2σ |
weight norm | input/norm*w | input/(2norm) * (2w) |
由于分子分母的2可以消掉,所以得到的权重矩阵的放缩并不影响输出的结果,因此称其具有不变性。
2.3.2权重矩阵W偏置
对W 分别加2,会引起输出矩阵summed input怎样的变化。
befor: X 11 ? W 11 + X 12 ? W 12 + X 13 ? W 13 X_{11}*W_{11}+X_{12}*W_{12}+X_{13}*W_{13} X11??W11?+X12??W12?+X13??W13?
after: ( X 11 ? W 11 + X 12 ? W 12 + X 13 ? W 13 ) + 2 ( X 11 + X 12 + X 13 ) (X_{11}*W_{11}+X_{12}*W_{12}+X_{13}*W_{13}) + 2(X_{11}+X_{12}+X_{13}) (X11??W11?+X12??W12?+X13??W13?)+2(X11?+X12?+X13?)
将原来的输出记为 Y 11 Y 12 Y 21 Y 22 Y_{11} Y_{12} Y_{21} Y_{22} Y11?Y12?Y21?Y22?,则新的矩阵input为
2 Y 11 + 2 ( X 11 + X 12 + X 13 ) 2Y_{11}+2(X_{11}+X_{12}+X_{13}) 2Y11?+2(X11?+X12?+X13?) | 2 Y 12 + 2 ( X 11 + X 12 + X 13 ) 2Y_{12}+2(X_{11}+X_{12}+X_{13}) 2Y12?+2(X11?+X12?+X13?) |
2 Y 21 + 2 ( X 21 + X 22 + X 23 ) 2Y_{21}+2(X_{21}+X_{22}+X_{23}) 2Y21?+2(X21?+X22?+X23?) | 2 Y 22 + 2 ( X 21 + X 22 + X 23 ) 2Y_{22}+2(X_{21}+X_{22}+X_{23}) 2Y22?+2(X21?+X22?+X23?) |
由于每一行加的一样所以对于按行计算的layer norm友好,对按列计算的batch norm不友好。
batch norm下的均值μ变为一个新值μ’,σ 也变为一个新值σ‘
layer norm下的均值μ变为μ + 2 X 1 X_{1} X1?,σ 不变
weight norm下的范数σ = ‖w‖2变为一个新值‖w‖2’
那么经过norm后得到的输出会产生怎样的变化:
befor | after | |
---|---|---|
batch norm | (input-μ)/σ | (input’ - μ’)/σ’ |
layer norm | (input-μ)/σ | (input + 2 -(μ +2))/σ |
weight norm | input/norm*w | input/(norm’) * (w+2) |
由于layer norm的 +2 -2 相抵消,故只有layer norm 在 权重矩阵W偏置保持不变性。
2.3.3其它
上述两个例子很好的说明了,不变性的推导方法,故其余的推导不多做赘述,其中涉及的是概率论的基本知识。
对于 vector re-scaling 向量缩放变化,其推导在于只对某一列(向量)做缩放的操作。
2.3六种变化的对比
condition | batch norm | layer norm | weight norm | |||
---|---|---|---|---|---|---|
befor | after | befor | after | befor | after | |
对W矩阵乘2 | (input-μ)/σ | (2input-2μ)/2σ | (input-μ)/σ | (2input-2μ)/2σ | input/norm*w | input/(2norm) * (2w) |
对W矩阵加2 | (input-μ)/σ | (input’ - μ’)/σ’ | (input-μ)/σ | (input + 2t -(μ +2t))/σ | input/norm*w | input/(norm’) * (w+2) |
对W的第一列乘2 | (input-μ)/σ | (2input-2μ)/2σ | (input-μ)/σ | (input-μ’)/σ‘ | input/norm*w | input/(2norm) * (2w) |
对X矩阵称2 | (input-μ)/σ | (2input-2μ)/2σ | (input-μ)/σ | (2input-2μ)/2σ | input/norm*w | 2*input/norm *w |
对X矩阵加2 | (input-μ)/σ | (input + 2t -(μ +2t))/σ | (input-μ)/σ | (input’ - μ’)/σ’ | input/norm*w | (input+2)/norm*w |
对X的第一行乘2 | (input-μ)/σ | (input’ - μ’)/σ’ | (input-μ)/σ | (2input-2μ)/2σ | input/norm*w | input/norm*w |
其中2t 代表一个增量,input’ 、 μ’ 、σ’ 代表这是一个新值和原值无关。