无事水——RLHF PPO ppo_epochs

其他  ·  2024-06-04

RLHF PPO的时候为了节约成本,训练时一个batch的数据会多次利用,即多次更新actor model和critic model。

和Reference KL distance用于限制PPO模型和原始SFT模型之间的距离(防止PPO训歪,这一项是加在Reward model产生的R中,如deepspeed chat)一样,多次更新actor model也需要有原始actor model作为约束,因此actor loss计算中会有一项$logits/old\_logits$(重要性采样)

actor loss代码如下(from deepspeed chat)

    def actor_loss_fn(self, logprobs, old_logprobs, advantages, mask):
        ## policy gradient loss
        log_ratio = (logprobs - old_logprobs) * mask
        ratio = torch.exp(log_ratio)
        pg_loss1 = -advantages * ratio
        pg_loss2 = -advantages * torch.clamp(ratio, 1.0 - self.cliprange,
                                             1.0 + self.cliprange)
        pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2) * mask) / mask.sum()
        return pg_loss

那么一个batch第一次迭代的时候actor loss对应的ratio是1(因为actor model没有更新,两次forward产生的logits一致),那这次更新critic model只依赖所谓的advantages(与R和V有关,另外actor model不更新,因为advantages和它无关)。

 
评论
hls开山大弟子

好评

Axuanz

好评

Axuanz

final email test.

Axuanz

reply.

chhnsvmvyd

真好呢

shhaslchot

真棒!

JJJYmmm's Blog. All Rights Reserved. Theme Jasmine by Kent Liao.