RLHF PPO的时候为了节约成本,训练时一个batch的数据会多次利用,即多次更新actor model和critic model。

和Reference KL distance用于限制PPO模型和原始SFT模型之间的距离(防止PPO训歪,这一项是加在Reward model产生的R中,如deepspeed chat)一样,多次更新actor model也需要有原始actor model作为约束,因此actor loss计算中会有一项$logits/old\_logits$(重要性采样)

actor loss代码如下(from deepspeed chat)

    def actor_loss_fn(self, logprobs, old_logprobs, advantages, mask):
        ## policy gradient loss
        log_ratio = (logprobs - old_logprobs) * mask
        ratio = torch.exp(log_ratio)
        pg_loss1 = -advantages * ratio
        pg_loss2 = -advantages * torch.clamp(ratio, 1.0 - self.cliprange,
                                             1.0 + self.cliprange)
        pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2) * mask) / mask.sum()
        return pg_loss

那么一个batch第一次迭代的时候actor loss对应的ratio是1(因为actor model没有更新,两次forward产生的logits一致),那这次更新critic model只依赖所谓的advantages(与R和V有关,另外actor model不更新,因为advantages和它无关)。

今日力扣是关于单调栈的,刚开始我使用stack实现,最后题目要求返回vector,因为c++把stack视为容器而不是迭代器,所以使用pointer来计算stack的起始终止位置,如以下代码。

class Solution {
public:
    vector<int> mostCompetitive(vector<int>& nums, int k) {
        stack<int> stk;
        int n = nums.size();
        for(int i=0;i<n;i++){
            while(!stk.empty() && nums[i] < stk.top() && stk.size() + n - i - 1>= k)
                stk.pop();
            stk.push(nums[i]);
        }
        int *begin = &stk.top() + 1 - stk.size();
        return vector<int>(begin, begin + k);
    }
};

但以上代码过不了最后几个用例,报heap-overflow,判断可能是stack的内存连续性问题。因为stack默认使用deque(即双端数组)作为底层容器,实际上一个逻辑数组在内存中分块存储的。以下是llama-70b的回答。

That's correct. `std::deque` does not guarantee contiguous memory allocation. It's a double-ended queue that typically uses a sequence of fixed-size arrays (called "blocks" or "chunks") to store its elements. Each block has a fixed capacity, and when the deque grows or shrinks, new blocks are allocated or deallocated as needed.

The elements in a `std::deque` are not stored in a single contiguous block of memory. Instead, the deque maintains a sequence of pointers to these blocks, which allows it to efficiently insert or remove elements at any position.

This design provides several benefits, such as:

1. Efficient insertion and deletion at any position, with an average time complexity of O(1).
2. Good cache locality for iterating over the elements.
3. Ability to handle large datasets without running out of contiguous memory.

However, this design also means that `std::deque` does not provide a contiguous block of memory, which can be a limitation in certain situations, such as:

1. When you need to pass a pointer to a contiguous array to a C-style API or a function that expects a contiguous array.
2. When you need to use pointer arithmetic or indexing into the array.

In contrast, `std::vector` does guarantee contiguous memory allocation, which is why it's often preferred when contiguous memory is required.

In the case of `std::stack`, when you use the default underlying container (`std::deque`), you don't have a guarantee of contiguous memory. But when you specify `std::vector` as the underlying container, you do get contiguous memory, which is why the original code worked when using `stack<int, vector<int>>`.

解决方案: stack<int>stack<int, vector<int>>

RoPEv1: chatglm/baichuan中使用
RoPEv2: Llama中使用

两者区别

  • v1代码较为繁琐,但是和原始算法对应
  • v2使用torch.complex实现,更加明了,不过没有严格遵循原算法,qk转到复数域时,并没有使用相邻的神经元作为复数的实部和虚部,而是直接在特征维度上切分,一半实部,一半虚部 Su的解释:神经元是无序的(dot attention做内积,不依赖于维度顺序)
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

### RoPEv1
def sinusoidal_position_embedding(batch_size, nums_head, max_len, output_dim, device = 'cpu'):
    # (max_len, 1)
    position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(-1)
    # (output_dim//2)
    ids = torch.arange(0, output_dim // 2, dtype=torch.float)
    theta = torch.pow(10000, -2 * ids / output_dim)

    # (max_len, output_dim//2)
    embeddings = position * theta  # pos / (10000^(2i/d))

    # (max_len, output_dim//2, 2)
    embeddings = torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1)

    # (bs, head, max_len, output_dim//2, 2)
    embeddings = embeddings.repeat((batch_size, nums_head, *([1] * len(embeddings.shape))))

    # (bs, head, max_len, output_dim)
    embeddings = torch.reshape(embeddings, (batch_size, nums_head, max_len, output_dim))
    embeddings = embeddings.to(device)
    return embeddings

def RoPEv1(q: torch.tensor, k: torch.tensor):
    # q,k: (bs, head, max_len, output_dim)
    batch_size = q.shape[0]
    nums_head = q.shape[1]
    max_len = q.shape[2]
    output_dim = q.shape[-1]

    # (bs, head, max_len, output_dim)
    pos_emb = sinusoidal_position_embedding(batch_size, nums_head, max_len, output_dim, q.device)

    # cos_pos,sin_pos: (bs, head, max_len, output_dim)
    cos_pos = pos_emb[...,  1::2].repeat_interleave(2, dim=-1)  # (cos cos) (cos cos)...
    sin_pos = pos_emb[..., ::2].repeat_interleave(2, dim=-1)  # (sin sin) (sin sin)...

    # q,k: (bs, head, max_len, output_dim)
    q2 = torch.stack([-q[..., 1::2], q[..., ::2]], dim=-1)
    q2 = q2.reshape(q.shape)
    q = q * cos_pos + q2 * sin_pos

    k2 = torch.stack([-k[..., 1::2], k[..., ::2]], dim=-1)
    k2 = k2.reshape(k.shape)
    k = k * cos_pos + k2 * sin_pos

    return q, k

### RoPEv2
def precompute_freqs_cis(dim: int, seq_len: int, theta: float = 10000.0):
    # freqs.shape = [dim // 2]
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    # t.shape = [seq_len]
    t = torch.arange(seq_len, device=freqs.device)
    # freqs.shape = [seq_len, dim // 2] 
    freqs = torch.outer(t, freqs).float()
    # https://pytorch.org/docs/stable/generated/torch.polar.html
    # get unit complex, each element has (cosx, isinx)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_cis

def RoPEv2(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
):
    xq_ = xq.float().reshape(*xq.shape[:-1], -1, 2)
    xk_ = xk.float().reshape(*xk.shape[:-1], -1, 2)
    
    xq_ = torch.view_as_complex(xq_)
    xk_ = torch.view_as_complex(xk_)
    
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)

### Attn
def attention(q, k, v, mask=None, dropout=None, use_RoPEv1=True, use_RoPEv2=False):
    # q.shape: (bs, head, seq_len, dk)
    # k.shape: (bs, head, seq_len, dk)
    # v.shape: (bs, head, seq_len, dk)
    assert not (use_RoPEv1 and use_RoPEv2)

    if use_RoPEv1:
        q, k = RoPEv1(q, k)
    elif use_RoPEv2:
        freqs_cis = precompute_freqs_cis(q.shape[-1],q.shape[-2])
        q, k = RoPEv2(q, k, freqs_cis)
    
    print(q.shape, k.shape, v.shape)

    d_k = k.size()[-1]

    att_logits = torch.matmul(q, k.transpose(-2, -1))  # (bs, head, seq_len, seq_len)
    att_logits /= math.sqrt(d_k)

    if mask is not None:
        att_logits = att_logits.masked_fill(mask == 0, -1e9)

    att_scores = F.softmax(att_logits, dim=-1)  # (bs, head, seq_len, seq_len)

    if dropout is not None:
        att_scores = dropout(att_scores)

    # (bs, head, seq_len, seq_len) * (bs, head, seq_len, dk) = (bs, head, seq_len, dk)
    return torch.matmul(att_scores, v)


if __name__ == '__main__':
    # (bs, head, seq_len, dk)
    q = torch.randn((8, 12, 10, 32))
    k = torch.randn((8, 12, 10, 32))
    v = torch.randn((8, 12, 10, 32))

    # check result
    resultv1 = attention(q, k, v, use_RoPEv1=True, use_RoPEv2=False)
    resultv2 = attention(q, k, v, use_RoPEv1=False, use_RoPEv2=True)
    print(torch.allclose(resultv1, resultv2, rtol=1e-2))

Pytorch中的kaiming_uniform中标准差stdv乘了一个因子$\sqrt{3}$

def kaiming_uniform_(
    tensor: Tensor, a: float = 0, mode: str = 'fan_in', nonlinearity: str = 'leaky_relu'
):
    #......
    std = gain / math.sqrt(fan)
    bound = math.sqrt(3.0) * std  # Calculate uniform bounds from standard deviation
    with torch.no_grad():
        return tensor.uniform_(-bound, bound)

为什么会出现这个因子,原因其实和物理实验里B类误差需要除去$\sqrt{3}$一样。
均匀分布的方差是$Var = \frac {(b-a)^2}{12}$,而我们真正想要的bound是$(b-a)/2$,所以有$Var = \frac{1}{3} {bound}^2$,标准差则是$\frac{1}{\sqrt 3} bound$。所以求出标准差stdv后,还需要乘一个$\sqrt{3}$得到真正的bound。

不过理论虽然如此,实践中这个因子似乎已经被抛弃了。
https://github.com/pytorch/pytorch/issues/57109