BMVC2020 Best Paper: Delving Deeper into Anti-aliasing in ConvNets论文解读

    科技2022-07-11  102

    官方代码

    背景

    虽然convNets具有平移不变形,但貌似这种平移不变性还不是特别理想。之前的一些研究发现,把图像在不同的位置crop,对裁剪的图像预测,当crop位置不同,图像得到的分类概率也不一样。对于其他下游任务,如检测,分割,平移都会带来一些错误的预测(aliasing)。此篇论文目的就是减弱这种因平移带来的错判问题(anti-aliasing)。

    方法

    作者首先从信息论奈奎斯特采样率来引入话题。仅仅把第一行的二进制右移一位,然后采样,得到的数值能发生很大的差异。因此,下采样对信息是有破坏作用的。但为了网络的计算量,下采样是不可避免的。因此需要设计出一种尽量少丢失重要信息的下采样方式。

    实际上之前是有一些研究的,他们通常都是在下采样层之前加入高斯模糊。但作者认为,给空间位置加高斯模糊还不是最好的,因为没有考虑通道之间也有频率的变化。而且在一些情况下,过度使用高斯模糊,导致高频信息的丢失,但某些高频信息恰好是需要的(比如下图的边缘,在第三张图里面,已经很淡了)。 作者提出content-aware anti-aliasing模块,用自适应权重来完成加权模糊,让模型自己决定特征的重要程度,并自己预测模糊核系数,同时考虑在空间和通道两方面。

    content-aware anti-aliasing

    该模块的结构图如下 有两个步骤。第一步是用3x3卷积得到一系列maps,然后对maps分成g组,且每组的通道数是k^2。每组内部执行Spatial adaptive anti-aliasing。 所有组结合起来就是Channel-grouped adaptive anti-aliasing

    Spatial adaptive anti-aliasing

    self.conv = nn.Conv2d(in_channels, group*kernel_size*kernel_size, kernel_size=kernel_size, stride=1, bias=False)

    先得到g*k*k个通道的maps。接着每个位置(i,j)都对应了一个k*k的向量,reshape为k x k,作为一个核。 加权对应位置的特征。

    Channel-grouped adaptive anti-aliasing

    将上述操作,作用于每个group,然后concat组成c个通道的特征 作者为了平衡性能和计算量,将特征分为g组(g=8)。同一组的特征的不同channel,都使用相同的weight。作者认为特征channel之间也有相似的,因此可以共享系数。

    这部分代码实现需要参考论文,没发简单的讲明白。 一顿reshape permute操作。 之后接一个stride为2的采样。

    x[:,:,torch.arange(h)%self.stride==0,:][:,:,:,torch.arange(w)%self.stride==0]

    该模块insert在每个下采样之前(也就是每个stage的第一层),代替原本的stride为2的conv。

    凡是下采样的地方,都用这个模块代替。那么resnet有三种用到下采样的地方。第一个是stage0。也就是第一层conv block

    self.maxpool = nn.Sequential(*[nn.MaxPool2d(kernel_size=2, stride=1), Downsample_PASA_group_softmax(kernel_size=filter_size, stride=2, in_channels=planes[0], group=pasa_group)])

    第二个是,每个stage的开头(除了stage1)。

    downsample = [Downsample_PASA_group_softmax(kernel_size=filter_size, stride=stride, in_channels=self.inplanes, group=pasa_group),] if(stride !=1) else [] downsample += [conv1x1(self.inplanes, planes * block.expansion, 1), norm_layer(planes * block.expansion)] # print(downsample) downsample = nn.Sequential(*downsample)

    第三个是,当stride为2,skip connection也需要一次下采样。

    self.conv3 = nn.Sequential(Downsample_PASA_group_softmax(kernel_size=filter_size, stride=stride, in_channels=planes, group=pasa_group), conv1x1(planes, planes * self.expansion))
    Processed: 0.029, SQL: 8