mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-06-06 00:04:50 +00:00
[feat] reduce RL memory
This commit is contained in:
@@ -269,6 +269,8 @@ def rl_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_model
|
||||
old_per_token_logps = torch.tensor([old_logps + [0.0] * ((max_len - 1) - len(old_logps)) for _, _, _, old_logps in packed_samples], device=args.device, dtype=torch.float32)
|
||||
full_mask = (input_ids != tokenizer.pad_token_id).long()
|
||||
|
||||
rewards = calculate_rewards(prompts, completions, gt_batch, tools_batch, args.num_generations, reward_model, device=args.device, turn_outputs_batch=turn_outputs_batch, unfinished_batch=unfinished_batch)
|
||||
|
||||
model_unwrapped = model.module if isinstance(model, DistributedDataParallel) else model
|
||||
with autocast_ctx:
|
||||
res = model_unwrapped(input_ids, attention_mask=full_mask)
|
||||
@@ -288,7 +290,6 @@ def rl_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_model
|
||||
completion_mask = completion_mask * (pos <= eos_idx.unsqueeze(1)).float()
|
||||
token_counts = completion_mask.sum(dim=1)
|
||||
valid_rows = token_counts > 0
|
||||
rewards = calculate_rewards(prompts, completions, gt_batch, tools_batch, args.num_generations, reward_model, device=args.device, turn_outputs_batch=turn_outputs_batch, unfinished_batch=unfinished_batch)
|
||||
|
||||
if args.debug_mode and is_main_process() and step % args.debug_interval == 0:
|
||||
for i in range(len(messages_batch)):
|
||||
|
||||
@@ -83,18 +83,15 @@ def grpo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_mod
|
||||
max_new_tokens=args.max_gen_len,
|
||||
temperature=0.8,
|
||||
)
|
||||
outputs = rollout_result.output_ids.clone()
|
||||
completion_ids = rollout_result.completion_ids.clone()
|
||||
outputs = rollout_result.output_ids
|
||||
completion_ids = rollout_result.completion_ids
|
||||
completions = rollout_result.completions
|
||||
old_per_token_logps = rollout_result.per_token_logps.to(args.device).detach()
|
||||
prompt_lens = rollout_result.prompt_lens.to(args.device)
|
||||
full_mask = (outputs != tokenizer.pad_token_id).long()
|
||||
logp_pos = prompt_lens.unsqueeze(1) - 1 + torch.arange(completion_ids.size(1), device=args.device).unsqueeze(0)
|
||||
|
||||
# 在 policy/ref 前向占用显存之前先算 rewards,此时显存最充裕
|
||||
torch.cuda.empty_cache()
|
||||
rewards = calculate_rewards(prompts, completions, reward_model).to(args.device) # [B*num_gen]
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
model_unwrapped = model.module if isinstance(model, DistributedDataParallel) else model
|
||||
with autocast_ctx:
|
||||
|
||||
Reference in New Issue
Block a user