siammask代码阅读(3)siammask.py和siammask_sharp.py区别
程序员文章站
2022-07-12 13:17:25
...
1.siammask.py里面:
class SiamMask(nn.Module):
def __init__(self, anchors=None, o_sz=63, g_sz=127):
siammask_sharp.py里面
class SiamMask(nn.Module):
def __init__(self, anchors=None, o_sz=127, g_sz=127):
第一个是参数o_sz=63改成了127
2.siammask.py里面:
def run(self, template, search, softmax=False):
"""
run network
"""
template_feature = self.feature_extractor(template)
search_feature = self.feature_extractor(search)
rpn_pred_cls, rpn_pred_loc = self.rpn(template_feature, search_feature)
rpn_pred_mask = self.mask(template_feature, search_feature) # (b, 63*63, w, h)
siammask_sharp.py里面
def run(self, template, search, softmax=False):
"""
run network
"""
template_feature = self.feature_extractor(template)
feature, search_feature = self.features.forward_all(search)
rpn_pred_cls, rpn_pred_loc = self.rpn(template_feature, search_feature)
corr_feature = self.mask_model.mask.forward_corr(template_feature, search_feature) # (b, 256, w, h)
rpn_pred_mask = self.refine_model(feature, corr_feature)
第二个是生成mask的不同方法,siammask.py里面是用的论文中的第一种方法,而sharp 是用的refine 的方法。
3.siammask.py里面:
p_m = p_m.permute(0, 2, 3, 1).contiguous().view(-1, 1, o_sz, o_sz)
p_m = torch.index_select(p_m, 0, pos)
p_m = nn.UpsamplingBilinear2d(size=[g_sz, g_sz])(p_m)
p_m = p_m.view(-1, g_sz * g_sz)
mask_uf = F.unfold(mask, (g_sz, g_sz), padding=32, stride=8)
mask_uf = torch.transpose(mask_uf, 1, 2).contiguous().view(-1, g_sz * g_sz)
siammask_sharp.py里面
if len(p_m.shape) == 4:
p_m = p_m.permute(0, 2, 3, 1).contiguous().view(-1, 1, o_sz, o_sz)
p_m = torch.index_select(p_m, 0, pos)
p_m = nn.UpsamplingBilinear2d(size=[g_sz, g_sz])(p_m)
p_m = p_m.view(-1, g_sz * g_sz)
else:
p_m = torch.index_select(p_m, 0, pos)
mask_uf = F.unfold(mask, (g_sz, g_sz), padding=0, stride=8)
mask_uf = torch.transpose(mask_uf, 1, 2).contiguous().view(-1, g_sz * g_sz)
第三个是加了条件判断以及改了一个padding的参数,由32改到了0
上一篇: Siammask代码阅读笔记(二)
下一篇: 记录一些项目中遇到的问题