From 9c98cabc9a71820bb95fb55033cfda7816c2c35f Mon Sep 17 00:00:00 2001 From: jingyaogong Date: Sat, 15 Nov 2025 18:25:37 +0800 Subject: [PATCH] [fix] prompt length calculate --- trainer/train_ppo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trainer/train_ppo.py b/trainer/train_ppo.py index 1cdb074..53a734f 100644 --- a/trainer/train_ppo.py +++ b/trainer/train_ppo.py @@ -124,7 +124,7 @@ def ppo_train_epoch(epoch, loader, iters, old_actor_model, ref_model, actor_sche prompts = batch["prompt"] # list[str], length B enc = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True, max_length=args.max_seq_len).to(args.device) # input_ids: [B, P], attention_mask: [B, P] - prompt_lengths = enc.attention_mask.sum(dim=1) # [B] + prompt_lengths = torch.full((enc.input_ids.size(0),), enc.input_ids.shape[1], dtype=torch.long, device=enc.input_ids.device) # [B] with torch.no_grad(): # DDP 模型需要使用 .module 访问 generate 方法