无事水——RLHF PPO ppo_epochs

其他 · 2024-06-04 · 1419 人浏览

RLHF PPO的时候为了节约成本,训练时一个batch的数据会多次利用,即多次更新actor model和critic model。

和Reference KL distance用于限制PPO模型和原始SFT模型之间的距离(防止PPO训歪,这一项是加在Reward model产生的R中,如deepspeed chat)一样,多次更新actor model也需要有原始actor model作为约束,因此actor loss计算中会有一项$logits/old\_logits$(和原始KL相比,去除了log,近端策略优化裁剪)

actor loss代码如下(from deepspeed chat)

    def actor_loss_fn(self, logprobs, old_logprobs, advantages, mask):
        ## policy gradient loss
        log_ratio = (logprobs - old_logprobs) * mask
        ratio = torch.exp(log_ratio)
        pg_loss1 = -advantages * ratio
        pg_loss2 = -advantages * torch.clamp(ratio, 1.0 - self.cliprange,
                                             1.0 + self.cliprange)
        pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2) * mask) / mask.sum()
        return pg_loss

那么一个batch第一次迭代的时候actor loss对应的ratio是1(因为actor model没有更新,两次forward产生的logits一致),那这次更新actor/critic model只依赖所谓的advantages(与R和V有关)。

  1. shhaslchot 2024-11-19

    真棒!

  2. chhnsvmvyd 2024-11-13

    真好呢

  3. Axuanz (作者)  2024-07-20

    final email test.

    1. Axuanz (作者)  2024-07-20
      @Axuanz

      reply.

  4. hls开山大弟子 2024-07-19

    好评

    1. Axuanz (作者)  2024-07-20
      @hls开山大弟子

      好评

Theme Jasmine by Kent Liao