RLHF PPO的时候为了节约成本,训练时一个batch的数据会多次利用,即多次更新actor model和critic model。
和Reference KL distance用于限制PPO模型和原始SFT模型之间的距离(防止PPO训歪,这一项是加在Reward model产生的R中,如deepspeed chat)一样,多次更新actor model也需要有原始actor model作为约束,因此actor loss计算中会有一项$logits/old\_logits$(和原始KL相比,去除了log,近端策略优化裁剪)
actor loss代码如下(from deepspeed chat)
def actor_loss_fn(self, logprobs, old_logprobs, advantages, mask):
## policy gradient loss
log_ratio = (logprobs - old_logprobs) * mask
ratio = torch.exp(log_ratio)
pg_loss1 = -advantages * ratio
pg_loss2 = -advantages * torch.clamp(ratio, 1.0 - self.cliprange,
1.0 + self.cliprange)
pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2) * mask) / mask.sum()
return pg_loss
那么一个batch第一次迭代的时候actor loss对应的ratio是1(因为actor model没有更新,两次forward产生的logits一致),那这次更新actor/critic model只依赖所谓的advantages(与R和V有关)。
真棒!
真好呢
final email test.
reply.
好评
好评