From 10776417aae969466d4fece8d5b4d5e808ebe89e Mon Sep 17 00:00:00 2001 From: TKiteRunner <1291390722@qq.com> Date: Wed, 6 May 2026 14:14:15 +0800 Subject: [PATCH] fix: resolve CUDA OOM in train_grpo.py on GPUs with <=8GB VRAM --- trainer/train_grpo.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/trainer/train_grpo.py b/trainer/train_grpo.py index ea63419..2497ed6 100755 --- a/trainer/train_grpo.py +++ b/trainer/train_grpo.py @@ -83,23 +83,27 @@ def grpo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_mod max_new_tokens=args.max_gen_len, temperature=0.8, ) - outputs = rollout_result.output_ids - completion_ids = rollout_result.completion_ids + outputs = rollout_result.output_ids.clone() + completion_ids = rollout_result.completion_ids.clone() completions = rollout_result.completions - old_per_token_logps = rollout_result.per_token_logps.to(args.device) + old_per_token_logps = rollout_result.per_token_logps.to(args.device).detach() prompt_lens = rollout_result.prompt_lens.to(args.device) full_mask = (outputs != tokenizer.pad_token_id).long() logp_pos = prompt_lens.unsqueeze(1) - 1 + torch.arange(completion_ids.size(1), device=args.device).unsqueeze(0) + # 在 policy/ref 前向占用显存之前先算 rewards,此时显存最充裕 + torch.cuda.empty_cache() + rewards = calculate_rewards(prompts, completions, reward_model).to(args.device) # [B*num_gen] + torch.cuda.empty_cache() + model_unwrapped = model.module if isinstance(model, DistributedDataParallel) else model with autocast_ctx: res = model_unwrapped(outputs, attention_mask=full_mask) aux_loss = res.aux_loss if lm_config.use_moe else torch.tensor(0.0, device=args.device) per_token_logps = F.log_softmax(res.logits[:, :-1, :], dim=-1).gather(2, outputs[:, 1:].unsqueeze(-1)).squeeze(-1).gather(1, logp_pos) - + with torch.no_grad(): ref_per_token_logps = F.log_softmax(ref_model(outputs, attention_mask=full_mask).logits[:, :-1, :], dim=-1).gather(2, outputs[:, 1:].unsqueeze(-1)).squeeze(-1).gather(1, logp_pos) - rewards = calculate_rewards(prompts, completions, reward_model).to(args.device) # [B*num_gen] if args.debug_mode and is_main_process() and step % args.debug_interval == 0: for i in range(len(prompts)):