当前位置: 代码迷 >> 综合 >> 为什么 dot-product attention 需要被 scaled?
  详细解决方案

为什么 dot-product attention 需要被 scaled?

热度:23   发布时间:2023-11-14 12:30:48.0

前言

注意力机制也有很多种类,不同的注意力机制对应着不同的对齐分数(alignment score)计算方式。有关注意力机制的总结,大家可以看看这篇博客:Attention? Attention!

在 Attention Is All You Need 这篇论文中,有提到两种较为常见的注意力机制:additive attention 和 dot-product attention。并讨论到,当 query 和 key 向量维度 dkd_kdk? 较小时,这两种注意力机制效果相当,但当 dkd_kdk? 较大时,additive attention 要优于 dot-product attention. 但是 dot-product attention 在计算方面更具有优势。为了利用 dot-product attention 的优势且消除当 dkd_kdk? 较大时 dot-product attention 的不足,原文采用 scaled dot-product attention。

正文

那造成这种情况(但当 dkd_kdk? 较大时,additive attention 要优于 dot-product attention)的原因是什么?下面是原论文中的解释(当 dkd_kdk? 较大时,向量内积的值也会容易变得很大,这时 softmax 函数的梯度会非常的小)。

We suspect that for large values of dkd_kdk?, the dot products grow large in magnitude, pushing the softmax function into regions where it has extremely samll gradients.

我们知道,计算完各个 key 的对齐分数后需要将所有 key 的对齐分数输入到 softmaxsoftmaxsoftmax 激活函数中,得到规范化的注意力权重。

dot-product attention 中的对齐分数的计算公式为:
score(q,k)=qTkscore(q, k) = q^T k score(q,k)=qTk

先解释:为什么当 dkd_kdk? 较大时,向量内积容易取很大的值(借用原论文的注释)

假设 query 和 key 向量中的元素都是相互独立的均值为 0,方差为 1 的随机变量,那么这两个向量的内积 qTk=∑i=1dkqikiq^T k = \sum_{i=1}^{d_k} q_ik_iqTk=i=1dk??qi?ki? 的均值为 0,而方差为 dkd_kdk?.

证明:
已知 E[qi]=E[ki]=0,Var(qi)=Var(ki)=1\text{E}[q_i] = \text{E}[k_i] = 0,\ \text{Var}(q_i)=\text{Var}(k_i)=1E[qi?]=E[ki?]=0, Var(qi?)=Var(ki?)=1.

由于 qiq_iqi?kik_iki? 相互独立,则两者的协方差为 0:
Cov(qi,ki)=E[(qi?E[qi])(ki?E[ki])]=E[qiki]?E[qi]E[ki]=0\begin{aligned} \text{Cov}(q_i,k_i) &= \text{E}\left[\left(q_i-\text{E}[q_i]\right)\left(k_i-\text{E}[k_i]\right)\right] \\ &= \text{E}[q_ik_i] - \text{E}[q_i] \text{E}[k_i] \\ &= 0 \end{aligned} Cov(qi?,ki?)?=E[(qi??E[qi?])(ki??E[ki?])]=E[qi?ki?]?E[qi?]E[ki?]=0?
故得 E[qiki]=E[qi]E[ki]=0\text{E}[q_ik_i] = \text{E}[q_i] \text{E}[k_i] = 0E[qi?ki?]=E[qi?]E[ki?]=0.

对于方差,有:
Var(qi)=E[qi2]?(E[qi])2=E[qi2]=1Var(ki)=E[ki2]=1\begin{aligned} \text{Var}(q_i) &= \text{E}[q_i^2] - (\text{E}[q_i])^2\\ &= \text{E}[q_i^2] \\ &= 1 \\ \text{Var}(k_i) &= \text{E}[k_i^2] = 1 \end{aligned} Var(qi?)Var(ki?)?=E[qi2?]?(E[qi?])2=E[qi2?]=1=E[ki2?]=1?
故得:
Var(qiki)=E[(qiki)2]?(E[qiki])2=E[qi2]E[ki2]?(E[qi]E[ki])2=Var(qi)Var(ki)=1\begin{aligned} \text{Var}(q_ik_i) &= \text{E}[(q_ik_i)^2] - (\text{E}[q_ik_i])^2 \\ &= \text{E}[q_i^2]\text{E}[k_i^2] - (\text{E}[q_i] \text{E}[k_i])^2 \\ & = \text{Var}(q_i)\text{Var}(k_i) \\ & = 1 \end{aligned} Var(qi?ki?)?=E[(qi?ki?)2]?(E[qi?ki?])2=E[qi2?]E[ki2?]?(E[qi?]E[ki?])2=Var(qi?)Var(ki?)=1?
由于对于两个相互独立的随机变量有如下定义:
E[X+Y]=E[X]+E[Y]Var(X+Y)=Var(X)+Var(Y)+2Cov(X,Y)=Var(X)+Var(Y)\begin{aligned} &\text{E}[X+Y] = \text{E}[X] +\text{E}[Y]\\ &\text{Var(X+Y)} = \text{Var(X)} + \text{Var(Y)} + 2\text{Cov}(X,Y) \\ &\qquad \qquad \ \ \ =\text{Var(X)} + \text{Var(Y)} \end{aligned} ?E[X+Y]=E[X]+E[]Var(X+Y)=Var(X)+Var(Y)+2Cov(X,Y)   =Var(X)+Var(Y)?
综上,可得:
E[qTk]=∑i=1dkE[qiki]=0Var(qTk)=∑i=1dkVar(qiki)=dk\begin{aligned} &\text{E}[q^T k ] = \sum_{i=1}^{d_k} \text{E}[q_ik_i] = 0\\ &\text{Var}(q^T k) = \sum_{i=1}^{d_k} \text{Var}(q_ik_i) = d_k \end{aligned} ?E[qTk]=i=1dk??E[qi?ki?]=0Var(qTk)=i=1dk??Var(qi?ki?)=dk??
所以,可以看出,当 dkd_kdk? 较大时,qTkq^TkqTk 的方差较大,不同的 key 与同一个 query 算出的对齐分数可能会相差很大,有的远大于 0,有的则远小于 0.

再解释:向量内积的值(对齐分数)较大时,softmax 函数梯度很小

先介绍一下 softmax 函数:

softmaxsoftmaxsoftmax 函数是 logistic (或 sigmoid)函数在多类问题上的引申(有关于 sigmoid 函数的信息可查看我的另一篇博客),记为 SSS,其公式为:
S(xi)=exi∑j=0nexjS(x_i) = \frac{e^{x_i}}{\sum_{j=0}^n e^{x_j}} S(xi?)=j=0n?exj?exi??
S(xi)S(x_i)S(xi?) 求偏导,可得:
??xiS(xi)=S(xi)(1?S(xi))??xjS(xi)=?S(xi)S(xj)\begin{aligned} \frac{\partial}{\partial x_i} S(x_i) &= S(x_i)(1-S(x_i)) \\ \frac{\partial}{\partial x_j} S(x_i) &= -S(x_i)S(x_j) \end{aligned} ?xi???S(xi?)?xj???S(xi?)?=S(xi?)(1?S(xi?))=?S(xi?)S(xj?)?
从上面的结果可以看出:

  • xix_ixi? 相对于其他的 xj(j≠i)x_j(j \neq i)xj?(j??=i) 特别大时,S(xi)S(x_i)S(xi?) 趋近于 1,则 ??xiS(xi)\frac{\partial}{\partial x_i} S(x_i)?xi???S(xi?)??xiS(xj)\frac{\partial}{\partial x_i} S(x_j)?xi???S(xj?) 都趋近于 0.
  • xix_ixi? 相对较小时,S(xi)S(x_i)S(xi?) 趋近于 0,则 ??xiS(xi)\frac{\partial}{\partial x_i} S(x_i)?xi???S(xi?)??xiS(xj)\frac{\partial}{\partial x_i} S(x_j)?xi???S(xj?) 也都趋近于 0.

也就是,xix_ixi? 趋于 0 或 1 时,上述的两种偏导数都趋于零

现在,我们就可以把这里的 xix_ixi? 替换成前一部分讲到的 query 和 key 向量的内积 qTkq^T kqTk 了。

在前一部分我们有得出结论:当 dkd_kdk? 较大时,qTkq^TkqTk 的方差较大,不同的 key 与同一个 query 算出的对齐分数可能会相差很大,有的远大于 0,有的则远小于 0.

所以,当 dkd_kdk? 较大时,很有可能存在某个 key,其与 query 计算出来的对齐分数远大于其他的 key 与该 query 算出的对齐分数。这时, softmaxsoftmaxsoftmax 函数对各个 qTkq^TkqTk 的偏导数都趋于 0.

其结果就是, softmaxsoftmaxsoftmax 函数梯度过低(趋于零),使得模型误差反向传播(back-propagation)经过 softmaxsoftmaxsoftmax 函数后无法继续传播到模型前面部分的参数上,造成这些参数无法得到更新,最终影响模型的训练效率。

那么如何消除如上 dot-product attention 的问题呢?一种方法就是论文中的对 dot-product attention 进行缩放(除以 dk\sqrt{d_k}dk? ?),获得 scaled dot-product attention。其对齐分数的计算公式为:
score(q,k)=qTkdkscore(q, k) = \frac{q^T k}{\sqrt{d_k}} score(q,k)=dk? ?qTk?
根据方差的计算法则:Var(kx)=k2Var(x)\text{Var}(kx) = k^2\text{Var}(x)Var(kx)=k2Var(x),可知缩放后,score(q,k)score(q,k)score(q,k) 的方差由原来的 dkd_kdk? 缩小到了 1. 这就消除了 dot-product attention 在 dkd_kdk? 较大时遇到的问题。这时,softmax 函数的梯度就不容易趋近于零了。

这就是为什么 dot-product attention 需要被 scaled.

总结

本博客基于随机变量的期望和方差以及 softmaxsoftmaxsoftmax 函数的性质,详细说明了——为什么 dot-product attention 需要被 scaled.

参考源

  • Attention Is All You Need
  • Attention? Attention!

推荐资源(Transformer 相关)

  • The Illustrated Transformer(概念上)
  • The Annotated Transformer(代码实现上)
  相关解决方案