标签 计算机视觉 下的文章

ALBEF paper : http://arxiv.org/abs/2107.07651

BLIP paper : http://arxiv.org/abs/2201.12086

BLIP code : https://github.com/salesforce/BLIP

ALBEF

网络结构如下,这篇算得上BLIP/BLIP2的前身了,其三个Loss一直延续至今(当然MLM变成了LM)。具体三个Loss的介绍可以看BLIP2 - JJJYmmm Blog,有以下几个特别点:

  • ITC Loss采用了MOCO的形式,即通过一个momentum encoder来扩大负样本的数量,这也是多了一个momentum model的原因。论文还从模型蒸馏的角度对momentum model做了进一步改进,例如在计算ITC和MLM时引入了伪标签
  • 这里对于文本理解任务,采用的是MLM而不是LM,可能是因为MLM任务相比于LM任务更简单,因为模型只需要预测被mask的单词即可
  • 在计算ITC时得到的Image-Text Similarity可以挑选出hard negatives,专门去做ITM;这个方法在BLIP和BLIP2都用到
  • 在这里三个Encoder都不共享参数,算是典型的双流模型

BLIP

image-20231104162902822

网络结构如上,可以看到不同Encoder之间共享参数。loss则与ALBEF没什么区别。其他值得注意的点有:

  • 对于单模态Encoder,无论是文本还是图片,都是采用self attention提取特征;对于Image-grounded Text Encoder,主要添加了一个cross attention层,KVs是Image Embedding;对于Image-grounded Text Decoder,替换了self attention层(其实就是mask改成casual mask吧~) ; 三个text相关的encoder/decoder都共用FFN;关于参数共享的细节,可以看消融实验
  • 其实一直都有一个问题,为什么Image-grounded要让Image Embeddings作为cross attention的KVs呢?这样text作为query,cross attn的结果不就是通过text加权得到的image features?这样得到的特征应该更多与image有关而不是text有关吧(我能想到的一个原因是auto-regressive限制了decoder的input必须是text)
  • BLIP的主要亮点是对数据集的处理,这里引入了半监督bootstrap的做法,具体看下面这张图就懂了~这里论文同样从模型蒸馏的角度来说明bootstrap的有效性,在后续的实验中也表明,每次对于清洗/扩充后的数据集,都应该从头对模型进行pre-train,这符合模型蒸馏中的学生模型不应该继承教师模型参数的常识。

image-20231104164756650

Repository:https://github.com/salesforce/LAVIS/tree/main/projects/blip2

预训练结构

第一阶段

网络结构如下图。

  • 对于图像特征,采用DETR类似的思路,使用Learned Queries作为输入,Image Features作为cross attention的KVs,希望通过可学习的参数来抽取与文本更相关的视觉特征
  • 对于文本特征,采用传统的Bert Encoder思路
  • 对于多模态特征的融合,与双流模型不同,这里两个Encoder的self attention层是共享参数的,当然与单流模型也不同,因为FFN不共享参数,且视觉特征提取时还会走cross attention层

image-20231104154128875

训练采用的三个Loss函数(ITG/ITM/ITC)主要参考之前的ALBEF工作。

  • Image-Text Contrastive : 这部分的目的主要是对齐图像特征和文本特征的单模态特征,计算方法类似CLIP。与ALBEF不同的是,这里的negative pairs直接采用in-batc方式得到,并没有像ALBEF那样借鉴MOCO得到一个较大的Dictionary
  • Image-Text Matching : 这部分的目的是学习图像特征与文本特征的细粒度对齐,通过外接二分类器计算Loss。这一阶段Query和Text可以互相关注
  • Image-Grounded Text Generation : 这一部分主要是训练Queries捕获有关文本所有信息的视觉特征,因为在这一步Query是无法看到Text信息的,而Text可以通过self attention层看到Query并输出结果,所以Query只能从Image Feature中尽可能提取与文本相关的视觉特征,才能生成一个质量比较高的Text(一个直观理解的explanation)

第二阶段

第二阶段的网络结构如下,通过一个FC对齐Qformer与LLM的维度,并对Qformer进一步微调。Qformer的输出主要作为LLM的一个soft visual prompts,提示LLM的输出。

image-20231104155834896

源码

BLIP2的model文件在lavis/models/blip2_models下,之后默认以此为根目录

阅读顺序(主要类与函数)如下:

  • ../base_model.py : BaseModel(nn.Module)
  • blip2.py : Blip2Basecompute_sim_matrix
  • Qformer.py : BertEmbeddings, BertLayer, BertEncoder, BertPooler, BertModel, BertOnlyMLMHead, BertLMHeadModel
  • blip2_qformer.py : Blip2Qformer,

关注点主要是Qformer的实现,其实就是一个魔改的Bert

  • Learned Queries和Text会一起进入Encoder做attention,两者之间的交互由self-attention层的mask控制,具体mask信息参照论文Figure 2
  • 魔改的Bert会每隔1个Encoder Layer()就在该层self-attention层后添加一个cross-attention层,其KVs是Image Encoder的输出即视觉特征
  • 只有Query部分激活cross-attention层进行计算,会通过input[:, :query_length, :]进行截取;同理,如果输入只有Text(Unimodal)则根本不会进行cross-attn计算
  • past_key_values或者说KV cache是在自回归decoder推理时加速的trick,具体来说就是decoder上一次运算各层attention的结果KVs(即当前所有token的embedding信息)会被保存,下次运算时,只需要输入新的token(seq_len=1),进行attention计算时加入之前保存的KVs即可。这样做的可行性主要来自自回归模式的无后效性,token做attention时只会和自己以及之前的token交互
  • 在Qformer第一阶段训练中,past_key_values只使用一次:单模态阶段分别encode图像(Query)特征和文本特征,在encode图像(Query)特征时保存各个attn层的KVs;在计算ITG损失时,由于Query和Text都可以看到所有Query,于是这里就使用了当时单模态计算Query特征时的KVs,模型就只需要跑Text部分即可(使用past_key_values后,Query不需要也不能拼接在Text前面作为输入)
  • 一般来说计算ITM损失时也可以利用计算单模态Query时的产生的KVs,代码里之所以没这么做,是因为计算ITM时额外使用了hard negatives mining,输入已经发生变化

最近写的一个Multi-task框架~
项目地址:https://github.com/JJJYmmm/Pix2SeqV2-Pytorch

Simple PyTorch implementation of Pix2SeqV2. This project references moein-shariatnia's Pix2Seq and the paper A Unified Sequence Interface for Vision Tasks.

overview

Introduction

Pix2Seq is a generalized framework for solving visual tasks proposed by Google. Essentially it treats visual tasks as language tasks, generating sequences of tokens by auto-regression, and obtaining the output of many visual tasks(e.g., object detection, segmentation, captioning, keypoint, etc.) by decoding the tokens.

The official implementation of Pix2Seq google-research/pix2seq: Pix2Seq codebase is written in TensorFlow. I wish there was a PyTorch implementation. Then Shariatnia gave me a simple implementation of Pix2SeqV1(just for object detection, no multi-task training). I followed his project and added something new below:

  • For objection detection, add support for COCO2017 datasets.
  • Keep the network structure unchanged and add interfaces for more tasks(instance segmentation, image captioning and keypoint detection).
  • Add support for multi-task training.

Something notes:

  • This project is just a simple implementation in PyTorch, I only referred to the original paper for other tasks' interface, please refer to the official implementation for more details.
  • Since this is a practice project, I only used one GPU of RTX 3090ti for training and reasoning. The main purpose of this project is to verify the feasibility of multi-task training, so I don't have more performance requirements.
  • If you want to improve the performance, just try to 1)add more data augmentations, 2)train for more epochs and 3)replace the model with a larger number of params, etc.

If you have more questions about this project, feel free to issues and PRs!

Environment

I use anaconda to manage my python environment. You can clone my environment by doing this:

# change to root dir
cd Pix2SeqV2
# create a new python 3.8 env
conda create -n your_env_name python=3.8
# install essential packages
pip install -r ./requirements.txt

If you want to run the project, you need to have at least one GPU with more than 12G memory. Of course, the more GPUs the better!

I haven't written the code for multi-GPU training, but it's coming soon.

Configurations

All configurations can be modified in CFG class in Pix2SeqV2/config.py. Most of my training configurations come from Shariatnia's tutorials.

I use relative paths for other configs like weights and other required files. The only thing you need to change is the path of the dataset.

To fetch VOC dataset, just cd download and bash download_voc.sh

To fetch COCO2017 dataset, download here

# For VOC dataset, you need to change the following two var
img_path = '../download/VOCdevkit/VOC2012/JPEGImages'
xml_path = '../download/VOCdevkit/VOC2012/Annotations'
# For COCO dataset, you need to change dir_root
dir_root = '/mnt/MSCOCO'

I trained some weights for different tasks, you can fetch them here. Put them in folder Pix2SeqV2/weights, so that you don't need to change corresponding configs.

A Small Demo

Before diving into the formal Train&Infer session, let me show a small demo for the multi-task processing.

I just trained the multi-task model weight for 2 epochs in 11 hours, including four tasks(instance segmentation, object detection, image captioning and keypoint detection). So the results are unsurprisingly poor, forgive me =v=. The weight can be download here.

I random choose a picture(No.6471) from COCO validation dataset for visualization.

000000006471

Next, you can run the following code to get the results of the four tasks.

# mask sure you're in the root directory and set the right weight path(multi_task_weight_path) in CFG 
cd infer
python infer_single_image_multi_task.py --image ../images/baseball.jpg > result.txt

After that you can see three images(instance_segmentation.png, keypoint_detection.png, object_detection.png) and a txt file(result.txt) in the infer directory.

result.txt shows all the predictions,

skipping pos_embed...
skipping pos_embed...
<All keys matched successfully>
Captioning:
[['baseball', 'player', 'swinging', 'his', 'bat', 'at', 'home', 'plate', '.']]
[['batter', 'during', 'the', 'game', 'of', 'baseball', 'game', '.']]
[['baseball', 'players', 'are', 'playing', 'baseball', 'in', 'a', 'field', '.']]
Bounding boxes:
[[ 15.665796 134.68234  130.5483   191.906   ]
 [262.4021    69.40819   90.07831  232.37599 ]
 [  0.        94.21238   15.665796  53.524773]
 [ 96.60574   78.54657   44.38643   61.357697]
 [206.26633  223.45518   28.720627  37.859024]
 [259.79114   72.01914   75.71802  229.76505 ]
 [ 97.911224 180.37425  137.07573  140.99214 ]
 [  0.        95.51784   19.582247  53.52481 ]]
Labels:
['person', 'person', 'person', 'person', 'baseball glove', 'person', 'person', 'person']
Keypoint list:
[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 176, 125, 198, 117, 168, 145, 237, 157, 0, 0, 261, 167, 180, 184, 205, 183, 173, 208, 204, 210, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]

Three images visualize the results of the different visual tasks.

obeject_detection

instance_segmentation

keypoint_detection

The low recall of object detection task may be due to poor data augmentation and not enough training epochs.

The segmentation task performed OK given the object detection box, as it was given the maximum training weight and I followed the settings of the original paper: repeat the prediction eight times to ensure recall.

The keypoint detection task performed very poorly, I think there are a few reasons for this: firstly it has the lowest weight in the multi-task training; secondly the bounding box I used in data augmentation seems to be too big (twice as big of a detection box, following the original paper's setup), resulting more than one person in the bbox.

Anyway, JJJymmm's pix2seqV2 has taken the first step !!!

Training & Inference

Object Detection

object_detection

For object detection, you can run the following code to train Pix2Seq from scratch. Hyperparameters such as training epochs, learning rate, etc. can be set in ./config.py. And the weights are saved in the directory ./train.

# mask sure you're in the root directory
cd train
python train_coco_object_detection.py # train on COCO2017
python train_voc_object_detection.py # train on VOC

Once the weights are obtained, you can run the code to infer a single image.

# mask sure you're in the root directory
cd infer
python infer_single_image_object_detection.py --image your_image_path # COCO2017
python infer_single_image_voc.py --image your_image_path # VOC

The predictions(bounding boxes and labels) are printed in terminal and the results of visualization are saved in object_detection.png.

Training and prediction for the other tasks did not differ much from this task.

Instance Segmentation

segmentation

Code for training.

# mask sure you're in the root directory
cd train
python train_coco_segmentation.py

Code for inference.

# mask sure you're in the root directory
cd infer
python infer_single_image_segmentation.py --image your_image_path --box selected_area(format:xywh)

The results of visualization are saved in instance_segmentation.png.

Image Captioning

captioning

Code for training.

# mask sure you're in the root directory
cd dataset
python build_captioning_vocab.py # generate vocab.pkl
# put the vocab.pkl to train folder or set the vocab_path in CFG
cd ../train
python train_coco_img_captioning.py

Code for inference.

# mask sure you're in the root directory
cd infer
python infer_single_image_caption.py --image

The results are printed in terminal.

Keypoint Detection

keypoint

Code for training.

# mask sure you're in the root directory
cd train
python train_coco_segmentation.py

Code for inference.

# mask sure you're in the root directory
cd infer
python infer_single_image_segmentation.py --image your_image_path --box selected_area(format:xywh)

The results of visualization are saved in keypoint_detection.png.

Multi-Task

Code for training.

# mask sure you're in the root directory
cd train
python train_multi_task.py --task task1,task2,task3...
# supported tasks: detection,keypoint,segmentation,captioning

Code for inference.

# mask sure you're in the root directory
cd infer
python infer_single_image_segmentation.py --image your_image_path --box selected_area(format:xywh)

The text results are printed in terminal and the results of visualization are saved in object_detection.png, keypoint_detection.png, instance_segmentation.png.

Some Results

pix2seq_result_objection_detection

pix2seq_result_objection_detection2

pix2seq_result_instance_segmentation

pix2seq_result_keypoint_detection

Cite

  • Pix2seq : official implementation(by Tensorflow)
  • Pix2seqV1 implementation(by PyTorch)
  • Pix2seq paper:

    @article{chen2021pix2seq,
      title={Pix2seq: A language modeling framework for object detection},
      author={Chen, Ting and Saxena, Saurabh and Li, Lala and Fleet, David J and Hinton, Geoffrey},
      journal={arXiv preprint arXiv:2109.10852},
      year={2021}
    }
  • Pix2seq multi-task paper:

    @article{chen2022unified,
      title={A Unified Sequence Interface for Vision Tasks},
      author={Chen, Ting and Saxena, Saurabh and Li, Lala and Lin, Tsung-Yi and Fleet, David J. and Hinton, Geoffrey},
      journal={arXiv preprint arXiv:2206.07669},
      year={2022}
    }

Acknowledgement

坐上那飞机去拉萨(civi粉丝版)

pix2seq implement by Pytorch

pix2seq - framework

source code : moein-shariatnia/Pix2Seq

paper : http://arxiv.org/abs/2109.10852

这次解析的源码是非官方实现的Pix2Seq,项目地址如上。教程直接看作者的Readme或者Object Detection w/ Transformers Pix2Seq in Pytorch | Towards AI,总体还是比较详细的。

模型即训练源码基本没什么问题,不过推荐先看完原始论文,不然可能在一些细节方面卡住。代码问题主要出现在测试文件中,问题如下。

Issue 1

2023.8.30

Tokenizer类的max_len参数用于限制单张图片的Obejct个数

labels = labels.astype('int')[:self.max_len]

bboxes = self.quantize(bboxes)[:self.max_len]

get_loaders中的collate_fn,把max_len作为输入序列的最大长度,这两处地方出现了矛盾(因为一个Object对应5个token,[x1, y1, x2, y2, class] )

if max_len: # [B,max_seq_len,dim] -> [B,max_len,dim]
        pad = torch.ones(seq_batch.size(0), max_len -
                         seq_batch.size(1)).fill_(pad_idx).long()
        seq_batch = torch.cat([seq_batch, pad], dim=1)

Issue 2

2023.8.31

test.py中的postprocess函数存在问题,没有考虑model未检出object的情况。导致会将一个空序列(即\<EOS\>\<BOS\>)输入tokenizer的decoder方法,从而引发错误

 for i, EOS_idx in enumerate(EOS_idxs.tolist()):
        if EOS_idx == 0:
            all_bboxes.append(None)
            all_labels.append(None)
            all_confs.append(None)
            continue

修正如下,考虑空序列的情况

 for i, EOS_idx in enumerate(EOS_idxs.tolist()):
        if EOS_idx == 0 or EOS_idx ==1: # invalid idx which EOS_idx = 0 or the model detect nothing when EOS_idx = 1 
            all_bboxes.append(None)
            all_labels.append(None)
            all_confs.append(None)
            continue

Issue 3

2023.8.31

test.py 中的第125行,发生类型判断错误。这里是剔除没有检测出物体的图片,但是此时lambda表达式中的x是array类型,所以isinstance函数参数不应该是list。否则过滤后preds_df会是空表

# preds_df = preds_df[preds_df['bbox'].map(lambda x: isinstance(x, list))].reset_index(drop=True)
# fix : list -> np.ndarray
preds_df = preds_df[preds_df['bbox'].map(lambda x: isinstance(x, np.ndarray))].reset_index(drop=True)

Issue 4

2023.9.1 按道理这个很影响结果啊,毕竟图片之间的映射都错了。加上Issue 3,感觉作者写完教程,整理源码后,并没有再测试test.py这个文件了。

test.py第117行开始,保存预测结果。这里的原意应该是把预测结果对应回每张图片的id,但是源码直接对valid_df做了截断,这里很明显导致了部分图片的id被trunc了(这是因为valid_df将同一图片中的不同物体分成了若干行,如果直接按照总的图片数目进行截断,后面的一些图片id就会被截掉)。使用valid_df['id'].unique()进行修改,正好与133行代码对应。能这么做也是因为data loader的shuffle=False

# preds_df = pd.DataFrame()
# valid_df = valid_df.iloc[:len(all_bboxes)]
# preds_df['id'] = valid_df['id'].copy()
# preds_df['bbox'] = all_bboxes
# preds_df['label'] = all_labels
# preds_df['conf'] = all_confs

# I think there is some bug above, because the code \
# do not consider the corresponding id of image, \
# it just trunc the valid_df with len(all_bboxes)!!!

preds_df = pd.DataFrame()
preds_df['id'] = valid_df['id'].unique().copy()
preds_df['bbox'] = all_bboxes
preds_df['label'] = all_labels
preds_df['conf'] = all_confs

# line_133 : valid_df = df[df['id'].isin(preds_df['id'].unique())].reset_index(drop=True)

使用作者提供的权重文件,跑出来的结果确实比原作者给出的结果(mAP=0.264399)高出7%左右,而且由于上述逻辑错误,作者的结果其实只在基本一半数量的box做了mAP的计算(因为测试图片就少了快一半),如果使用全量数据,他的效果应该会更差。

image-20230901001735855

Issue 5

2023.9.1

TODO : 之后把这些问题跟作者反馈一下,下一步考虑使用DDP,方便以后多卡训练。

Issues and PR

https://github.com/moein-shariatnia/Pix2Seq/issues/6

https://github.com/moein-shariatnia/Pix2Seq/issues/7

https://github.com/moein-shariatnia/Pix2Seq/pull/8

摘要

之前在SIFT算法中,有一个加速操作是使用图像金字塔,即不断对图像进行降采样。按照算法的思想表明:降采样后,标准差为$\sigma$的高斯模糊图像标准差会减半,得到标准差为$1/2\sigma$的高斯模糊图像

这里我不知道该如何证明....网上也没有相关资料,所以暂时采用数值解去验证这个说法。

实验过程

代码贴在最后,主要思路是比较两张图像:一张是先降采样一倍再用$\sigma$高斯模糊的图像;另一张是先使用$2\sigma$进行高斯模糊,再在模糊的图像上进行一倍降采样。

首先可视化这两张图,肉眼查看之间的差距,确实差距还是挺小的。此处$\sigma=30$(忽略窗口的值,那里标错了~)

image-20230709221257222

为了对比,这里把原图分别使用$\sigma$和$2\sigma$进行高斯模糊的结果也可视化了出来。这两张图就明显存在差异,这说明对高斯模糊过的图像降采样,确实会对其$sigma$产生影响

image-20230709221347242

image-20230709221357634

接着最早的两张图做差并画出来,可以看到形成了一个类似边缘检测的图像。这说明“先降采样再$\sigma$高斯模糊”跟“先$2\sigma$高斯模糊再降采样”这两个操作不完全等价。

为什么看上去是边缘检测图像?其实也很好理解,对于不同$\sigma$的高斯模糊图像相减,就是对高斯模糊图像微分,即DOG,DOG和LOG又只差一个常数倍,所以等效边缘检测了~

image-20230709221740842

接下来再探究“先$2\sigma$高斯模糊再降采样”得到的模糊图像的标准差到底是多少。最暴力的方式就是搜索,我们在降采样的图像上使用不同的$\sigma$进行遍历,画出delta的范数变化情况。最后我们发现:最接近的$\sigma_0$就是$\sigma$!

5d45f0cae88e5a5ae970dc07cc4d735

至此,我们知道上述的两个操作并不等价,但是它们足够接近。所以SIFT算法通过这种近似去做图像金字塔,大幅提高运算效率

代码

import cv2

Path = "C:\\Users\\Axuanz\\Desktop\\download.png"

if __name__ == "__main__":
    img = cv2.imread(Path,cv2.IMREAD_GRAYSCALE)

    sigma = 30

    ksize1 = sigma*6+1
    ksize2 = int(sigma/2*6+1)

    blur1 = cv2.GaussianBlur(img,ksize=(ksize1,ksize1),sigmaX=sigma)
    blur1 = cv2.pyrDown(blur1)

    downsample_img = cv2.pyrDown(img)
    blur2 = cv2.GaussianBlur(downsample_img,ksize=(ksize2,ksize2),sigmaX=sigma/2)

    blur3 = cv2.GaussianBlur(img,ksize=(ksize1,ksize1),sigmaX=sigma)
    blur4 = cv2.GaussianBlur(img,ksize=(ksize2,ksize2),sigmaX=sigma/2)

    delta = blur2 - blur1
    print(cv2.norm(delta))

    cv2.imshow("sigma=10",blur1)
    cv2.imshow("sigma=5",blur2)
    cv2.imshow("blur3",blur3)
    cv2.imshow("blur4",blur4)
    cv2.imshow("delta",delta)
    # cv2.waitKey()

###########################################
    sigma2 = 1
    norm_list = []
    while sigma2 <= sigma:
        _ksize = sigma2*6 + 1
        blur = cv2.GaussianBlur(downsample_img,ksize=(_ksize,_ksize),sigmaX = sigma2)
        delta = blur - blur1
        
        norm = cv2.norm(delta)
        print(f"sigma2 = {sigma2},delta norm ={norm}")

        norm_list.append(norm)
        sigma2 += 1
    
    import matplotlib.pyplot as plt
    x = list(range(1,sigma+1))
    print(x)
    print(norm_list)
    plt.plot(x,norm_list)
    plt.show()
###########################################