From b113b494cb33dd1907aae5521ee9d7f29804c120 Mon Sep 17 00:00:00 2001 From: Dxpsk <1395693839@qq.com> Date: Tue, 24 Mar 2026 17:15:55 +0800 Subject: [PATCH] merge redundant forward passes for logps and aux_loss (in train_grpo.py) --- trainer/train_grpo.py | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/trainer/train_grpo.py b/trainer/train_grpo.py index 5e63779..a6c3526 100755 --- a/trainer/train_grpo.py +++ b/trainer/train_grpo.py @@ -110,22 +110,36 @@ def grpo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_token completion_ids = outputs[:, prompt_inputs["input_ids"].size(1):] # [B*num_gen, R] - def get_per_token_logps(mdl, input_ids, n_keep): + def get_per_token_logps_and_aux(mdl, input_ids, n_keep, use_moe=False): input_ids = input_ids.detach().clone() if input_ids.is_inference() else input_ids - logits = mdl(input_ids, logits_to_keep=n_keep + 1).logits[:, :-1, :] + + model_output = mdl(input_ids, logits_to_keep=n_keep + 1) + + logits = model_output.logits[:, :-1, :] per_token_logps = [] for logits_row, ids_row in zip(logits, input_ids[:, -n_keep:]): ids_row = ids_row.detach().clone() if ids_row.is_inference() else ids_row per_token_logps.append(torch.gather(logits_row.log_softmax(dim=-1), 1, ids_row.unsqueeze(1)).squeeze(1)) - return torch.stack(per_token_logps) + + stacked_logps = torch.stack(per_token_logps) + if use_moe: + aux_loss = getattr(model_output, 'aux_loss', None) + if aux_loss is None: + aux_loss = torch.tensor(0.0, device=input_ids.device) + else: + aux_loss = torch.tensor(0.0, device=input_ids.device) + return stacked_logps, aux_loss with autocast_ctx: - per_token_logps = get_per_token_logps(model, outputs, completion_ids.size(1)) # [B*num_gen, R] - res = model(outputs) if lm_config.use_moe else None - aux_loss = res.aux_loss if res is not None else torch.tensor(0.0, device=args.device) + per_token_logps, aux_loss = get_per_token_logps_and_aux( + model, + outputs, + completion_ids.size(1), + use_moe=lm_config.use_moe + ) with torch.no_grad(): - ref_per_token_logps = get_per_token_logps(ref_model, outputs, completion_ids.size(1)) # [B*num_gen, R] + ref_per_token_logps, _ = get_per_token_logps_and_aux(ref_model, outputs, completion_ids.size(1)) # [B*num_gen, R] completions = tokenizer.batch_decode(completion_ids, skip_special_tokens=True) rewards = calculate_rewards(prompts, completions, reward_model, reward_tokenizer).to(args.device) # [B*num_gen]