RoPEv1: chatglm/baichuan中使用
RoPEv2: Llama中使用
两者区别
- v1代码较为繁琐,但是和原始算法对应
- v2使用torch.complex实现,更加明了,不过没有严格遵循原算法,qk转到复数域时,并没有使用相邻的神经元作为复数的实部和虚部,而是直接在特征维度上切分,一半实部,一半虚部 Su的解释:
神经元是无序的(dot attention做内积,不依赖于维度顺序)
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
### RoPEv1
def sinusoidal_position_embedding(batch_size, nums_head, max_len, output_dim, device = 'cpu'):
# (max_len, 1)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(-1)
# (output_dim//2)
ids = torch.arange(0, output_dim // 2, dtype=torch.float)
theta = torch.pow(10000, -2 * ids / output_dim)
# (max_len, output_dim//2)
embeddings = position * theta # pos / (10000^(2i/d))
# (max_len, output_dim//2, 2)
embeddings = torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1)
# (bs, head, max_len, output_dim//2, 2)
embeddings = embeddings.repeat((batch_size, nums_head, *([1] * len(embeddings.shape))))
# (bs, head, max_len, output_dim)
embeddings = torch.reshape(embeddings, (batch_size, nums_head, max_len, output_dim))
embeddings = embeddings.to(device)
return embeddings
def RoPEv1(q: torch.tensor, k: torch.tensor):
# q,k: (bs, head, max_len, output_dim)
batch_size = q.shape[0]
nums_head = q.shape[1]
max_len = q.shape[2]
output_dim = q.shape[-1]
# (bs, head, max_len, output_dim)
pos_emb = sinusoidal_position_embedding(batch_size, nums_head, max_len, output_dim, q.device)
# cos_pos,sin_pos: (bs, head, max_len, output_dim)
cos_pos = pos_emb[..., 1::2].repeat_interleave(2, dim=-1) # (cos cos) (cos cos)...
sin_pos = pos_emb[..., ::2].repeat_interleave(2, dim=-1) # (sin sin) (sin sin)...
# q,k: (bs, head, max_len, output_dim)
q2 = torch.stack([-q[..., 1::2], q[..., ::2]], dim=-1)
q2 = q2.reshape(q.shape)
q = q * cos_pos + q2 * sin_pos
k2 = torch.stack([-k[..., 1::2], k[..., ::2]], dim=-1)
k2 = k2.reshape(k.shape)
k = k * cos_pos + k2 * sin_pos
return q, k
### RoPEv2
def precompute_freqs_cis(dim: int, seq_len: int, theta: float = 10000.0):
# freqs.shape = [dim // 2]
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
# t.shape = [seq_len]
t = torch.arange(seq_len, device=freqs.device)
# freqs.shape = [seq_len, dim // 2]
freqs = torch.outer(t, freqs).float()
# https://pytorch.org/docs/stable/generated/torch.polar.html
# get unit complex, each element has (cosx, isinx)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis
def RoPEv2(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
):
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 2)
xq_ = torch.view_as_complex(xq_)
xk_ = torch.view_as_complex(xk_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
### Attn
def attention(q, k, v, mask=None, dropout=None, use_RoPEv1=True, use_RoPEv2=False):
# q.shape: (bs, head, seq_len, dk)
# k.shape: (bs, head, seq_len, dk)
# v.shape: (bs, head, seq_len, dk)
assert not (use_RoPEv1 and use_RoPEv2)
if use_RoPEv1:
q, k = RoPEv1(q, k)
elif use_RoPEv2:
freqs_cis = precompute_freqs_cis(q.shape[-1],q.shape[-2])
q, k = RoPEv2(q, k, freqs_cis)
print(q.shape, k.shape, v.shape)
d_k = k.size()[-1]
att_logits = torch.matmul(q, k.transpose(-2, -1)) # (bs, head, seq_len, seq_len)
att_logits /= math.sqrt(d_k)
if mask is not None:
att_logits = att_logits.masked_fill(mask == 0, -1e9)
att_scores = F.softmax(att_logits, dim=-1) # (bs, head, seq_len, seq_len)
if dropout is not None:
att_scores = dropout(att_scores)
# (bs, head, seq_len, seq_len) * (bs, head, seq_len, dk) = (bs, head, seq_len, dk)
return torch.matmul(att_scores, v)
if __name__ == '__main__':
# (bs, head, seq_len, dk)
q = torch.randn((8, 12, 10, 32))
k = torch.randn((8, 12, 10, 32))
v = torch.randn((8, 12, 10, 32))
# check result
resultv1 = attention(q, k, v, use_RoPEv1=True, use_RoPEv2=False)
resultv2 = attention(q, k, v, use_RoPEv1=False, use_RoPEv2=True)
print(torch.allclose(resultv1, resultv2, rtol=1e-2))
qwen用的rope是v2 但实现没有llama优雅(
很好rope 使我脑袋dizzy
很好video diffusion 使我脑袋dizzy