语义分割之《Dual Attention Network for Scene Segmentation》论文阅读
论文地址:Dual Attention Network for Scene Segmentation
代码地址:github
1. 简介
文章中介绍了一种注意力机制,分为两个模块PAM(position attention module)和CAM(channel attention module),结构和non-local之类的attention非常相似。作者使用这个结构的用意是通过不同通道和不同像素之间相互影响生成attention map来进行feature中的区域选取。实现方式基本就是attention的思想:通道加权,像素加权。
另外作者基于这两个模块设计了DANet。
2. 结构
下图是DANet的整体结构,基础网络是经过修改的ResNet,后面接两个分支PAM和CAM,最后将两个分支的feature进行相加经过一个卷积层得到最终的分割图。DANet的backbone是将Resnet最后的两个block的下采样层去掉之后使用dilation convolution(空洞卷积)实现保证网络的feture map尺寸只下降了。
2.1 Position Attention Module
途中PAM分为四个分支,其中生成Attention的上BC两个分支,D用来生成经过attention map处理的feature,最下面的是保留原有的feature保证信息不丢失。
基本的流程:
- 输入,BCD三个分支分别经过conv+batch_norm+relu得到对应的输出,,
-
经过reshape维度编程,然后得到,得到的输出经过softmax得到的相应图,其中表示feature上点对的影响。
- 得到的attention map和输入进行相乘的操作;
- 最后和原有的输入进行加权的相加,如下公式,其中初始化为0,是可学习的。
下面为流程图,不会出现feature之类的标签,主要以维度为主如C*H*W表示feature或者attention map是通道高宽为。
2.2 Channel Attention Module
CAM模块的操作和PAM非常相似不同之处是操作室基于通道的生成的attention map 大小为,最后的加权参数为也是可学习的,得到的attention map上的点表示通道对的影响。不做赘述。下面为流程图
2.3 代码实现
CAM:
def hw_flatten(x):
return K.reshape(x, shape=[K.shape(x)[0], K.shape(x)[1]*K.shape(x)[2], K.shape(x)[3]])
def cam(x):
f = hw_flatten(x) # [bs, h*w, c]
g = hw_flatten(x) # [bs, h*w, c]
h = hw_flatten(x) # [bs, h*w, c]
s = K.batch_dot(K.permute_dimensions(hw_flatten(g), (0, 2, 1)), hw_flatten(f))
beta = K.softmax(s, axis=-1) # attention map
o = K.batch_dot(hw_flatten(h),beta) # [bs, N, C]
o = K.reshape(o, shape=K.shape(x)) # [bs, h, w, C]
x = gamma * o + x
return x
PAM:
def hw_flatten(x):
return K.reshape(x, shape=[K.shape(x)[0], K.shape(x)[1]*K.shape(x)[2], K.shape(x)[3]])
def pam(x):
f = K.conv2d(x, kernel= kernel_f, strides=(1, 1), padding='same') # [bs, h, w, c']
g = K.conv2d(x, kernel= kernel_g, strides=(1, 1), padding='same') # [bs, h, w, c']
h = K.conv2d(x, kernel= kernel_h, strides=(1, 1), padding='same') # [bs, h, w, c]
s = K.batch_dot(hw_flatten(g), K.permute_dimensions(hw_flatten(f), (0, 2, 1))) #[bs, N, N]
beta = K.softmax(s, axis=-1) # attention map
o = K.batch_dot(beta, hw_flatten(h)) # [bs, N, C]
o = K.reshape(o, shape=K.shape(x)) # [bs, h, w, C]
x = gamma * o + x
return x
3. 实验结果
实验:分别在Citycapes,PASCAL Context, COCO Stuff上进行了实验。
3.1 实验设置
- 学习率初始化和更新为:
- 基础学习率为0.01, Cityscapes为0.001, momentum=0.9, weight-decay=0.0001;
- 数据增强:随机左右翻转
3.2 结果
可视化:
上一篇: Vue核心技术-18,父子组件通信
下一篇: jsp跳转到servlet找不到路径解决