源码阅读|Faster RCNN(五)——ROI Headers

源码阅读 · 2023-03-30 · 275 人浏览

整体思路

​ RoI-Header共由三部分组成:

  • box_roi_pool:Multi-scale RoIAlign pooling
  • box_head:TwoMLPHead
  • box_predictor:FastRCNNPredictor

MultiScaleRoIAlign

​ 该类与之前所述的RoIPooling不同,RoIAlign的定位能力更强。RoIPooling在计算过程中存在取整操作,从而引入了更多的定位误差,而Align不会进行取整操作。具体以后再展开~

image-20230325223450415

TwoMLPHead

​ TwoMLPHead其实就是RoIPooling之后跟着的两个全连接层(还有一个Flatten层).

image-20230325223606050

FastRCNNPredictor

​ FastRCNNPredictor也就是两个全连接层,分别预测每个proposal的类别和bbox的回归参数。

image-20230325223734155

注意输入的num_classes应该是实际类型+1,因为第0类是background

image-20230325223834839

ROIHeads

init

​ 保存一些需要用到的工具.

  • box_similarity负责计算box_iou
  • proposal_matcher负责正负样本的分配
  • fg_bg_sampler负责正负样本的采样
  • 其他参数就是刚刚提到的类以及一些阈值参数了

image-20230325223924768

forward

​ 训练模式下,首先会对proposal进一步采样,得到proposal样本和对应的label.

​ 其次将proposal和features特征层送入roi_pool得到每个proposal的box_features.box_features的形状应该是[num_proposals,channel,7,7]

​ 随后将box_features送入box_header提取出特征向量

​ 最后将这些向量送入box_predictor得到类别和回归参数预测结果

image-20230325224128474

​ 最后一部分代码如下.如果是训练模式下将通过fastrcnn_loss计算损失;如果是预测模式则会对proposals进行预处理postprocess_detections.最后返回相应的结果

image-20230325231206433

select_training_samples

该函数的功能是将RPN网络提供的Proposal进行采样,并计算这些Proposal的标签和regression参数(分配gtbox并计算,跟之前RPN网络内的操作类似)

​ 如下图所示,源码将gt_boxes也拼到了proposal后面,这里可能考虑到了PRN训练初始无法提供有效的proposal,所以加入gt_boxes来训练FastRCNN网络部分.

image-20230325224923241

​ 接下来将调用assign_targets_to_proposals函数将proposals分配给gt_boxes.这个函数在之前的RPN网络提到过,这里不再赘述.

​ 之后调用了subsample进行采样.得到一定比例的正负样本.

image-20230325225121300

​ 最后一步是遍历每张图片,首先找到正负样本(因为回归参数正负样本都参与计算)对应proposal的类别和proposal分配到的gt_box,再计算gt_box和proposal之间的回归参数(通过box_coder的encode方法,之前在RPN网络中有提到).

注意这里负样本对应的gt_box是第0个gt_box,按道理来说负样本不参与边界回归参数损失的计算.但是为了防止matched_idxs下标越界,所以在计算match_idxs时将-1都置为了0,导致现在"负样本有对应的gt_box,且计算了回归参数",不过这个问题不大,因为label记录了负样本的位置,在计算损失时忽略这部分即可~

image-20230325225957610

subsample

​ 该函数其实只是调用了fg_bg_sampler这个类对象,得到了每张图片里的正负样本索引,随后将每张图片的正负样本索引丢到sampled_inds列表里.

image-20230325225519409

fastrcnn_loss

​ 刚开始将label和regression cat起来是把不同图片的labels和回归参数都摞起来,一起处理.

​ 正负样本都会计算类别损失.

​ 而回归参数损失只计算正样本的,所以这里需要用sampled_pos_inds_subset记录正样本的位置.同时还需要对box_regression进行reshape处理,因为regression参数针对每个类别都会有四个参数. 最后使用smoothL1Loss进行正样本的回归参数损失计算.

image-20230325231409283

postprocess_detections

在预测模式下,将通过此函数得到最后的预测结果。具体流程见下图;具体操作见源码(带注释)

image-20230326095609394

Faster-RCNN
Theme Jasmine by Kent Liao