[feat] reduce RL memory

This commit is contained in:
jingyaogong
2026-05-06 15:07:28 +08:00
parent e73a407f7a
commit 802c15b2b4
2 changed files with 4 additions and 6 deletions
+2 -1
View File
@@ -269,6 +269,8 @@ def rl_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_model
old_per_token_logps = torch.tensor([old_logps + [0.0] * ((max_len - 1) - len(old_logps)) for _, _, _, old_logps in packed_samples], device=args.device, dtype=torch.float32)
full_mask = (input_ids != tokenizer.pad_token_id).long()
rewards = calculate_rewards(prompts, completions, gt_batch, tools_batch, args.num_generations, reward_model, device=args.device, turn_outputs_batch=turn_outputs_batch, unfinished_batch=unfinished_batch)
model_unwrapped = model.module if isinstance(model, DistributedDataParallel) else model
with autocast_ctx:
res = model_unwrapped(input_ids, attention_mask=full_mask)
@@ -288,7 +290,6 @@ def rl_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_model
completion_mask = completion_mask * (pos <= eos_idx.unsqueeze(1)).float()
token_counts = completion_mask.sum(dim=1)
valid_rows = token_counts > 0
rewards = calculate_rewards(prompts, completions, gt_batch, tools_batch, args.num_generations, reward_model, device=args.device, turn_outputs_batch=turn_outputs_batch, unfinished_batch=unfinished_batch)
if args.debug_mode and is_main_process() and step % args.debug_interval == 0:
for i in range(len(messages_batch)):
+2 -5
View File
@@ -83,18 +83,15 @@ def grpo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_mod
max_new_tokens=args.max_gen_len,
temperature=0.8,
)
outputs = rollout_result.output_ids.clone()
completion_ids = rollout_result.completion_ids.clone()
outputs = rollout_result.output_ids
completion_ids = rollout_result.completion_ids
completions = rollout_result.completions
old_per_token_logps = rollout_result.per_token_logps.to(args.device).detach()
prompt_lens = rollout_result.prompt_lens.to(args.device)
full_mask = (outputs != tokenizer.pad_token_id).long()
logp_pos = prompt_lens.unsqueeze(1) - 1 + torch.arange(completion_ids.size(1), device=args.device).unsqueeze(0)
# 在 policy/ref 前向占用显存之前先算 rewards,此时显存最充裕
torch.cuda.empty_cache()
rewards = calculate_rewards(prompts, completions, reward_model).to(args.device) # [B*num_gen]
torch.cuda.empty_cache()
model_unwrapped = model.module if isinstance(model, DistributedDataParallel) else model
with autocast_ctx: