分类 深度学习 下的文章

http://arxiv.org/abs/2305.14314

Quantization + LoRA,了解一下呢。

LoRA之前写过博客了,这里主要想说一些如何与量化结合起来,以及用到的一些新的tech

当然像分页优化这种就掠过了...需要用到Nvidia统一内存去做内存/GPU交换,cuda编程不会捏,感兴趣可以看论文附的链接https://docs.nvidia.com/cuda/cuda-c-programming-guide/

4-bit NormalFloat(NF4)

之前谈量化说到INT8/INT4,是把浮点数量化成定点数。量化操作一般是除去一个量化因子$s = \frac {2^{N-1}} {ABSMAX}$,这个其实潜在地假设了数据满足的是均匀分布,实际上LLM权重满足的是一个正态分布。

那么NF4就把概率作为量化间隔,而不是把实际的数值差作为量化间隔。看以下这个图就明白了。p满足等差数列,而X明显不满足,越近0,X的分布越密集。

图源 https://zhuanlan.zhihu.com/p/654967425

v2-48f1fa9168fe79028365cba76d3bfdc4_720w

这个时候把X归一化到到[-1,1],得到的就是quant book啦。那么实际应用中怎么量化呢,其实和传统量化差不多,输入数据首先除一个ABSMAX,归一化到[-1,1],然后查quant book,最靠近的那个浮点数就是对应的量化等级了。(反量化的时候直接取quant book对应的那个浮点数即可)

NF4的quant book是['-1.0000', '-0.6962', '-0.5251', '-0.3949', '-0.2844', '-0.1848', '-0.0911', '0.0000', '0.0796', '0.1609', '0.2461', '0.3379', '0.4407', '0.5626', '0.7230', '1.0000']

这种量化方案考虑了输入本身的分布(权重近似正态分布),bin划分的更合理,自然会减少一定的量化误差啦。

Double Quantization

双重量化,顾名思义就是执行两次量化。

回顾传统量化,对于一个权重矩阵W,为了更好的表示它,一般会对它进行per-channel的量化,即把W分成n块,每一块对应不同的量化因子s,量化后,我们不仅要保存量化后的权重W‘,还要保存这n个量化因子(因为反量化的时候要用),而且量化因子一般和W的类型一样,例如同样是FP32。

那么为了保证量化质量,这个n可能很大,比如和W的channel数相当,那么量化因子的保存就会是一笔比较大的内存开销。所以!Double Quantization就是量化量化因子,让它占的内存少一点!

开始小小的实例分析(其实是论文里附带的例子)。

对于FP32的W,我们决定把它量化成INT4,块大小取64(总共$numel(W)/64$块),那么相当于W中的每64个参数会共享一个FP32量化因子s,那么每个参数对应的量化因子bit数是$32/64=0.5$。

如果我们对量化后的因子再次量化,将其量化到8bit,块大小选择256(这里指每256个量化因子共用一个FP32量化因子,有点绕哈哈哈哈),那么这个时候我们得到了一个新的量化因子s'(FP32),和一个量化后的INT8量化因子q(s),这个时候再来计算下W每个参数对应的量化因子bit数,首先第一次量化的块大小是64,即W中64个参数共用一个量化因子s,s又被量化到了8bit,所以是$8/64$,我们还要保存量化s的量化因子s’,这个s‘被256个s共用,每个s又被W中64个参数共用,那么一个s'被64*256个W参数共用,那么就是$32/(64 \times 256)$,最后每个参数对应的量化因子bit数是$8/64 + 32/(64 \times 256) = 0.127$,比传统量化减少了0.373bit/per param

QLoRA

介绍完前面两个,QLoRA的训练用以下公式就一目了然了。forward过程W会反两次量化从NF4到BF16,和同为BF16的X做运算;backward的时候只对$L_1, L_2$进行更新(即LoRA里的A,B)

image-20240321200705889

doubleDequant就是我刚刚提到的两次反量化啦。先反量化W的量化因子s,再用恢复后的量化因子反量化W。

image-20240321200904180

挺有意思的一篇论文 Stealing Part of a Production Language Model(arxiv2403.06634),给出了一种通过black-box API查询来恢复LLM投影矩阵的方法。实际意义嘛...可以拿来蒸馏?但是成本还挺高的(

同期工作 Logits of API-Protected LLMs Leak Proprietary Information arxiv/2403.09539

简单讲一下Algorithm 1(其他的看不懂

image-20240318151657594

这里假设攻击者的能力是可以拿到LLM的Logit-Vector API,也就意味着可以拿到每个token对应的logits,注意这里还没有过softmax,所以算法里没有考虑softmax的影响。这个假设其实很强,之后会慢慢放开到目前商用的API。

假设LLM的hidden-dim是$h$,词表大小是$l$。那么我们可以假设一个$n$,也就是查询API的次数,我们希望它比$h$要大。每次查询,我们输入LLM随机的前缀作为Prompt,那么我们会拿到一个长为$l$的logit-vector。

虽然论文里没讲,但这里的logits应该是指prompt最后一个token对应的logits

虽然logits长度为$l$,但是它们应该都在$dim = h$的子空间中,因为logits是通过一个hidden_states($dim = h$)乘上一个投影矩阵($h$x$l$)得到的。

因此,如果查询次数足够多,那么之后得到的logits响应会和之前的线性有关,这就给了我们分析的空间。

回到Algo. 1中的Q,因为Q的每行都是一个logits响应,它们处在维度为h的子空间,而子空间最多只有h个向量彼此线性无关,所以有$Rank(Q) \le h$。当n足够大时,$Rank(Q) = h$。

这里我省掉了论文里的$Q = WH$,因为作者应该是默认W一般是满秩的,不影响秩的计算。

那么现在求解h的方法就转向求解$Rank(Q)$,一般来说求一下奇异值个数count就可以了。

不过Q由logits组成,即浮点数矩阵,所以求解奇异值的时候会出现一些数值很小的假奇异值。于是作者这里在求解奇异值之后,sort了一下并通过$log ||\Delta||$筛选出真正的奇异值的个数。

以下是对Pythia 1.4B的hack,可以看到查询足够多的情况可以观察到一个明显的gap,从而确定hidden dimension是2048。

n < h时,$Q$是满秩的,且有n个非平凡奇异值,所以无法恢复出h

image-20240318155154037

为什么开始看量化了,我也不知道

量化原理

我们一般谈LLM的精度,会涉及到FP32,FP16,BF16,INT8,INT4等字样。这些字段确定了LLM中一个参数所占的内存空间(如FP32指4字节浮点数,FP16和BF16指2字节浮点数,其中BF具有更多的指数位,INT8/4分别占8/4个比特)。其中INT8/4就涉及了模型量化,一个浮点数如何量化成一个定点数,这是量化干的最基本的事儿。

img

最直观的想法就是一组浮点数除去他们的ABS MAX,这些参数就落在了[-1,1]之间,然后根据量化位数N的不同乘以对应的步长$2^{N-1}$,接下来得到的数就是量化结果啦。

但上面这个问题在于浮点数的分布不均时,量化空间会有所浪费,所以实际应用时会进行一定的截断。

img

量化结束后,这组浮点数(weight/activation)就以定点数的形式存了下来。那么如何利用量化后的权重/激活呢?我们知道量化的时候一个浮点数先除了一个ABS MAX,再乘了一个$2^{N-1}$,那么$2^{N-1}/MAX$就是一个$\Delta$,我们实际做运算的时候额外把这个数除一下就可以了,比如$WX = \frac {\Delta W} {\Delta} X= \frac {Q(W)} {\Delta} X$。

解决完量化/反量化的问题,下一个问题是如何给浮点数分组,全部用一个$\Delta$肯定不是好方法,SmoothQuant总结了一个很好的图,红色框线的区域代表用了同一个$\Delta$。per-tensor很好理解,很适合并行化,但是量化效果可能不好,比如X的不同channel具有不同的值域。per-tokenper-channel就比较合适,在一个token/channel内进行量化。

但是为什么X应用per-token,W应用per-channel呢?不能反过来?这其实是处于性能考虑。实际计算的时候,假如X,W都进行量化,那么矩阵乘法计算如下,可以看到反量化过程会涉及X,W的两个scale因子

image-20240317225243979

如果想进一步优化速度,我们当然是希望提出两个scale,这样$x_q$和$w_q$两个量化后的结果就可以愉快地做乘法,对于整个XW也意味着可以先做INT8的矩阵乘法,而不是FP32/16的MM,两个scale只要在INT8矩阵乘法后做一个element-wise的乘法即可。说了这么多好处,提出两个scale的条件是与k无关,那么也就意味着,不同k之间的scale是一样的,也就对应per-tokenper-channel了~

image-20240317225528660

这部分内容来自 arxiv2004.09602

image-20240317224648570

说到这里,其实量化的理论基础就讲完了(吗

感觉量化还是更偏底层系统的方向,实现难难的

SmoothQuant

SmoothQuant的motivation是他们发现LLMs某些层产生的activation(上图中的X)在某些通道(col)上非常大,这导致了量化难度的增加,因为目前为了硬件实现效率,大家对于X的量化基本都是per-token,无法处理某些channel的outlier。而对于权重W,一般相对activation来说更加平坦,更容易量化。

在这个背景下,SmoothQuant提出可以把量化难度从X转移到W。

image-20240318125130339

具体来说,就是先对X进行per-channel的scale,然后把scale因子乘回W,保证等价性。

值得注意的是,这里还不涉及量化阶段,只是前处理阶段,或者论文中提到的offline stage。

有些人可能好奇X/输入从哪里来,X是来自校准集的元素产生的activation,是用来确定s的。

image-20240318125301476

那么S应该怎么确定呢,如果要尽可能的处理X中的outlier,最好的方法就是每个通道除以该通道最大的ABS MAX。但是这么处理会导致W的outlier过大,W量化难度过大。所以综合考虑X/W的量化难度,s由下式决定,$\alpha$一般取0.5。

image-20240318125858397

以下是一个Smooth例子,右侧是处理过的X/W。这个时候他们都比较好量化了。(后续实验好像是对这俩做了per-tensor量化)

image-20240318130042826

再次重申,这个时候还没有到量化阶段,这只是对activation/weight进行smooth。接下来就可以拿着Smooth后的W去量化咯,这里直接采用常规量化就可以了。

可能有人好奇,推理的时候,activation每进行一个MM,都需要scale一下吗,其实这个$diag(s)^{-1}$可以融到之前一个操作的矩阵里,这样就避免了额外的延时。

AWQ: Activation-aware Weight Quantization

AWQ应该是目前最流行的INT4量化方法了,他是SmoothQuant的延申工作。

我们之前说到,LLM中有些activation的某些通道值特别大,导致其不好量化。而在仅权重量化的基础上,作者发现,如果保留这些通道对应的W的精度(这些salient weight大概占整个权重的0.1%),比如保持FP16,那么结果的ppl同样可以保持很好的性能。

image-20240318131321256

但是混合精度表示W对于并行是较差的,所以作者进一步提出Scale的操作,希望把这些salient weight也量化了。

image-20240318131524548

这一部分呢,我觉得跟SmoothQuant还是很像的,不过这里对于X的每个通道是计算平均值找到对应的salient weight。找到这些深蓝色的salient weight之后,作者对他们进行了一定的放缩,增大了W(对应的也要缩小X对应的channel,这个和SmoothQuant一样,当然看问题的角度不太一样,因为这里只考虑仅权重量化,并不量化activation)。

那么为什么做这个Scale操作呢?其实是为了减少量化损失,对于普通的权重量化,损失一般在于Round操作的舍入误差,一般浮点数的舍入值在0~0.5,平均误差就是0.25。

image-20240318132142871

而先scale再量化的公式如下,一般来说在对应的salient weight row乘上因子s并不会影响weight的极值,那么$\Delta \approx \Delta'$,而Round误差一般也是不变的,那么下式的Err相比于原先的Err会多出一个$1/s(s \gt 1 )$,那么量化误差就变低。

这里的$\Delta$应该也是$\Delta'$,感觉是论文打错了,因为$\Delta$指$ws$即新权重的量化因子

image-20240318132356458

综上所述,对于salient weight对应的scale操作和SmoothQuant如出一辙,只不过一个是为了减少量化误差,一个是为了降低activation的量化难度。

DDPM中建模的$q(\mathbf{x}_{t-1} \vert \mathbf{x}_t, \mathbf{x}_0)$满足正态分布,

$$ q(\mathbf{x}_{t-1} \vert \mathbf{x}_t, \mathbf{x}_0) = \mathcal{N}(\mathbf{x}_{t-1}; \tilde{\boldsymbol{\mu}}(\mathbf{x}_t, \mathbf{x}_0), \tilde{\beta}_t \mathbf{I}) \\ \tilde{\beta}_t = \frac{1 - \bar{\alpha}_{t-1}}{1 - \bar{\alpha}_t} \cdot \beta_t $$

DDIM中建模的$q_\sigma(\mathbf{x}_{t-1} \vert \mathbf{x}_t, \mathbf{x}_0)$如下,第一个等式二三步用到了重参数技巧多个独立高斯分布的等价形式

$$ \begin{aligned} \mathbf{x}_{t-1} &= \sqrt{\bar{\alpha}_{t-1}}\mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_{t-1}}\boldsymbol{\epsilon}_{t-1} \\ &= \sqrt{\bar{\alpha}_{t-1}}\mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_{t-1} - \sigma_t^2} \boldsymbol{\epsilon}_t + \sigma_t\boldsymbol{\epsilon} \\ &= \sqrt{\bar{\alpha}_{t-1}}\mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_{t-1} - \sigma_t^2} \frac{\mathbf{x}_t - \sqrt{\bar{\alpha}_t}\mathbf{x}_0}{\sqrt{1 - \bar{\alpha}_t}} + \sigma_t\boldsymbol{\epsilon} \\ q_\sigma(\mathbf{x}_{t-1} \vert \mathbf{x}_t, \mathbf{x}_0) &= \mathcal{N}(\mathbf{x}_{t-1}; \sqrt{\bar{\alpha}_{t-1}}\mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_{t-1} - \sigma_t^2} \frac{\mathbf{x}_t - \sqrt{\bar{\alpha}_t}\mathbf{x}_0}{\sqrt{1 - \bar{\alpha}_t}}, \sigma_t^2 \mathbf{I}) \end{aligned} $$

DDIM中的的$\sigma_t$与DDPM中的$\tilde{\beta_t}$保持一致,并且添加了一个可学习参数控制方差,

$$ \sigma^2_t = \eta \cdot \tilde{\beta}_t $$

当$\eta=0$时,采样过程是确定的;当$\eta=1$时,退化成DDPM的形式,以下给出推导,

$$ \begin{aligned} \mu_{\mathbf{x}_{t-1}} &= \sqrt{\bar{\alpha}_{t-1}}\mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_{t-1} - \sigma_t^2} \frac{\mathbf{x}_t - \sqrt{\bar{\alpha}_t}\mathbf{x}_0}{\sqrt{1 - \bar{\alpha}_t}} \\ &= \sqrt{\bar{\alpha}_{t-1}}\mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_{t-1} - \tilde{\beta}_t} \frac{\mathbf{x}_t - \sqrt{\bar{\alpha}_t}\mathbf{x}_0}{\sqrt{1 - \bar{\alpha}_t}} \\ &= \sqrt{\bar{\alpha}_{t-1}}\mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_{t-1} - \frac{(1 - \bar{\alpha}_{t-1}) \cdot (1 - \alpha_t)}{1 - \bar{\alpha}_t}} \frac{\mathbf{x}_t - \sqrt{\bar{\alpha}_t}\mathbf{x}_0}{\sqrt{1 - \bar{\alpha}_t}} \\ &= (\sqrt{\bar{\alpha}_{t-1}} + \sqrt{1 - \bar{\alpha}_{t-1} - \frac{(1 - \bar{\alpha}_{t-1}) \cdot (1 - \alpha_t)}{1 - \bar{\alpha}_t}} \cdot \frac{ - \sqrt{\bar{\alpha}_t}}{\sqrt{1 - \bar{\alpha}_t}}) \cdot \mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_{t-1} - \frac{(1 - \bar{\alpha}_{t-1}) \cdot (1 - \alpha_t)}{1 - \bar{\alpha}_t}} \cdot \frac{\mathbf{x}_t}{\sqrt{1 - \bar{\alpha}_t}} \\ &= (\sqrt{\bar{\alpha}_{t-1}} + \sqrt{(1 - \bar{\alpha}_{t-1}) \cdot (1 - \frac{1 - \alpha_t}{1 - \bar{\alpha}_t})} \cdot \frac{ - \sqrt{\bar{\alpha}_t}}{\sqrt{1 - \bar{\alpha}_t}}) \cdot \mathbf{x}_0 + \sqrt{(1 - \bar{\alpha}_{t-1}) \cdot (1 - \frac{1 - \alpha_t}{1 - \bar{\alpha}_t})} \cdot \frac{\mathbf{x}_t}{\sqrt{1 - \bar{\alpha}_t}} \\ &= (\sqrt{\bar{\alpha}_{t-1}} + \sqrt{(1 - \bar{\alpha}_{t-1}) \cdot \frac{\alpha_t - \bar\alpha_t}{1 - \bar{\alpha}_t}} \cdot \frac{ - \sqrt{\bar{\alpha}_t}}{\sqrt{1 - \bar{\alpha}_t}}) \cdot \mathbf{x}_0 + \sqrt{(1 - \bar{\alpha}_{t-1}) \frac{\alpha_t - \bar\alpha_t}{1 - \bar{\alpha}_t}} \cdot \frac{\mathbf{x}_t}{\sqrt{1 - \bar{\alpha}_t}} \\ &= (\sqrt{\bar{\alpha}_{t-1}} + \sqrt{1 - \bar{\alpha}_{t-1}} \cdot \frac{\sqrt{\alpha_t - \bar\alpha_t}}{\sqrt{{1 - \bar{\alpha}_t}}} \cdot \frac{ - \sqrt{\bar{\alpha}_t}}{\sqrt{1 - \bar{\alpha}_t}}) \cdot \mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_{t-1}} \cdot \frac{\sqrt{\alpha_t - \bar\alpha_t}}{\sqrt{{1 - \bar{\alpha}_t}}} \cdot \frac{\mathbf{x}_t}{\sqrt{1 - \bar{\alpha}_t}} \\ &= (\sqrt{\bar{\alpha}_{t-1}} - \sqrt{1 - \bar{\alpha}_{t-1}} \cdot \frac{\sqrt{1 - \bar\alpha_{t-1}} \cdot \sqrt{\alpha_t}}{\sqrt{{1 - \bar{\alpha}_t}}} \cdot \frac{\sqrt{\bar{\alpha}_t}}{\sqrt{1 - \bar{\alpha}_t}}) \cdot \mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_{t-1}} \cdot \frac{\sqrt{1 - \bar\alpha_{t-1}} \cdot \sqrt{\alpha_t}}{\sqrt{{1 - \bar{\alpha}_t}}} \cdot \frac{\mathbf{x}_t}{\sqrt{1 - \bar{\alpha}_t}} \\ &= (\sqrt{\bar{\alpha}_{t-1}} - \frac{\sqrt{\alpha_t} \cdot \sqrt{\bar\alpha_t} \cdot (1 - \bar\alpha_{t-1})}{1 - \bar\alpha_{t}}) \cdot \mathbf{x}_0 + \frac{\sqrt{\alpha_t} \cdot (1 - \bar\alpha_{t-1}) \cdot \mathbf{x}_t}{1 - \bar\alpha_{t}} \\ &= (\frac{\sqrt{\bar{\alpha}_{t-1}} - \bar\alpha_t \cdot \sqrt{\bar{\alpha}_{t-1}} - \sqrt{\alpha_t} \cdot \sqrt{\bar\alpha_t} \cdot + \sqrt{\alpha_t} \cdot \sqrt{\bar\alpha_t} \cdot \bar\alpha_{t-1}}{1 - \bar\alpha_{t}}) \cdot \mathbf{x}_0 + \frac{\sqrt{\alpha_t} \cdot (1 - \bar\alpha_{t-1}) \cdot \mathbf{x}_t}{1 - \bar\alpha_{t}} \\ &= (\frac{\sqrt{\bar{\alpha}_{t-1}} - \bar\alpha_t \cdot \sqrt{\bar{\alpha}_{t-1}} - \sqrt{\alpha_{t-1}} \cdot \bar\alpha_t + \sqrt{\bar\alpha_{t-1}} \cdot \bar\alpha_t}{1 - \bar\alpha_{t}}) \cdot \mathbf{x}_0 + \frac{\sqrt{\alpha_t} \cdot (1 - \bar\alpha_{t-1}) \cdot \mathbf{x}_t}{1 - \bar\alpha_{t}} \\ &= (\frac{\sqrt{\bar{\alpha}_{t-1}} - \sqrt{\alpha_{t-1}} \cdot \bar\alpha_t }{1 - \bar\alpha_{t}}) \cdot \mathbf{x}_0 + \frac{\sqrt{\alpha_t} \cdot (1 - \bar\alpha_{t-1}) \cdot \mathbf{x}_t}{1 - \bar\alpha_{t}} \\ &= \mu_{\mathbf{x}_{t-1}}^{DDPM} \end{aligned} $$

最近看到一个写的挺好的多任务框架(https://github.com/SwinTransformer/AiT,参考了detectron2),分享一下~

多任务训练一般分为两种:data mixing 和 batch mixing。简单来说,对于前者,一个batch中的样本可以来自不同任务,而后者一个batch中任务都是一样的。两者相比,后者实现更加容易,效率更高,并且做数据增强也更方便一点(由Pix2Seq提出,但其实数据增强我认为并不是两者的主要差异)。

先给出我写的一个batch mixing的例子(来自JJJYmmm/Pix2SeqV2-Pytorch):

def get_multi_task_loaders(tokenizer,tasks):

    assert set(tasks) <= set(['detection', 'keypoint', 'segmentation', 'captioning'])

    train_loaders = {}
    valid_loaders = {}

    if 'detection' in tasks:
        detection_train_loader, detection_valid_loader = detection_loaders(
        CFG.dir_root, tokenizer, CFG.img_size, CFG.batch_size, CFG.max_len, tokenizer.PAD_code)
        train_loaders['detection'] = detection_train_loader
        valid_loaders['detection'] = detection_valid_loader
   
    if 'keypoint' in tasks: 
        keypoint_train_loader, keypoint_valid_loader = keypoint_loaders(
        CFG.dir_root, tokenizer,person_kps_info, CFG.img_size, CFG.batch_size, CFG.max_len, tokenizer.PAD_code)
        train_loaders['keypoint'] = keypoint_train_loader
        valid_loaders['keypoint'] = keypoint_valid_loader

    if 'segmentation' in tasks: 
        segmentation_train_loader, segmentation_valid_loader = segmentation_loaders(
        CFG.dir_root, tokenizer,person_kps_info, CFG.img_size, CFG.batch_size, CFG.max_len, tokenizer.PAD_code)
        train_loaders['segmentation'] = segmentation_train_loader
        valid_loaders['segmentation'] = segmentation_valid_loader

    if 'captioning' in tasks:
        img_caption_train_loader, img_caption_valid_loader = img_caption_loaders(
        CFG.dir_root, tokenizer,vocab, CFG.img_size, CFG.batch_size, CFG.max_len, tokenizer.PAD_code)
        train_loaders['captioning'] = img_caption_train_loader
        valid_loaders['captioning'] = img_caption_valid_loader
    
    return train_loaders, valid_loaders

以上代码首先维护一个多任务的dataloader字典,在这里之前(即创建dataset)就可以针对不同任务做对应的数据增强。

# get longest dataloader
epoch_size = 0
longest_loader = None
for name, loader in train_loaders.items():
    if len(loader) > epoch_size:
        epoch_size = len(loader)
        longest_loader = name

# create iter for dataloaders
loader_iters = dict()
for k, v in train_loaders.items():
    if k != longest_loader:
        loader_iters[k] = iter(v)

# iter longest dataloader
tqdm_object = tqdm(train_loaders[longest_loader], total=len(train_loaders[longest_loader]))
for iteration,(x, y, init_lens) in enumerate(tqdm_object):

    optimizer.zero_grad()
        
    total_loss = torch.zeros(1, requires_grad=False, device=CFG.device)
    total_batch = x.size(0)
    # loss_1
    loss = cal_loss_multi_task(model, criterion, x, y, init_lens, task_id = task_ids[longest_loader])
    total_loss = total_loss + loss.item() * task_weights[longest_loader]
    loss *= task_weights[longest_loader] # mul weight
    loss.backward()
    # calculate other tasks' loss
    for k, v in loader_iters.items():
        try:
            (x, y, init_lens) = next(v)
        except StopIteration: # recover other tasks' iter
            loader_iters[k] = iter(train_loaders[k])
            (x, y, init_lens) = next(loader_iters[k])
        total_batch += x.size(0)
        # loss_i
        loss = cal_loss_multi_task(model, criterion, x, y, init_lens, task_id=task_ids[k])
        total_loss = total_loss + loss.item() * task_weights[k]
        loss *= task_weights[k]
        loss.backward()

    # total_loss.backward()
    optimizer.step()

训练时,首先确定batch数最多的任务(数据集A),把它作为训练的最外层,这一部分和单任务训练一致,对于其他任务,则分别创建一个迭代器iterator负责取数据。之后在循环数据集A的时候,每次算出loss_A后,会从其他迭代器中取出对应的数据并计算loss_B/loss_C...,之后根据任务权重对loss进行加权平均,并进行反向传播。

在上述代码中,为了节省显存,每次计算完loss,我都直接乘上权重反向传播了,这个好处在于,每个loss计算完后对应的计算图会被自动释放,如果显式显出加权平均,那么所有任务的计算图都会被保留~

可以看出,batch mixing其实只是在训练过程中加入多个数据集的batch,然后分别算出loss并反向传播罢了。(data mixing也差不多,只是粒度更细一点)

关于data mixing,其实就比batch mixing多了一步操作,就是把所有的数据集拼起来,然后对于每个样本都添加一个字段表示任务。取数据的时候直接从大数据集里面取就可以了。

class ConcatDataset(Dataset[T_co]):
    r"""Dataset as a concatenation of multiple datasets.

    This class is useful to assemble different existing datasets.

    Args:
        datasets (sequence): List of datasets to be concatenated
    """
    datasets: List[Dataset[T_co]]
    cumulative_sizes: List[int]

    @staticmethod
    def cumsum(sequence):
        r, s = [], 0
        for e in sequence:
            l = len(e)
            r.append(l + s)
            s += l
        return r

    def __init__(self, datasets: Iterable[Dataset]) -> None:
        super(ConcatDataset, self).__init__()
        self.datasets = list(datasets)
        assert len(self.datasets) > 0, 'datasets should not be an empty iterable'  # type: ignore[arg-type]
        for d in self.datasets:
            assert not isinstance(d, IterableDataset), "ConcatDataset does not support IterableDataset"
        self.cumulative_sizes = self.cumsum(self.datasets)

    def __len__(self):
        return self.cumulative_sizes[-1]

    def __getitem__(self, idx):
        if idx < 0:
            if -idx > len(self):
                raise ValueError("absolute value of index should not exceed dataset length")
            idx = len(self) + idx
        dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
        if dataset_idx == 0:
            sample_idx = idx
        else:
            sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
        return self.datasets[dataset_idx][sample_idx]

    @property
    def cummulative_sizes(self):
        warnings.warn("cummulative_sizes attribute is renamed to "
                      "cumulative_sizes", DeprecationWarning, stacklevel=2)
        return self.cumulative_sizes

到这里也就知道,data mixing的数据增强也很方便,在拼接数据集之前,各个数据集定义自己的增强方式即可。但是data mixing的最大问题在于:训练过程中,需要循环batch里的每个样本,根据任务的不同分配给不同的Heads处理loss,并行度其实很差。当然,对于大模型训练来说,有时候单张GPU就放一个样本,那这个劣势就相当于没有了~

最后总结就是,data mixing相比于batch mixing,任务粒度更细,并且对于多任务的支持更好(不像batch mixing每加一个任务就要改train代码,data mixing只需要写好数据集和对应处理的Head即可),但是并行度相比于batch mixing较差。