欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

siammask代码阅读(3)siammask.py和siammask_sharp.py区别

程序员文章站 2022-07-12 13:17:25
...

1.siammask.py里面:

class SiamMask(nn.Module):
      def __init__(self, anchors=None, o_sz=63, g_sz=127):

siammask_sharp.py里面

	class SiamMask(nn.Module):
	def __init__(self, anchors=None, o_sz=127, g_sz=127):

第一个是参数o_sz=63改成了127

2.siammask.py里面:

	def run(self, template, search, softmax=False):
	"""
	run network
	"""
	template_feature = self.feature_extractor(template)
	search_feature = self.feature_extractor(search)
	rpn_pred_cls, rpn_pred_loc = self.rpn(template_feature, search_feature)
	rpn_pred_mask = self.mask(template_feature, search_feature) # (b, 63*63, w, h)

siammask_sharp.py里面

	def run(self, template, search, softmax=False):
	"""
	run network
	"""
	template_feature = self.feature_extractor(template)
	feature, search_feature = self.features.forward_all(search)
	rpn_pred_cls, rpn_pred_loc = self.rpn(template_feature, search_feature)
	corr_feature = self.mask_model.mask.forward_corr(template_feature, search_feature) # (b, 256, w, h)
	rpn_pred_mask = self.refine_model(feature, corr_feature)
	

第二个是生成mask的不同方法,siammask.py里面是用的论文中的第一种方法,而sharp 是用的refine 的方法。
3.siammask.py里面:

	p_m = p_m.permute(0, 2, 3, 1).contiguous().view(-1, 1, o_sz, o_sz)
	p_m = torch.index_select(p_m, 0, pos)
	p_m = nn.UpsamplingBilinear2d(size=[g_sz, g_sz])(p_m)
	p_m = p_m.view(-1, g_sz * g_sz)
	
	mask_uf = F.unfold(mask, (g_sz, g_sz), padding=32, stride=8)
	mask_uf = torch.transpose(mask_uf, 1, 2).contiguous().view(-1, g_sz * g_sz)

siammask_sharp.py里面

	if len(p_m.shape) == 4:
	p_m = p_m.permute(0, 2, 3, 1).contiguous().view(-1, 1, o_sz, o_sz)
	p_m = torch.index_select(p_m, 0, pos)
	p_m = nn.UpsamplingBilinear2d(size=[g_sz, g_sz])(p_m)
	p_m = p_m.view(-1, g_sz * g_sz)
	else:
	p_m = torch.index_select(p_m, 0, pos)
	
	mask_uf = F.unfold(mask, (g_sz, g_sz), padding=0, stride=8)
	mask_uf = torch.transpose(mask_uf, 1, 2).contiguous().view(-1, g_sz * g_sz)

第三个是加了条件判断以及改了一个padding的参数,由32改到了0