merge redundant forward passes for logps and aux_loss (in train_grpo.py)

This commit is contained in:
Dxpsk 2026-03-24 17:15:55 +08:00
parent 349e74ec7b
commit b113b494cb

View File

@ -110,22 +110,36 @@ def grpo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_token
completion_ids = outputs[:, prompt_inputs["input_ids"].size(1):] # [B*num_gen, R]
def get_per_token_logps(mdl, input_ids, n_keep):
def get_per_token_logps_and_aux(mdl, input_ids, n_keep, use_moe=False):
input_ids = input_ids.detach().clone() if input_ids.is_inference() else input_ids
logits = mdl(input_ids, logits_to_keep=n_keep + 1).logits[:, :-1, :]
model_output = mdl(input_ids, logits_to_keep=n_keep + 1)
logits = model_output.logits[:, :-1, :]
per_token_logps = []
for logits_row, ids_row in zip(logits, input_ids[:, -n_keep:]):
ids_row = ids_row.detach().clone() if ids_row.is_inference() else ids_row
per_token_logps.append(torch.gather(logits_row.log_softmax(dim=-1), 1, ids_row.unsqueeze(1)).squeeze(1))
return torch.stack(per_token_logps)
stacked_logps = torch.stack(per_token_logps)
if use_moe:
aux_loss = getattr(model_output, 'aux_loss', None)
if aux_loss is None:
aux_loss = torch.tensor(0.0, device=input_ids.device)
else:
aux_loss = torch.tensor(0.0, device=input_ids.device)
return stacked_logps, aux_loss
with autocast_ctx:
per_token_logps = get_per_token_logps(model, outputs, completion_ids.size(1)) # [B*num_gen, R]
res = model(outputs) if lm_config.use_moe else None
aux_loss = res.aux_loss if res is not None else torch.tensor(0.0, device=args.device)
per_token_logps, aux_loss = get_per_token_logps_and_aux(
model,
outputs,
completion_ids.size(1),
use_moe=lm_config.use_moe
)
with torch.no_grad():
ref_per_token_logps = get_per_token_logps(ref_model, outputs, completion_ids.size(1)) # [B*num_gen, R]
ref_per_token_logps, _ = get_per_token_logps_and_aux(ref_model, outputs, completion_ids.size(1)) # [B*num_gen, R]
completions = tokenizer.batch_decode(completion_ids, skip_special_tokens=True)
rewards = calculate_rewards(prompts, completions, reward_model, reward_tokenizer).to(args.device) # [B*num_gen]