前言
在很多近期的transformer工作中,经常提到一个词: relative position bias。用在self attention的计算当中。笔者在第一次看到这个概念时,不解其意,本文用来笔者自己关于relative position bias的理解。
笔者第一次看到该词是在swin transformer。后来在focal transformer和LG-transformer中都看到它。
relative position bias(相对位置偏置)
基本形式如下:
Attention(Q,K,V)=Softmax(QKT+B)VAttention(Q, K, V) = Softmax(QK^T + B)VAttention(Q,K,V)=Softmax(QKT+B)V
其中Q,V∈Rn×dQ, V \in R^{n\times d}Q,V∈Rn×d, B∈Rn×nB \in R^{n \times n}B∈Rn×n,n是token vector的数目。可以看出,B的作用是给attention map QKTQK^TQKT的每个元素加了一个值。其本质就是希望attention map进一步有所偏重。因为attention map中某个值越低,经过softmax之后,该值会更低。对最终特征的贡献就低。
而B并不是一个随便初始化的参数,它有一个完备的使用过程。其基本过程如下:
- 初始化一个n2n^2n2的tensor作为表,同时也是个参数。
- 构建table index,用于根据位置查表。下面再介绍细节
- 前向传播中使用位置查表。
- 反向传播更新表。
在swin transformer的源代码中,可以清楚的看到相对位置偏置的使用过程。
在构造函数中,有以下相关内容:
self.relative_position_bias_table = nn.Parameter(torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH# get pair-wise relative position index for each token inside the windowcoords_h = torch.arange(self.window_size[0])coords_w = torch.arange(self.window_size[1])coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Wwcoords_flatten = torch.flatten(coords, 1) # 2, Wh*Wwrelative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Wwrelative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0relative_coords[:, :, 1] += self.window_size[1] - 1relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Wwself.register_buffer("relative_position_index", relative_position_index)
第一行就是初始化表,
后面的内容就是建立一个可以根据query和key的相对位置查表参数的index。
比如现在有一个2×22 \times 22×2的特征图。设置windows size 为(2,2),我们可以看看relative_position_index长什么样:
torch.Size([4, 4])
tensor([
[4, 3, 1, 0],
[5, 4, 2, 1],
[7, 6, 4, 3],
[8, 7, 5, 4]
])
注意看,主对角线都是4。上三角都是比4小,最低为0;下三角都比4大,且最大为8;一共9个数字,正好等于relative_position_bias_table的宽和高。
以第一行为例,第一个元素为4,第二个元素为3;对应就是方格图中标号为1和2的位置。
其实就是第一个query和第一个key都在标号1的位置,所以相对位置为0,则都使用参数表的第4个偏置;而第2个key中的元素,位置在标号1的右边一格,用参数表的第3个参数。
重点到了:只要query在key’的左边一格,relative_position_index中对应的位置都是3。比如第三行的最后一个数字。第三行对应标号3,只有标号4在其右边,而relative_position_bias_table[2][3]恰好为3。
进而可以观察其他元素之间的位置,可以发现相同的规律。因此,不难得出结论:B中值和query和key的相对位置有关系。相对位置一致的query-key pair,会采用相同的bias。
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
# 根据index提供的相对位置映射查bias,然后在view成可以和attention map计算的B
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
reference
swin transformer
图解Swin Transformer