分类 深度学习 下的文章

年前分享一篇关于Sequence Packing的论文:2107.02027

笔者认为Packing是Transformer架构训练时一个比较比较核心的问题(应用也非常广泛,LLMs如Llama 3会拼接多个文档做上下文的拓展),因为Transformer一般有一个固定的Context Size S,所有样本都需要pad到S的长度,从而满足transformer的矩阵计算;而pad token是不参与loss计算的,所以这里就带来了Context的浪费;对于一些分布很极端的数据集如Wikipedia,一个epoch下来,pad token要占到所有训练token的50%。

在以上背景下,有两种可以解决pad token过多带来的计算损耗。一个是优化计算流程,使pad token不参与attention这样复杂度高的计算,但是这种方法往往需要非常工程的优化;另一个是降低pad token的比例,也就是所谓的Sequence Packing了——把多个样本或序列拼在一起。

Packing带来的好处是非常明显的:多个样本拼在一起,首先减少了样本数,其次单条数据中有效token的比例也增加,这些优点加快了模型的收敛。

那么如何确定packing的逻辑呢?如果考虑全局最优(优化目标是pad token最少),这个问题其实就是装箱问题,NP hard。所以我们能承受的一般是启发式等近似算法。接下来介绍论文里提到的两个Packing Strategy。

Non-Negative Least Squares Histogram-Packing(NNLSHP)

NNLSHP利用了训练数据分布的直方图信息进行最小二乘法求解Packing逻辑。

首先,给定Context Size S,它先定义了一系列Packing的组合。例如 S = 512 = 256 + 128 + 128。一条512的数据可以通过拼接一条256的数据和两条128的数据得到。这个组合其实就是所谓的Partition Number (谢谢欧拉)。为了减少组合数,算法定义一条数据最多由N条数据打包到一起。得到所有的组合之后,可以排布成一个矩阵。

其中每一行代表某一个长度的数据,每一列代表一种策略,列和小于等于N。例如对于S = 8,N = 3,Strategy Matrix如下图。总共有10种策略,第一种就是两个长度为1的数据和长度为6的数据拼在一起,得到一条长度为8的新数据。

matrix

得到这个Matrix后,整个数据集的分布我们也是知道的,即每个长度的数据各自有多少条,我们记为b;每个策略应该应用多少次,我们记为x。那么x可以通过求解Ax=b得到。

得到x之后,我们就可以愉快的进行Packing了。当然这边有一些细节,例如x按照定义应该都是正整数,但是实际应用时只要按照$x \geq 0$计算再舍入即可(这就是算法名字中的NN)。

Shortest-Pack-First Histogram Packing

第一个算法还是存在一些缺点,首先预设了每条数据最多由三条数据Pack得到,其次Strategy Matrix随着S增大也增大,在超长上下文时计算效率太低。所以我们希望可以把这个痛点解决,于是乎启发式的SPFHP来了~

其实内容只需要三行就可以解释清楚:

  1. 这个算法运行在数据分布直方图上,从长序列往短序列遍历
  2. 检查当前长度能不能和前一个长度进行合并,如果不能,那么他们单独占一个样本位置
  3. 如果可以(或者部分可以),那么和前一个长度进行合并,并更新直方图的信息

例如下图S = 8的情况;从8遍历到3,此时长度为3的数据可以和长度为4的数据进行合并,组成长度为7的四条数据

image-20250125150202379

接着长度为2的数据可以和长度为5的数据进行拼接,组成长度7的数据

image-20250125150354232

最后三个长度1的数据分别和5、6的数据进行合并

image-20250125150440659

最终得到的合并序列如下,对于一些空位,再填入pad token。

image-20250125150519134

这个算法运行复杂度是O(S),因此速度非常快。值得注意的是每次合并时,选择的是前一个长度的数据进行合并,也就是说,每次合并保证预留出来的空间最多。这个和传统的First-Fit或者磁盘管理中的最佳适应是不太一样的,最佳适应是要求剩余空间尽可能少,而这里是尽可能多。

论文里也解释了:两种策略从最终效果来说相差不大,而First-Fit复杂度更高,因此选择了short-Fit。复杂度高是因为这里多了一步二分查找的过程。

2025第一篇博客,Hello

RoPEv1: chatglm/baichuan中使用
RoPEv2: Llama中使用

两者区别

  • v1代码较为繁琐,但是和原始算法对应
  • v2使用torch.complex实现,更加明了,不过没有严格遵循原算法,qk转到复数域时,并没有使用相邻的神经元作为复数的实部和虚部,而是直接在特征维度上切分,一半实部,一半虚部 Su的解释:神经元是无序的(dot attention做内积,不依赖于维度顺序)
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

### RoPEv1
def sinusoidal_position_embedding(batch_size, nums_head, max_len, output_dim, device = 'cpu'):
    # (max_len, 1)
    position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(-1)
    # (output_dim//2)
    ids = torch.arange(0, output_dim // 2, dtype=torch.float)
    theta = torch.pow(10000, -2 * ids / output_dim)

    # (max_len, output_dim//2)
    embeddings = position * theta  # pos / (10000^(2i/d))

    # (max_len, output_dim//2, 2)
    embeddings = torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1)

    # (bs, head, max_len, output_dim//2, 2)
    embeddings = embeddings.repeat((batch_size, nums_head, *([1] * len(embeddings.shape))))

    # (bs, head, max_len, output_dim)
    embeddings = torch.reshape(embeddings, (batch_size, nums_head, max_len, output_dim))
    embeddings = embeddings.to(device)
    return embeddings

def RoPEv1(q: torch.tensor, k: torch.tensor):
    # q,k: (bs, head, max_len, output_dim)
    batch_size = q.shape[0]
    nums_head = q.shape[1]
    max_len = q.shape[2]
    output_dim = q.shape[-1]

    # (bs, head, max_len, output_dim)
    pos_emb = sinusoidal_position_embedding(batch_size, nums_head, max_len, output_dim, q.device)

    # cos_pos,sin_pos: (bs, head, max_len, output_dim)
    cos_pos = pos_emb[...,  1::2].repeat_interleave(2, dim=-1)  # (cos cos) (cos cos)...
    sin_pos = pos_emb[..., ::2].repeat_interleave(2, dim=-1)  # (sin sin) (sin sin)...

    # q,k: (bs, head, max_len, output_dim)
    q2 = torch.stack([-q[..., 1::2], q[..., ::2]], dim=-1)
    q2 = q2.reshape(q.shape)
    q = q * cos_pos + q2 * sin_pos

    k2 = torch.stack([-k[..., 1::2], k[..., ::2]], dim=-1)
    k2 = k2.reshape(k.shape)
    k = k * cos_pos + k2 * sin_pos

    return q, k

### RoPEv2
def precompute_freqs_cis(dim: int, seq_len: int, theta: float = 10000.0):
    # freqs.shape = [dim // 2]
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    # t.shape = [seq_len]
    t = torch.arange(seq_len, device=freqs.device)
    # freqs.shape = [seq_len, dim // 2] 
    freqs = torch.outer(t, freqs).float()
    # https://pytorch.org/docs/stable/generated/torch.polar.html
    # get unit complex, each element has (cosx, isinx)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_cis

def RoPEv2(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
):
    xq_ = xq.float().reshape(*xq.shape[:-1], -1, 2)
    xk_ = xk.float().reshape(*xk.shape[:-1], -1, 2)
    
    xq_ = torch.view_as_complex(xq_)
    xk_ = torch.view_as_complex(xk_)
    
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)

### Attn
def attention(q, k, v, mask=None, dropout=None, use_RoPEv1=True, use_RoPEv2=False):
    # q.shape: (bs, head, seq_len, dk)
    # k.shape: (bs, head, seq_len, dk)
    # v.shape: (bs, head, seq_len, dk)
    assert not (use_RoPEv1 and use_RoPEv2)

    if use_RoPEv1:
        q, k = RoPEv1(q, k)
    elif use_RoPEv2:
        freqs_cis = precompute_freqs_cis(q.shape[-1],q.shape[-2])
        q, k = RoPEv2(q, k, freqs_cis)
    
    print(q.shape, k.shape, v.shape)

    d_k = k.size()[-1]

    att_logits = torch.matmul(q, k.transpose(-2, -1))  # (bs, head, seq_len, seq_len)
    att_logits /= math.sqrt(d_k)

    if mask is not None:
        att_logits = att_logits.masked_fill(mask == 0, -1e9)

    att_scores = F.softmax(att_logits, dim=-1)  # (bs, head, seq_len, seq_len)

    if dropout is not None:
        att_scores = dropout(att_scores)

    # (bs, head, seq_len, seq_len) * (bs, head, seq_len, dk) = (bs, head, seq_len, dk)
    return torch.matmul(att_scores, v)


if __name__ == '__main__':
    # (bs, head, seq_len, dk)
    q = torch.randn((8, 12, 10, 32))
    k = torch.randn((8, 12, 10, 32))
    v = torch.randn((8, 12, 10, 32))

    # check result
    resultv1 = attention(q, k, v, use_RoPEv1=True, use_RoPEv2=False)
    resultv2 = attention(q, k, v, use_RoPEv1=False, use_RoPEv2=True)
    print(torch.allclose(resultv1, resultv2, rtol=1e-2))

Pytorch中的kaiming_uniform中标准差stdv乘了一个因子$\sqrt{3}$

def kaiming_uniform_(
    tensor: Tensor, a: float = 0, mode: str = 'fan_in', nonlinearity: str = 'leaky_relu'
):
    #......
    std = gain / math.sqrt(fan)
    bound = math.sqrt(3.0) * std  # Calculate uniform bounds from standard deviation
    with torch.no_grad():
        return tensor.uniform_(-bound, bound)

为什么会出现这个因子,原因其实和物理实验里B类误差需要除去$\sqrt{3}$一样。
均匀分布的方差是$Var = \frac {(b-a)^2}{12}$,而我们真正想要的bound是$(b-a)/2$,所以有$Var = \frac{1}{3} {bound}^2$,标准差则是$\frac{1}{\sqrt 3} bound$。所以求出标准差stdv后,还需要乘一个$\sqrt{3}$得到真正的bound。

不过理论虽然如此,实践中这个因子似乎已经被抛弃了。
https://github.com/pytorch/pytorch/issues/57109