这篇论文提出一种Attention Module结构,通过堆叠这种结构实现更好的分类效果。其主要贡献有如下三点:
- 堆叠网络结构:整个网络是基于相同的attention module 堆叠构建的,运用了混合注意力机制(后面会讲解什么是混合注意力)。每个模块有不同的注意力类型。
- 引入了残差注意力学习:如果直接堆叠注意力模块会导致明显的性能下降。引入残差学习既能解决这个问题,也能保证网络的深度。
- 使用编码解码结构产生注意力mask,使用mask作为特征的软加权。其目的是指引学习更加值得注意的特征(判断度更加高)
不同阶段有不同的关注点
左边的图示意注意力机制的工作方法,有两条分支,主分支依旧提取特征,旁分支学习一种注意力掩码,告诉网络哪里是值得注意的地方。天空和背景是黑色的,热气球被突出。
右边的图说明不同特征对应的注意力区域也不一样。天空掩码会减少低级的颜色特征的区域。到了高级特征,掩码聚焦在物体上或者物体的一部分上。
另外要知道的是,注意力掩码随着主干路特征不同自动的变化。就是说输入什么样的图,mask就会自动变化成相应的注意力区域。
Attention Module
这个模块有两个分支,一个称trunk branch,另一个是mask branch。trunk branch是用来提取特征的,可以使用当前流行的结构;mask branch是一个编码解码结构,输出是一个和输入同尺寸的特征图。记输入为x,trunk branch的输出为F(x),mask branch的输出为M(x),那么模块输出为
Hi,c(x)=(1+Mi,c(x))?Fi,c(x)H_{i,c}(x) = (1 + M_{i,c}(x))*F_{i,c}(x) Hi,c?(x)=(1+Mi,c?(x))?Fi,c?(x)
先解释一下符号。i就是特征图的空间位置的坐标,c是通道维度上的index。公式的意思很明显了,就是将mask和提取的特征做对应元素相乘,在加上提取的特征。这就是attention residual learning。
attention residual learning中residual的含义
H的公式其实是论文一大亮点。一般来说,我们提取的特征,直接和得到的mask做对应元素相乘就可以了,为什么还要加上一个特征呢?即模块为何不是如下的输出:
Hi,c(x)=Mi,c?Fi,cH_{i,c}(x) = M_{i,c}*F_{i,c}Hi,c?(x)=Mi,c??Fi,c?
论文解释了这样做是不可以的。
首先作者做了实验,上面的做法会导致性能的大幅下降。
然后作者指出:
- 重复使用mask点乘(对应元素相乘)会在深层的输出减少特征图的值。
- 直接点乘破坏了残差学习的恒等映射。因为trunk branch用的是残差单元。
解决的办法就是对mask branch也进行残差学习。如果mask的输出导致效果不好,那么通过引入mask的残差,再不济也是进行恒等映射。
总体结构
整个模型分为三个阶段。每个阶段都至少有一个 Attention Module。每个模块分两路,也是之前提到的。更多细节在下面。
至于实验结构部分就略去了。