diff --git a/train_gated_ppo.py b/train_gated_ppo.py index 6aa966b..58135c2 100644 --- a/train_gated_ppo.py +++ b/train_gated_ppo.py @@ -166,7 +166,7 @@ def ppo_train_epoch(epoch, loader, iters, old_actor_model, ref_model, actor_sche # 修改部分:添加门控 ratio = torch.exp(actor_logp - old_logp) # [B] - ratio = ratio * torch.sigmoid(0.5 * ratio) + ratio = ratio * torch.sigmoid(0.1 * ratio) surr1 = ratio * advantages # [B] surr2 = torch.clamp(ratio, 1.0 - args.clip_epsilon, 1.0 + args.clip_epsilon) * advantages # [B]