diff --git a/trainer/train_agent.py b/trainer/train_agent.py index 2138fb2..e652c3e 100644 --- a/trainer/train_agent.py +++ b/trainer/train_agent.py @@ -269,6 +269,8 @@ def rl_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_model old_per_token_logps = torch.tensor([old_logps + [0.0] * ((max_len - 1) - len(old_logps)) for _, _, _, old_logps in packed_samples], device=args.device, dtype=torch.float32) full_mask = (input_ids != tokenizer.pad_token_id).long() + rewards = calculate_rewards(prompts, completions, gt_batch, tools_batch, args.num_generations, reward_model, device=args.device, turn_outputs_batch=turn_outputs_batch, unfinished_batch=unfinished_batch) + model_unwrapped = model.module if isinstance(model, DistributedDataParallel) else model with autocast_ctx: res = model_unwrapped(input_ids, attention_mask=full_mask) @@ -288,7 +290,6 @@ def rl_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_model completion_mask = completion_mask * (pos <= eos_idx.unsqueeze(1)).float() token_counts = completion_mask.sum(dim=1) valid_rows = token_counts > 0 - rewards = calculate_rewards(prompts, completions, gt_batch, tools_batch, args.num_generations, reward_model, device=args.device, turn_outputs_batch=turn_outputs_batch, unfinished_batch=unfinished_batch) if args.debug_mode and is_main_process() and step % args.debug_interval == 0: for i in range(len(messages_batch)): diff --git a/trainer/train_grpo.py b/trainer/train_grpo.py index 2497ed6..52f7c7d 100755 --- a/trainer/train_grpo.py +++ b/trainer/train_grpo.py @@ -83,18 +83,15 @@ def grpo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_mod max_new_tokens=args.max_gen_len, temperature=0.8, ) - outputs = rollout_result.output_ids.clone() - completion_ids = rollout_result.completion_ids.clone() + outputs = rollout_result.output_ids + completion_ids = rollout_result.completion_ids completions = rollout_result.completions old_per_token_logps = rollout_result.per_token_logps.to(args.device).detach() prompt_lens = rollout_result.prompt_lens.to(args.device) full_mask = (outputs != tokenizer.pad_token_id).long() logp_pos = prompt_lens.unsqueeze(1) - 1 + torch.arange(completion_ids.size(1), device=args.device).unsqueeze(0) - # 在 policy/ref 前向占用显存之前先算 rewards,此时显存最充裕 - torch.cuda.empty_cache() rewards = calculate_rewards(prompts, completions, reward_model).to(args.device) # [B*num_gen] - torch.cuda.empty_cache() model_unwrapped = model.module if isinstance(model, DistributedDataParallel) else model with autocast_ctx: