RoPE两种实现方式

深度学习 · 2024-04-29 · 1165 人浏览

RoPEv1: chatglm/baichuan中使用
RoPEv2: Llama中使用

两者区别

  • v1代码较为繁琐,但是和原始算法对应
  • v2使用torch.complex实现,更加明了,不过没有严格遵循原算法,qk转到复数域时,并没有使用相邻的神经元作为复数的实部和虚部,而是直接在特征维度上切分,一半实部,一半虚部 Su的解释:神经元是无序的(dot attention做内积,不依赖于维度顺序)
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

### RoPEv1
def sinusoidal_position_embedding(batch_size, nums_head, max_len, output_dim, device = 'cpu'):
    # (max_len, 1)
    position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(-1)
    # (output_dim//2)
    ids = torch.arange(0, output_dim // 2, dtype=torch.float)
    theta = torch.pow(10000, -2 * ids / output_dim)

    # (max_len, output_dim//2)
    embeddings = position * theta  # pos / (10000^(2i/d))

    # (max_len, output_dim//2, 2)
    embeddings = torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1)

    # (bs, head, max_len, output_dim//2, 2)
    embeddings = embeddings.repeat((batch_size, nums_head, *([1] * len(embeddings.shape))))

    # (bs, head, max_len, output_dim)
    embeddings = torch.reshape(embeddings, (batch_size, nums_head, max_len, output_dim))
    embeddings = embeddings.to(device)
    return embeddings

def RoPEv1(q: torch.tensor, k: torch.tensor):
    # q,k: (bs, head, max_len, output_dim)
    batch_size = q.shape[0]
    nums_head = q.shape[1]
    max_len = q.shape[2]
    output_dim = q.shape[-1]

    # (bs, head, max_len, output_dim)
    pos_emb = sinusoidal_position_embedding(batch_size, nums_head, max_len, output_dim, q.device)

    # cos_pos,sin_pos: (bs, head, max_len, output_dim)
    cos_pos = pos_emb[...,  1::2].repeat_interleave(2, dim=-1)  # (cos cos) (cos cos)...
    sin_pos = pos_emb[..., ::2].repeat_interleave(2, dim=-1)  # (sin sin) (sin sin)...

    # q,k: (bs, head, max_len, output_dim)
    q2 = torch.stack([-q[..., 1::2], q[..., ::2]], dim=-1)
    q2 = q2.reshape(q.shape)
    q = q * cos_pos + q2 * sin_pos

    k2 = torch.stack([-k[..., 1::2], k[..., ::2]], dim=-1)
    k2 = k2.reshape(k.shape)
    k = k * cos_pos + k2 * sin_pos

    return q, k

### RoPEv2
def precompute_freqs_cis(dim: int, seq_len: int, theta: float = 10000.0):
    # freqs.shape = [dim // 2]
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    # t.shape = [seq_len]
    t = torch.arange(seq_len, device=freqs.device)
    # freqs.shape = [seq_len, dim // 2] 
    freqs = torch.outer(t, freqs).float()
    # https://pytorch.org/docs/stable/generated/torch.polar.html
    # get unit complex, each element has (cosx, isinx)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_cis

def RoPEv2(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
):
    xq_ = xq.float().reshape(*xq.shape[:-1], -1, 2)
    xk_ = xk.float().reshape(*xk.shape[:-1], -1, 2)

    xq_ = torch.view_as_complex(xq_)
    xk_ = torch.view_as_complex(xk_)

    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)

### Attn
def attention(q, k, v, mask=None, dropout=None, use_RoPEv1=True, use_RoPEv2=False):
    # q.shape: (bs, head, seq_len, dk)
    # k.shape: (bs, head, seq_len, dk)
    # v.shape: (bs, head, seq_len, dk)
    assert not (use_RoPEv1 and use_RoPEv2)

    if use_RoPEv1:
        q, k = RoPEv1(q, k)
    elif use_RoPEv2:
        freqs_cis = precompute_freqs_cis(q.shape[-1],q.shape[-2])
        q, k = RoPEv2(q, k, freqs_cis)

    print(q.shape, k.shape, v.shape)

    d_k = k.size()[-1]

    att_logits = torch.matmul(q, k.transpose(-2, -1))  # (bs, head, seq_len, seq_len)
    att_logits /= math.sqrt(d_k)

    if mask is not None:
        att_logits = att_logits.masked_fill(mask == 0, -1e9)

    att_scores = F.softmax(att_logits, dim=-1)  # (bs, head, seq_len, seq_len)

    if dropout is not None:
        att_scores = dropout(att_scores)

    # (bs, head, seq_len, seq_len) * (bs, head, seq_len, dk) = (bs, head, seq_len, dk)
    return torch.matmul(att_scores, v)


if __name__ == '__main__':
    # (bs, head, seq_len, dk)
    q = torch.randn((8, 12, 10, 32))
    k = torch.randn((8, 12, 10, 32))
    v = torch.randn((8, 12, 10, 32))

    # check result
    resultv1 = attention(q, k, v, use_RoPEv1=True, use_RoPEv2=False)
    resultv2 = attention(q, k, v, use_RoPEv1=False, use_RoPEv2=True)
    print(torch.allclose(resultv1, resultv2, rtol=1e-2))
  1. Axuanz (作者)  2024-05-23

    qwen用的rope是v2 但实现没有llama优雅(

  2. 歪贼歪 2024-04-29

    很好rope 使我脑袋dizzy

    1. Axuanz (作者)  2024-05-02
      @歪贼歪

      很好video diffusion 使我脑袋dizzy

Theme Jasmine by Kent Liao