欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

轻量型注意力模块:ULSAM

程序员文章站 2024-03-22 18:25:34
...

ULSAM: Ultra-Lightweight Subspace Attention Module for Compact Convolutional Neural Networks
论文地址

作者提出了一种新的用于紧凑网络神经网络的注意力块(ULSAM),它可以学习每个特征子空间的个体注意力映射,并能够在多尺度、多频率特征学习的同时高效地学习跨信道信息。
轻量型注意力模块:ULSAM
主要思想:将提取的特征分成g组,对每组的子特征(论文中称问subspace)进行空间上的重新校准,最后,把g组特征concatenate到一起。具体做法看下面代码。首先用1×1 depthwise conv对每组特征提取channel为nin的新特征,然后maxpool,再pointwise conv成channel为1的attention map,最后利用softmax对attention map在H维缩放,确保attention map的权重和为1。

class SubSpace(nn.Module):
    """
    Subspace class.
    ...
    Attributes
    ----------
    nin : int
        number of input feature volume.
    Methods
    -------
    __init__(nin)
        initialize method.
    forward(x)
        forward pass.
    """

    def __init__(self, nin):
        super(SubSpace, self).__init__()
        self.conv_dws = nn.Conv2d(
            nin, nin, kernel_size=1, stride=1, padding=0, groups=nin
        )
        self.bn_dws = nn.BatchNorm2d(nin, momentum=0.9)
        self.relu_dws = nn.ReLU(inplace=False)

        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)

        self.conv_point = nn.Conv2d(
            nin, 1, kernel_size=1, stride=1, padding=0, groups=1
        )
        self.bn_point = nn.BatchNorm2d(1, momentum=0.9)
        self.relu_point = nn.ReLU(inplace=False)

        self.softmax = nn.Softmax(dim=2)

    def forward(self, x):
        out = self.conv_dws(x)
        out = self.bn_dws(out)
        out = self.relu_dws(out)

        out = self.maxpool(out)

        out = self.conv_point(out)
        out = self.bn_point(out)
        out = self.relu_point(out)

        m, n, p, q = out.shape
        out = self.softmax(out.view(m, n, -1))
        out = out.view(m, n, p, q)

        out = out.expand(x.shape[0], x.shape[1], x.shape[2], x.shape[3])

        out = torch.mul(out, x)

        out = out + x

        return out

分析一下整个attention map的计算复杂度,主要就是dw conv中的nin×h×w×1×1跟pw conv中nin×h×w×1,同时考虑到原始特征由分组而来,计算量缺失很小。

轻量型注意力模块:ULSAM
从上图结果来看,(POS11:1表示在第11个layer后在使用ULSAM),分组数g=4时可以获得较好的结果。虽然,整体增益不大,但考虑在整体计算复杂度基本不变,此方法确实有些意思。但有个问题,该方法增加了add,mul等elment-wise操作,这也会增加计算负担,并且在flops和param中无法体现。