【因果学习】VC RCNN(CVPR 2020)代码
作者基于MaskRCNN框架(Detectron2的前身)开发。受Bottom-Up and Top-Down Attention for Image Captioning and VQA启发,使用Mask RCNN作为Bottom-Up的backbone,为Downstream任务例如Image Captioning、VQA等提供图片特征。
论文中提到,去掉了RPN,使用GT bbox作为输入,训练的损失修改为:
测试阶段,则变为特征提取阶段,通过ROI_HEAD输出的特征,认为是VC Feature。
配置文件在:e2e_mask_rcnn_R_101_FPN_1x.yaml,相较于MaskRCNN,作者的BASE_LR从0.02修改为0.005,MAX_ITERS从90k修改为240k,同时作者是从头训练。主要涉及的文件在:ROI_BOX_HEAD中:FPN2MLPFeatureExtrator和FPNPredictor,其中前者是ROI Align和flatten + 两层fc+relu,输出1024维特征。后者FPNPredictor则是class预测和box回归。
在box_head.py的ROIBoxHead()中增加了causal_predictor和feature_save_path。
对于predictor(),去掉了box_regression部分,只对class进行分类。用class_logits和class_logits_causal_list送入loss_evaluator(),并在测试阶段,执行save_object_feature_gt_bu()。
在roi_box_predictors.py中增加了CausalPredictor()
@registry.ROI_BOX_PREDICTOR.register("CausalPredictor")
class CausalPredictor(nn.Module):
def __init__(self, cfg, in_channels):
super(CausalPredictor, self).__init__()
num_classes = cfg.MODEL.ROI_BOX_HEAD.NUM_CLASSES
self.embedding_size = cfg.MODEL.ROI_BOX_HEAD.EMBEDDING
representation_size = in_channels
self.causal_score = nn.Linear(2*representation_size, num_classes)
self.Wy = nn.Linear(representation_size, self.embedding_size)
self.Wz = nn.Linear(representation_size, self.embedding_size)
nn.init.normal_(self.causal_score.weight, std=0.01)
nn.init.normal_(self.Wy.weight, std=0.02)
nn.init.normal_(self.Wz.weight, std=0.02)
nn.init.constant_(self.Wy.bias, 0)
nn.init.constant_(self.Wz.bias, 0)
nn.init.constant_(self.causal_score.bias, 0)
self.feature_size = representation_size
self.dic = torch.tensor(np.load(cfg.DIC_FILE)[1:], dtype=torch.float)
self.prior = torch.tensor(np.load(cfg.PRIOR_PROB), dtype=torch.float)
def forward(self, x, proposals):
device = x.get_device()
dic_z = self.dic.to(device)
prior = self.prior.to(device)
box_size_list = [proposal.bbox.size(0) for proposal in proposals]
feature_split = x.split(box_size_list)
xzs = [self.z_dic(feature_pre_obj, dic_z, prior) for feature_pre_obj in feature_split]
causal_logits_list = [self.causal_score(xz) for xz in xzs]
return causal_logits_list
def z_dic(self, y, dic_z, prior):
"""
Please note that we computer the intervention in the whole batch rather than for one object in the main paper.
"""
length = y.size(0)
if length == 1:
print('debug')
attention = torch.mm(self.Wy(y), self.Wz(dic_z).t()) / (self.embedding_size ** 0.5)
attention = F.softmax(attention, 1)
z_hat = attention.unsqueeze(2) * dic_z.unsqueeze(0)
z = torch.matmul(prior.unsqueeze(0), z_hat).squeeze(1)
xz = torch.cat((y.unsqueeze(1).repeat(1, length, 1), z.unsqueeze(0).repeat(length, 1, 1)), 2).view(-1, 2*y.size(1))
# detect if encounter nan
if torch.isnan(xz).sum():
print(xz)
return xz
在loss.py中修改了FastRCNNLossComputation()中的__call__函数
def __call__(self, class_logits, causal_logits_list, proposals):
"""
Computes the loss for Faster R-CNN.
This requires that the subsample method has been called beforehand.
Arguments:
class_logits (list[Tensor])
box_regression (list[Tensor])
Returns:
classification_loss (Tensor)
box_loss (Tensor)
"""
class_logits = cat(class_logits, dim=0)
device = class_logits.device
labels = [proposal.get_field("labels").to(dtype=torch.int64) for proposal in proposals]
labels_self = cat(labels, dim=0)
# self predictor loss
classification_loss = F.cross_entropy(class_logits, labels_self)
# context predictor loss
causal_loss = 0.
for causal_logit, label in zip(causal_logits_list, labels):
mask_label = label.unsqueeze(0).repeat(label.size(0), 1)
mask = 1 - torch.eye(mask_label.size(0)).to(device)
loss_causal = F.cross_entropy(causal_logit, mask_label.view(-1), reduction='none')
loss_causal = loss_causal * mask.view(-1)
causal_loss += torch.mean(loss_causal)
return classification_loss, causal_loss
在box_head.py的ROIBoxHead()中增加函数,用于在测试中,保存feature
def save_object_feature_gt_bu(self, x, result, targets):
for i, image in enumerate(result):
feature_pre_image = image.get_field("features").cpu().numpy()
try:
assert image.get_field("num_box")[0] == feature_pre_image.shape[0]
image_id = str(image.get_field("image_id")[0].cpu().numpy())
path = os.path.join(self.feature_save_path, image_id) +'.npy'
np.save(path, feature_pre_image)
except:
print(image)
总的来说,作者去掉了和bbox相关的所有部分,本文使用的Mask R-CNN测试时需要提供bbox GT,某种程度上来说,它只执行了分类任务,并不包含任何的定位信息,因此不能单独使用,必须要加上Up-Down feature。
不同于Up-Down那篇论文的Faster RCNN是可以用于目标检测任务的。
从论文的测试结果也可以看出,Only VC的效果是比Origin要低的。
上一篇: 机器学习之决策树(Decision Tree)模型
下一篇: javap命令的使用技巧