fix: resolve CUDA OOM in train_grpo.py on GPUs with <=8GB VRAM

This commit is contained in:
TKiteRunner
2026-05-06 14:14:15 +08:00
parent 5020dc9dd4
commit 10776417aa
+9 -5
View File
@@ -83,23 +83,27 @@ def grpo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_mod
max_new_tokens=args.max_gen_len,
temperature=0.8,
)
outputs = rollout_result.output_ids
completion_ids = rollout_result.completion_ids
outputs = rollout_result.output_ids.clone()
completion_ids = rollout_result.completion_ids.clone()
completions = rollout_result.completions
old_per_token_logps = rollout_result.per_token_logps.to(args.device)
old_per_token_logps = rollout_result.per_token_logps.to(args.device).detach()
prompt_lens = rollout_result.prompt_lens.to(args.device)
full_mask = (outputs != tokenizer.pad_token_id).long()
logp_pos = prompt_lens.unsqueeze(1) - 1 + torch.arange(completion_ids.size(1), device=args.device).unsqueeze(0)
# 在 policy/ref 前向占用显存之前先算 rewards,此时显存最充裕
torch.cuda.empty_cache()
rewards = calculate_rewards(prompts, completions, reward_model).to(args.device) # [B*num_gen]
torch.cuda.empty_cache()
model_unwrapped = model.module if isinstance(model, DistributedDataParallel) else model
with autocast_ctx:
res = model_unwrapped(outputs, attention_mask=full_mask)
aux_loss = res.aux_loss if lm_config.use_moe else torch.tensor(0.0, device=args.device)
per_token_logps = F.log_softmax(res.logits[:, :-1, :], dim=-1).gather(2, outputs[:, 1:].unsqueeze(-1)).squeeze(-1).gather(1, logp_pos)
with torch.no_grad():
ref_per_token_logps = F.log_softmax(ref_model(outputs, attention_mask=full_mask).logits[:, :-1, :], dim=-1).gather(2, outputs[:, 1:].unsqueeze(-1)).squeeze(-1).gather(1, logp_pos)
rewards = calculate_rewards(prompts, completions, reward_model).to(args.device) # [B*num_gen]
if args.debug_mode and is_main_process() and step % args.debug_interval == 0:
for i in range(len(prompts)):