Faster-RCNN检测-RPN
程序员文章站
2024-03-14 11:33:04
...
主要贡献
提出了 RPN(Region Proposal Networks) 网络来计算候选框
主要步骤:
- 特征提取: 同 Fast R-CNN,以整张图为输入,利用 CNN 得到图像的特征层
- 区域提名: 在最终的卷积特征层conv5-3上利用K 个不同的矩形框 (Anchor Box) 进行提名, k 一般取 9
- 分类与回归: 对每一个 Anchor Box 对应的区域:
- 进行 object/non-object 二分类
- 用 k 个回归模型(对应不同的 Anchor Box )微调候选框位置和大小,最终进 行目标分类
注: Faster R-CNN 抛弃了 Selectave Search, 引入了 RPN 网络,使得区域提名,分类,回归一起共用卷积特征,从而得 到了进一步加速。但是 Faster R-CNN 需要对两万个 Anchor Box 先判断是否是目标(二分类),然后在进行目标识别,分成了两部分
RPN网络
特征细化:RPN层的第一步是用[3,3,512]的卷机核在conv5-3上进行卷机操作,conv5-3上的每一个像素点,对应的是原始图像中的近似于16x16的区域,所以这也就是为什么文章中说的把每一个中心点的像素转换为512-D的vector,feature map的个数变为了512,于是比如[0,0,:]就是一个512-D的vector
-
网络输出:在这个512-D的vector基础上,又有两个全卷机网络,
- 一个是1x1的卷积核,单输出9x2维,因为有9个anchor,每个anchor都有两个值,前景和背景,所以是18。所以这个输出的大小的height和width与conv5-3的大小一致,用于定义objetness。
- 另一个全卷机核大小维1x1,输出为9x4,因为有9个anchor,每个anchor有4个值,这4个值为预测的tx,ty,tw,th。所以这个输出feature的height和width与conv5-3的大小一致,但深度为36。用于定义pred-box。
-
区域生成:
-
anchor生成:
// return: // rpn_labels: [HxWxA, 2] // rpn_bbox_targets: [HxWxA, 4] rpn_labels, rpn_bbox_targets, rpn_bbox_inside_weights,rpn_bbox_outside_weights = anchor_target_layer(rpn_cls_score, # (1, H, W, Ax2) bg/fg, 只提供feature map的height,width信息 gt_boxes, # (g, 5) vstack of [x1, y1, x2, y2, class] gt_ishard, dontcare_ares, im_info, _feat_stride=[16,], anchor_scales=[4,8,16,32])
- 这个方法首先把越界的anchor都过滤掉,保留都在图像范围内的anchor。
- 然后创建一个全部是-1的label。
- 接着计算每一个anchor与ground_truth的overlap,overlap返回一个二元数组,行数代表anchor的个数,列数代表ground-truth的个数。
- 从中选择max-overlap,如果max-overlap大于某个阈值,那么这个anchor的label就设置为包含目标,用1表示。如果max-overlap小于某个阈值,那么这个anchor的label就设置为0。
- 然后在找到每个ground-truth和anchor覆盖最大的anchor的index,把这些anchor设置为1,从而避免某个ground-truth没有对应的anchor。
- 对每个anchor都设置是否含有目标后,利用anchors和每个anchors对应的max-overlap的ground-truth来计算该anchor对应的tx*, ty*, tw*, th*。
- 然后设置bbox_inside_weights,这个权值起到的作用是论文中的公示(1)中的pi*。bbox_outside_weights权值用来设置在所有样本中,postive和negitive的权值。由于上述所有操作中都是在没有超越边界的anchor中进行的,所以需要还原回所有的anchors中,于是使用了方法_unmap
- 该方法最后返回:
- rpn_label:这事真实的每一个anchor是否含有目标还是没有目标
- rpn_bbox_targets:这个是真实的每个anchor与其覆盖最大的ground-truth来计算得到的tx,ty,tw,th
-
proposal生成
// return: // rpn_rois: [1xHxWxA, 5] e.g. [0, x1, y1, x2, y2] rpn_rois = proposal_layer(rpn_cls_prob_reshape, # [1,H,W,Ax2] outputs of RPN, prob of bg or fg rpn_bbox_pred, # [1,H,W,Ax4], rgs boxes output of RPN im_info, cfg_key, _feat_stride=[16,], anchor_scales=[8,16,32])
- 该方法首先根据rpn_bbox_pred来生成原始图像中的anchor的预测坐标,由于rpn_bbox_pred是tx,ty,tw,th(计算方法参考论文),对rpn_cls_prob进行排序,根据objectness分数进行高低排序,然后选出需要保留的proposal个数,论文中设置为6000,然后从这些proposal中使用nms算法,筛选出最后的proposal,返回这些proposal和score。注意,这些proposal和score都是排序后的。
-
target生成
// return: // rois: [1xHxWxA,5] e.g. [0, x1, y1, x2, y2] // labels: [1xHxWxA, 1] e.g. {0,1,2,...,_num_classes-1} // bbox_targets: [1xHxWxA, Kx4] e.g. [dx1, dy1, dx2, dy2] // bbox_inside_weights: [1xHxWxA, kx4] e.g. 0/1 masks for the computing loss // bbox_outside_weight: [1xHxWxA, kx4] e.g. 0/1 masks for the computing loss roi, labels, bbox_targets, bbox_inside_weights, bbox_outside_weights = proposal_target_layer(rpn_roi, # (1xHxWxA, 5) e.g. [0,x1,y1,x2,y2] gt_boxes, # (G,5) e.g. [x1,y1,x2,y2,class] gt_ishard, dontcare_area, _num_classes)
- 首先计算fg_rois_per_image,也就是一个batch中认为是前景的roi的个数,剩余的认为是背景。
- 然后计算每一个rois和ground_truth的overlap,该overlap返回的数组形式为[roi_size, gt_size]。
- 从这些rois中随机选择一些正样本和负样本,max_overlap大于某个阈值的roi被认为是正样本,建立label【label是每一个正样本的类别标签,voc为20类,是某个数字】。按照比例设定背景样本,背景样本的标签为0。该方法中调用了一个—_sample_rois的方法,该方法返回值为:
- labels:每一个roi的类别标签,
- rois:就是原来所有rois进行正负样本过滤后,选择出来的正样本和负样本。
- roi_scores:对应选择出来的正负样本的objectness score
- bbox_targets:该返回值为数组,target_num=[num_rois, num_class*4],取其中一行作为例子,比如target_num[0],该vector的长度为80,首先设置全部为0,如果target_num[0]的类别是3,那么设置target_nums[0,3*4:(3*4+4)]的取值为tx,ty,tw,th。
- 这个方法返回的rois会接着送到后面的fast-rcnn网络中。该方法中计算出来的labels,boxs都作为真实值。
-
上一篇: 时序算法lstm实测代码
下一篇: autojs实现无root录屏
推荐阅读
-
Faster-RCNN检测-RPN
-
目标检测(Object Detection)
-
在谷歌目标检测(Google object_detection) API 上训练自己的数据集
-
tensorflow入门教程(二十五)Object Detection API目标检测(下)
-
ImageAI (二) 使用Python快速简单实现物体检测 Object Detection
-
tensorflow入门教程(二十四)Object Detection API目标检测(中)
-
使用tensorflow object_detection API完成目标检测
-
faster-rcnn tensorflow CPU版代码与demo运行流程
-
PHP检测数据类型的几种方法(总结)
-
PHP随机获取未被微信屏蔽的域名(微信域名检测)