标签 对比学习 下的文章

ALBEF paper : http://arxiv.org/abs/2107.07651

BLIP paper : http://arxiv.org/abs/2201.12086

BLIP code : https://github.com/salesforce/BLIP

ALBEF

网络结构如下,这篇算得上BLIP/BLIP2的前身了,其三个Loss一直延续至今(当然MLM变成了LM)。具体三个Loss的介绍可以看BLIP2 - JJJYmmm Blog,有以下几个特别点:

  • ITC Loss采用了MOCO的形式,即通过一个momentum encoder来扩大负样本的数量,这也是多了一个momentum model的原因。论文还从模型蒸馏的角度对momentum model做了进一步改进,例如在计算ITC和MLM时引入了伪标签
  • 这里对于文本理解任务,采用的是MLM而不是LM,可能是因为MLM任务相比于LM任务更简单,因为模型只需要预测被mask的单词即可
  • 在计算ITC时得到的Image-Text Similarity可以挑选出hard negatives,专门去做ITM;这个方法在BLIP和BLIP2都用到
  • 在这里三个Encoder都不共享参数,算是典型的双流模型

BLIP

image-20231104162902822

网络结构如上,可以看到不同Encoder之间共享参数。loss则与ALBEF没什么区别。其他值得注意的点有:

  • 对于单模态Encoder,无论是文本还是图片,都是采用self attention提取特征;对于Image-grounded Text Encoder,主要添加了一个cross attention层,KVs是Image Embedding;对于Image-grounded Text Decoder,替换了self attention层(其实就是mask改成casual mask吧~) ; 三个text相关的encoder/decoder都共用FFN;关于参数共享的细节,可以看消融实验
  • 其实一直都有一个问题,为什么Image-grounded要让Image Embeddings作为cross attention的KVs呢?这样text作为query,cross attn的结果不就是通过text加权得到的image features?这样得到的特征应该更多与image有关而不是text有关吧(我能想到的一个原因是auto-regressive限制了decoder的input必须是text)
  • BLIP的主要亮点是对数据集的处理,这里引入了半监督bootstrap的做法,具体看下面这张图就懂了~这里论文同样从模型蒸馏的角度来说明bootstrap的有效性,在后续的实验中也表明,每次对于清洗/扩充后的数据集,都应该从头对模型进行pre-train,这符合模型蒸馏中的学生模型不应该继承教师模型参数的常识。

image-20231104164756650

Repository:https://github.com/salesforce/LAVIS/tree/main/projects/blip2

预训练结构

第一阶段

网络结构如下图。

  • 对于图像特征,采用DETR类似的思路,使用Learned Queries作为输入,Image Features作为cross attention的KVs,希望通过可学习的参数来抽取与文本更相关的视觉特征
  • 对于文本特征,采用传统的Bert Encoder思路
  • 对于多模态特征的融合,与双流模型不同,这里两个Encoder的self attention层是共享参数的,当然与单流模型也不同,因为FFN不共享参数,且视觉特征提取时还会走cross attention层

image-20231104154128875

训练采用的三个Loss函数(ITG/ITM/ITC)主要参考之前的ALBEF工作。

  • Image-Text Contrastive : 这部分的目的主要是对齐图像特征和文本特征的单模态特征,计算方法类似CLIP。与ALBEF不同的是,这里的negative pairs直接采用in-batc方式得到,并没有像ALBEF那样借鉴MOCO得到一个较大的Dictionary
  • Image-Text Matching : 这部分的目的是学习图像特征与文本特征的细粒度对齐,通过外接二分类器计算Loss。这一阶段Query和Text可以互相关注
  • Image-Grounded Text Generation : 这一部分主要是训练Queries捕获有关文本所有信息的视觉特征,因为在这一步Query是无法看到Text信息的,而Text可以通过self attention层看到Query并输出结果,所以Query只能从Image Feature中尽可能提取与文本相关的视觉特征,才能生成一个质量比较高的Text(一个直观理解的explanation)

第二阶段

第二阶段的网络结构如下,通过一个FC对齐Qformer与LLM的维度,并对Qformer进一步微调。Qformer的输出主要作为LLM的一个soft visual prompts,提示LLM的输出。

image-20231104155834896

源码

BLIP2的model文件在lavis/models/blip2_models下,之后默认以此为根目录

阅读顺序(主要类与函数)如下:

  • ../base_model.py : BaseModel(nn.Module)
  • blip2.py : Blip2Basecompute_sim_matrix
  • Qformer.py : BertEmbeddings, BertLayer, BertEncoder, BertPooler, BertModel, BertOnlyMLMHead, BertLMHeadModel
  • blip2_qformer.py : Blip2Qformer,

关注点主要是Qformer的实现,其实就是一个魔改的Bert

  • Learned Queries和Text会一起进入Encoder做attention,两者之间的交互由self-attention层的mask控制,具体mask信息参照论文Figure 2
  • 魔改的Bert会每隔1个Encoder Layer()就在该层self-attention层后添加一个cross-attention层,其KVs是Image Encoder的输出即视觉特征
  • 只有Query部分激活cross-attention层进行计算,会通过input[:, :query_length, :]进行截取;同理,如果输入只有Text(Unimodal)则根本不会进行cross-attn计算
  • past_key_values或者说KV cache是在自回归decoder推理时加速的trick,具体来说就是decoder上一次运算各层attention的结果KVs(即当前所有token的embedding信息)会被保存,下次运算时,只需要输入新的token(seq_len=1),进行attention计算时加入之前保存的KVs即可。这样做的可行性主要来自自回归模式的无后效性,token做attention时只会和自己以及之前的token交互
  • 在Qformer第一阶段训练中,past_key_values只使用一次:单模态阶段分别encode图像(Query)特征和文本特征,在encode图像(Query)特征时保存各个attn层的KVs;在计算ITG损失时,由于Query和Text都可以看到所有Query,于是这里就使用了当时单模态计算Query特征时的KVs,模型就只需要跑Text部分即可(使用past_key_values后,Query不需要也不能拼接在Text前面作为输入)
  • 一般来说计算ITM损失时也可以利用计算单模态Query时的产生的KVs,代码里之所以没这么做,是因为计算ITM时额外使用了hard negatives mining,输入已经发生变化

找到一个Github star比较多的CLIP跑起来玩一下。结果发现挺多坑的......(我发现Shariatnia好像很喜欢在Jupyter改代码,却不把改动更新到源文件)

项目地址:https://github.com/moein-shariatnia/OpenAI-CLIP

关于Loss的一些问题

CLIP网络的输入是一个个图像-文本对,对于一个batch_size=8的输入,总共有8个正样本对和56个负样本对。于是我们可以使用torch.eye来创建标签值。在这个项目中,由于Flicker-8k数据集较小,而且一张图片对应5个caption,所以作者使用文本(图像)之间的相似度作为标签,如下代码块。

targets = F.softmax(
            (images_similarity + texts_similarity) / 2 * self.temperature,                 
dim=-1)
texts_loss = cross_entropy(logits, targets, reduction='none')
images_loss = cross_entropy(logits.T, targets.T, reduction='none')

这引发了我几个问题:

  • images_losstexts_loss好像写反了?(虽然这并不影响模型的训练,因为最终的loss是两者的平均) 在原论文中,images_loss指的是针对某个文本,images的分类损失
  • 使用图像(文本)的embeddings之间的相似度来作为标签值,这确实可以处理一个batch中出现重复图片(文本)的问题,并且带有Pseudo-Labelling的味道。但是这样处理同样也会出现问题,1)如果不使用预训练模型,那么训练初期得到的图像(文本)相似度没有意义,无法作为标签值,即一定需要一个好的预训练模型;2)训练后期,出现过拟合现象,两个encoder或投影层会倾向于输出相差不大的向量,从而得到一个较低的loss,这类似于GAN中的模式坍塌(实际上, 如果将模型的学习率统一改成较大的一个值例如lr=1e-3,那么一个epoch之后,你就会发现无论输入什么文本,模型给出的图像检索结果都是固定的,如果你问我为什么发现,其实作者的代码一开始就是这么写的555)

    model = CLIPModel().to(CFG.device)
    optimizer = torch.optim.AdamW(
            model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay)
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode="min", patience=CFG.patience, factor=CFG.factor)
  • 我同样好奇(images_similarity + texts_similarity) / 2 这一步是否正确。对于images_loss,我们可以通过images_similarity找到重复(类似)的图像,从而给他们相同的标签值。但是texts_similarity是否可以指导images_loss? 我认为不能,因为文字具有多义性,两段文本类似,不代表对应的图片就是类似的。所以我尝试修改了loss函数,使它更符合直觉,如下。

    images_similarity = F.softmax(images_similarity, dim=-1)
    texts_similarity = F.softmax(texts_similarity, dim=-1)
    images_loss = cross_entropy(logits, images_similarity, reduction='none')
    texts_loss = cross_entropy(logits.T, texts_similarity, reduction='none')

实验结果

为了验证我的猜想,我做了一些实验。

①模式坍塌

正如上文所说,如果两个Encoder学习率过大,那么会发生模式坍塌现象。这里就不再赘述,大家可以自己试一下~

②Loss Function && Temperature

首先我修改了一下Temperature这个超参(原论文里这是一个可学习参数,数值最后稳定在100左右),motivation在于:我发现使用similarity作为label值,正负样本对之间的差距有点小(大概是2.5%),所以我想把T调小一点,增大他们之间的差距。当然,结果好像不是很好(T=0.1,可能给小了)

clip_author_loss_1epochs_bs32_t10

第二个实验是关于(images_similarity + texts_similarity) / 2这一步similarity平均的效果,我使用了上文提到的分别计算的方法进行对比,结果发现:没啥区别... valid_loss也几乎没差

clip_author_mine_loss_4epochs_bs32

image-20231024182147993

我猜测可能是因为batch_size不够大,即图片(文本)重复的概率不够大。所以我把batch_size从32调到了128,结果发现:好像平均可以加快收敛,但是为什么可以,还是有点不够intuitive...(也有可能单纯没差~)

clip_author_mine_loss_4epochs_bs128

image-20231024182233396

③Prompt Engineering

Prompt的好坏可以通过两幅图表示,我使用训练了5个epoch的模型权重。

当我们使用so many dogs作为query时,可以发现检索出来的图片除了狗以外,还有一些成群的鸭子或鸽子,这可能是因为text encoder同时注意到了many

many_dogs

当我们使用a photo of dogs作为query时,可以发现检索出来的图片就都是狗了。

photo_of_dogs

当然还可以像CLIP官方那样使用80个prompt做prompt ensembling。

总结

CLIP是一篇利用对比学习完成多模态特征对齐的工作,它的ZERO-SHOT能力非常亮眼(毕竟在400million的数据集上训练的...),同时也因为多模态的特点,可以引出很多有趣的应用,比如图像检索、编辑、生成等等。