mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-06-06 00:04:50 +00:00
fix: resolve CUDA OOM in train_grpo.py on GPUs with <=8GB VRAM
This commit is contained in:
@@ -83,23 +83,27 @@ def grpo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_mod
|
||||
max_new_tokens=args.max_gen_len,
|
||||
temperature=0.8,
|
||||
)
|
||||
outputs = rollout_result.output_ids
|
||||
completion_ids = rollout_result.completion_ids
|
||||
outputs = rollout_result.output_ids.clone()
|
||||
completion_ids = rollout_result.completion_ids.clone()
|
||||
completions = rollout_result.completions
|
||||
old_per_token_logps = rollout_result.per_token_logps.to(args.device)
|
||||
old_per_token_logps = rollout_result.per_token_logps.to(args.device).detach()
|
||||
prompt_lens = rollout_result.prompt_lens.to(args.device)
|
||||
full_mask = (outputs != tokenizer.pad_token_id).long()
|
||||
logp_pos = prompt_lens.unsqueeze(1) - 1 + torch.arange(completion_ids.size(1), device=args.device).unsqueeze(0)
|
||||
|
||||
# 在 policy/ref 前向占用显存之前先算 rewards,此时显存最充裕
|
||||
torch.cuda.empty_cache()
|
||||
rewards = calculate_rewards(prompts, completions, reward_model).to(args.device) # [B*num_gen]
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
model_unwrapped = model.module if isinstance(model, DistributedDataParallel) else model
|
||||
with autocast_ctx:
|
||||
res = model_unwrapped(outputs, attention_mask=full_mask)
|
||||
aux_loss = res.aux_loss if lm_config.use_moe else torch.tensor(0.0, device=args.device)
|
||||
per_token_logps = F.log_softmax(res.logits[:, :-1, :], dim=-1).gather(2, outputs[:, 1:].unsqueeze(-1)).squeeze(-1).gather(1, logp_pos)
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
ref_per_token_logps = F.log_softmax(ref_model(outputs, attention_mask=full_mask).logits[:, :-1, :], dim=-1).gather(2, outputs[:, 1:].unsqueeze(-1)).squeeze(-1).gather(1, logp_pos)
|
||||
rewards = calculate_rewards(prompts, completions, reward_model).to(args.device) # [B*num_gen]
|
||||
|
||||
if args.debug_mode and is_main_process() and step % args.debug_interval == 0:
|
||||
for i in range(len(prompts)):
|
||||
|
||||
Reference in New Issue
Block a user