无事水——RLHF PPO ppo_epochs

其他  ·  2024-06-04

RLHF PPO的时候为了节约成本,训练时一个batch的数据会多次利用,即多次更新actor model和critic model。

和Reference KL distance用于限制PPO模型和原始SFT模型之间的距离(防止PPO训歪,这一项是加在Reward model产生的R中,如deepspeed chat)一样,多次更新actor model也需要有原始actor model作为约束,因此actor loss计算中会有一项$logits/old\_logits$(和原始KL相比,去除了log,近端策略优化裁剪)

actor loss代码如下(from deepspeed chat)

    def actor_loss_fn(self, logprobs, old_logprobs, advantages, mask):
        ## policy gradient loss
        log_ratio = (logprobs - old_logprobs) * mask
        ratio = torch.exp(log_ratio)
        pg_loss1 = -advantages * ratio
        pg_loss2 = -advantages * torch.clamp(ratio, 1.0 - self.cliprange,
                                             1.0 + self.cliprange)
        pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2) * mask) / mask.sum()
        return pg_loss

那么一个batch第一次迭代的时候actor loss对应的ratio是1(因为actor model没有更新,两次forward产生的logits一致),那这次更新actor/critic model只依赖所谓的advantages(与R和V有关)。

 
评论
hls开山大弟子
hls开山大弟子

好评

Axuanz
Axuanz

好评

Axuanz
Axuanz

final email test.

Axuanz
Axuanz

reply.

chhnsvmvyd
chhnsvmvyd

真好呢

shhaslchot
shhaslchot

真棒!

Axuanz的学习日记. All Rights Reserved. Theme Jasmine by Kent Liao.

鄂ICP备2023004395号