标签 源码阅读 下的文章

pix2seq implement by Pytorch

pix2seq - framework

source code : moein-shariatnia/Pix2Seq

paper : http://arxiv.org/abs/2109.10852

这次解析的源码是非官方实现的Pix2Seq,项目地址如上。教程直接看作者的Readme或者Object Detection w/ Transformers Pix2Seq in Pytorch | Towards AI,总体还是比较详细的。

模型即训练源码基本没什么问题,不过推荐先看完原始论文,不然可能在一些细节方面卡住。代码问题主要出现在测试文件中,问题如下。

Issue 1

2023.8.30

Tokenizer类的max_len参数用于限制单张图片的Obejct个数

labels = labels.astype('int')[:self.max_len]

bboxes = self.quantize(bboxes)[:self.max_len]

get_loaders中的collate_fn,把max_len作为输入序列的最大长度,这两处地方出现了矛盾(因为一个Object对应5个token,[x1, y1, x2, y2, class] )

if max_len: # [B,max_seq_len,dim] -> [B,max_len,dim]
        pad = torch.ones(seq_batch.size(0), max_len -
                         seq_batch.size(1)).fill_(pad_idx).long()
        seq_batch = torch.cat([seq_batch, pad], dim=1)

Issue 2

2023.8.31

test.py中的postprocess函数存在问题,没有考虑model未检出object的情况。导致会将一个空序列(即\<EOS\>\<BOS\>)输入tokenizer的decoder方法,从而引发错误

 for i, EOS_idx in enumerate(EOS_idxs.tolist()):
        if EOS_idx == 0:
            all_bboxes.append(None)
            all_labels.append(None)
            all_confs.append(None)
            continue

修正如下,考虑空序列的情况

 for i, EOS_idx in enumerate(EOS_idxs.tolist()):
        if EOS_idx == 0 or EOS_idx ==1: # invalid idx which EOS_idx = 0 or the model detect nothing when EOS_idx = 1 
            all_bboxes.append(None)
            all_labels.append(None)
            all_confs.append(None)
            continue

Issue 3

2023.8.31

test.py 中的第125行,发生类型判断错误。这里是剔除没有检测出物体的图片,但是此时lambda表达式中的x是array类型,所以isinstance函数参数不应该是list。否则过滤后preds_df会是空表

# preds_df = preds_df[preds_df['bbox'].map(lambda x: isinstance(x, list))].reset_index(drop=True)
# fix : list -> np.ndarray
preds_df = preds_df[preds_df['bbox'].map(lambda x: isinstance(x, np.ndarray))].reset_index(drop=True)

Issue 4

2023.9.1 按道理这个很影响结果啊,毕竟图片之间的映射都错了。加上Issue 3,感觉作者写完教程,整理源码后,并没有再测试test.py这个文件了。

test.py第117行开始,保存预测结果。这里的原意应该是把预测结果对应回每张图片的id,但是源码直接对valid_df做了截断,这里很明显导致了部分图片的id被trunc了(这是因为valid_df将同一图片中的不同物体分成了若干行,如果直接按照总的图片数目进行截断,后面的一些图片id就会被截掉)。使用valid_df['id'].unique()进行修改,正好与133行代码对应。能这么做也是因为data loader的shuffle=False

# preds_df = pd.DataFrame()
# valid_df = valid_df.iloc[:len(all_bboxes)]
# preds_df['id'] = valid_df['id'].copy()
# preds_df['bbox'] = all_bboxes
# preds_df['label'] = all_labels
# preds_df['conf'] = all_confs

# I think there is some bug above, because the code \
# do not consider the corresponding id of image, \
# it just trunc the valid_df with len(all_bboxes)!!!

preds_df = pd.DataFrame()
preds_df['id'] = valid_df['id'].unique().copy()
preds_df['bbox'] = all_bboxes
preds_df['label'] = all_labels
preds_df['conf'] = all_confs

# line_133 : valid_df = df[df['id'].isin(preds_df['id'].unique())].reset_index(drop=True)

使用作者提供的权重文件,跑出来的结果确实比原作者给出的结果(mAP=0.264399)高出7%左右,而且由于上述逻辑错误,作者的结果其实只在基本一半数量的box做了mAP的计算(因为测试图片就少了快一半),如果使用全量数据,他的效果应该会更差。

image-20230901001735855

Issue 5

2023.9.1

TODO : 之后把这些问题跟作者反馈一下,下一步考虑使用DDP,方便以后多卡训练。

Issues and PR

https://github.com/moein-shariatnia/Pix2Seq/issues/6

https://github.com/moein-shariatnia/Pix2Seq/issues/7

https://github.com/moein-shariatnia/Pix2Seq/pull/8

整体思路

​ GeneralizedRCNNTransform主要用在图像进入backbone网络前的预处理以及预测结果输出时的后处理两个阶段.主要工作是图像的标准化处理以及resize操作.

函数细节

__init__

__init__函数主要输入图像的均值和方差,以及resize时图片的最小(大)边长范围

image-20230322015154826

normalize

​ 最后一行通过添加None这个维度可以增加一维维度,再利用广播机制对image的每个像素都进行操作.

image-20230322015319593

resize

​ 这个方法首先调用_resize_image使用双线性插值调整图片大小,再通过resize_boxes调整对应的box大小.

image-20230322015549491

_resize_image

​ 根据宽高限制来确定缩放比例,调用interpolate对图像进行双线性插值,这里在image又添加一个维度,是因为interpolate方法输入需要是4D图像

image-20230322020119938

resize_boxes

​ 按照缩放比例调整box坐标即可.这里torch.stack()会在tensor最后新增一个维度,这里就是在最后一个维度摞起来

image-20230322020416714

batch_images

​ 这个方法是将一个batch图像中再次resize到统一尺寸,加速训练.这个统一尺寸被调整为size_divisible的整数倍

​ 具体实现时寻找一个batch中图片的高宽最大值,以此作为最大图像.其他图像跟该图像做左上角对齐,空余位置填充零.

​ 这种方法的好处是保证了图像的比例.

image-20230322020955182

forward

​ 对于每张图片,以此调用normalizeresize方法进行标准化和缩放.而在进行batch_images前,需要记录当前图像尺寸,存入image_sizes_list,最后与image打包成一个list,跟target标注一起返回.

​ 之所以要这么做是因为经过batch_images后,图像变成统一尺寸,但是图像有效区域在原先的图片大小范围内,所以需要保存batch_resize前的图像大小.

image-20230322021547272

postprocess

​ 这个方法是预测模式下最后的后处理操作.

image-20230322022153342

Faster RCNN框架图

image-20230322085023990

图源: deep-learning-for-image-processing/pytorch_object_detection/faster_rcnn at master · WZMIAOMIAO/deep-learning-for-image-processing (github.com)

源码主要内容

​ Faster R-CNN源码阅读将从以下几个方面展开,详见其他文档

  • DataSet
  • 网络框架
  • GeneralizedRCNNTransform
  • RPN
  • Predict Header
  • 正负样本划分与采样
  • Loss函数
  • PostProcess
  • Change Backbone(with FPN)

环境配置

  • Python 3.6/3.7/3.8
  • Pytorch>=1.6.0
  • pycocotools
  • Ubuntu or Centos
  • Use Gpu to train model
  • more details see requirements.txt

文件结构

  ├── backbone: 特征提取网络,可以根据自己的要求选择
  ├── network_files: Faster R-CNN网络(包括Fast R-CNN以及RPN等模块)
  ├── train_utils: 训练验证相关模块(包括cocotools)
  ├── my_dataset.py: 自定义dataset用于读取VOC数据集
  ├── train_mobilenet.py: 以MobileNetV2做为backbone进行训练
  ├── train_resnet50_fpn.py: 以resnet50+FPN做为backbone进行训练
  ├── train_multi_GPU.py: 针对使用多GPU的用户使用
  ├── predict.py: 简易的预测脚本,使用训练好的权重进行预测测试
  ├── validation.py: 利用训练好的权重验证/测试数据的COCO指标,并生成record_mAP.txt文件
  ├── coco.json: coco数据集标签文件
  └── pascal_voc_classes.json: pascal_voc标签文件

预训练权重

注意在源码中修改对应模型的路径与名称

数据集(以PASCAL VOC2012为例)

训练

  • 确保提前准备好数据集
  • 确保提前下载好对应预训练模型权重
  • 若要训练mobilenetv2+fasterrcnn,直接使用train_mobilenet.py训练脚本
  • 若要训练resnet50+fpn+fasterrcnn,直接使用train_resnet50_fpn.py训练脚本
  • 若要使用多GPU训练,使用python -m torch.distributed.launch --nproc_per_node=8 --use_env train_multi_GPU.py指令,nproc_per_node参数为使用GPU数量
  • 如果想指定使用哪些GPU设备可在指令前加上CUDA_VISIBLE_DEVICES=0,3(例如我只要使用设备中的第1块和第4块GPU设备)
  • CUDA_VISIBLE_DEVICES=0,3 python -m torch.distributed.launch --nproc_per_node=2 --use_env train_multi_GPU.py

注意事项

  • 在使用训练脚本时,注意要将--data-path(VOC_root)设置为自己存放VOCdevkit文件夹所在的根目录
  • 由于带有FPN结构的Faster RCNN很吃显存,如果GPU的显存不够(如果batch_size小于8的话)建议在create_model函数中使用默认的norm_layer, 即不传递norm_layer变量,默认去使用FrozenBatchNorm2d(即不会去更新参数的bn层),使用中发现效果也很好。
  • 训练过程中保存的results.txt是每个epoch在验证集上的COCO指标,前12个值是COCO指标,后面两个值是训练平均损失以及学习率
  • 在使用预测脚本时,要将train_weights设置为你自己生成的权重路径。
  • 使用validation文件时,注意确保你的验证集或者测试集中必须包含每个类别的目标,并且使用时只需要修改--num-classes--data-path--weights-path即可,其他代码尽量不要改动