参数量方面的确控制的不错,不过Flops和效果方面…还有待观察 https://arxiv.org/pdf/2010.03045.pdf https://github.com/LandskapeAI/triplet-attention
受益于在通道或空间位置之间建立相互依赖性的能力,注意力机制最近已得到广泛研究,并广泛用于各种计算机视觉任务中。在本文中,我们研究了轻量但有效的注意力机制,并提出了triplet attention,这是一种通过使用三分支结构捕获跨维度交互来计算注意力权重的新方法。对于输入张量,triplet attention通过旋转操作,然后使用残差变换建立维度间的依存关系,并以可忽略的计算开销对通道间和空间信息进行编码。我们的方法既简单又有效,并且可以轻松地作为附加模块插入经典骨干网络。我们证明了该方法在各种挑战性任务中的有效性,包括ImageNet-1k上的图像分类以及MSCOCO和PASCAL VOC数据集上的目标检测。此外,我们可视化了GradCAM和GradCAM ++结果,提供了对triplet attention表现的广泛见解。大量实验结果证明了我们的直觉,即在计算注意力权重时捕获跨维度依赖性的重要性。
关键词:跨通道交互信息
作者观察到CBAM中的通道注意力方法虽然提供了显着的性能改进,却不是因为跨通道交互。然而,作者展示了捕获通道交互时对性能会产生有利的影响。此外,CBAM在计算通道注意力时结合了降维功能。这在捕获通道之间的非线性局部依赖关系方面是多余的。
因此,本文提出了可以有效解决跨维度交互的triplet attention。相较于以往的注意力方法,主要有两个优点:
1.可以忽略的计算开销
2.强调了多维交互而不降低维度的重要性,因此消除了通道和权重之间的间接对应。
如上图所示,Triplet Attention主要包含3个分支,其中两个分支分别用来捕获通道C维度和空间维度W/H之间的跨通道交互,剩下的一个分支就是传统的通道注意力权重的计算。
A.网络结构 具体的网络结构如上图所示:
1.第一个分支:通道注意力计算分支,输入特征经过Z-Pool,再接着7 x 7卷积,最后Sigmoid激活函数生成通道注意力权重
2.第二个分支:通道C和空间W维度交互捕获分支,输入特征先经过permute,变为H X C X W维度特征,接着在H维度上进行Z-Pool,后面操作类似。最后需要经过permuter变为C X H X W维度特征,方便进行element-wise相加
3.第三个分支:通道C和空间H维度交互捕获分支,输入特征先经过permute,变为W X H X C维度特征,接着在W维度上进行Z-Pool,后面操作类似。最后需要经过permuter变为C X H X W维度特征,方便进行element-wise相加
最后对3个分支输出特征进行相加求Avg
B.Z-Pool
对输入进行MaxPooling和AvgPooling,输出2 X H X W特征 C.代码:
class ZPool(nn.Module): def forward(self, x): return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 ) class AttentionGate(nn.Module): def __init__(self): super(AttentionGate, self).__init__() kernel_size = 7 self.compress = ZPool() self.conv = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False) def forward(self, x): x_compress = self.compress(x) x_out = self.conv(x_compress) scale = torch.sigmoid_(x_out) return x * scale class TripletAttention(nn.Module): def __init__(self, no_spatial=False): super(TripletAttention, self).__init__() self.cw = AttentionGate() self.hc = AttentionGate() self.no_spatial=no_spatial if not no_spatial: self.hw = AttentionGate() def forward(self, x): x_perm1 = x.permute(0,2,1,3).contiguous() x_out1 = self.cw(x_perm1) x_out11 = x_out1.permute(0,2,1,3).contiguous() x_perm2 = x.permute(0,3,2,1).contiguous() x_out2 = self.hc(x_perm2) x_out21 = x_out2.permute(0,3,2,1).contiguous() if not self.no_spatial: x_out = self.hw(x) x_out = 1/3 * (x_out + x_out11 + x_out21) else: x_out = 1/2 * (x_out11 + x_out21) return x_out1.模块复杂度,几乎无增长:
2.ImageNet:参数量不涨,但是Flops比别人多啊… 3.Object Detection:效果不咋样…
