mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-06-06 00:04:50 +00:00
[fix] rollout bugs
This commit is contained in:
+22
-17
@@ -82,7 +82,8 @@ class TorchRolloutEngine(RolloutEngine):
|
|||||||
) # [B*num_gen, P+R]
|
) # [B*num_gen, P+R]
|
||||||
prompt_len = prompt_ids.size(1)
|
prompt_len = prompt_ids.size(1)
|
||||||
completion_ids = output_ids[:, prompt_len:] # [B*num_gen, R]
|
completion_ids = output_ids[:, prompt_len:] # [B*num_gen, R]
|
||||||
per_token_logps = compute_per_token_logps(self.policy_model, output_ids, completion_ids.size(1))
|
full_mask = (output_ids != self.tokenizer.pad_token_id).long()
|
||||||
|
per_token_logps = compute_per_token_logps(self.policy_model, output_ids, completion_ids.size(1), attention_mask=full_mask)
|
||||||
completions = self.tokenizer.batch_decode(completion_ids, skip_special_tokens=True)
|
completions = self.tokenizer.batch_decode(completion_ids, skip_special_tokens=True)
|
||||||
return RolloutResult(output_ids, completion_ids, per_token_logps, completions)
|
return RolloutResult(output_ids, completion_ids, per_token_logps, completions)
|
||||||
|
|
||||||
@@ -158,28 +159,32 @@ class SGLangRolloutEngine(RolloutEngine):
|
|||||||
def pad_to_tensor(seqs, max_len, pad_val=0):
|
def pad_to_tensor(seqs, max_len, pad_val=0):
|
||||||
return torch.tensor([s + [pad_val] * (max_len - len(s)) for s in seqs], device=device)
|
return torch.tensor([s + [pad_val] * (max_len - len(s)) for s in seqs], device=device)
|
||||||
|
|
||||||
|
pad_id = self.tokenizer.pad_token_id
|
||||||
return RolloutResult(
|
return RolloutResult(
|
||||||
output_ids=pad_to_tensor(all_output_ids, max_out_len),
|
output_ids=pad_to_tensor(all_output_ids, max_out_len, pad_val=pad_id),
|
||||||
completion_ids=pad_to_tensor(all_completion_ids, max_comp_len),
|
completion_ids=pad_to_tensor(all_completion_ids, max_comp_len, pad_val=pad_id),
|
||||||
per_token_logps=pad_to_tensor(all_logprobs, max_comp_len, pad_val=0.0),
|
per_token_logps=pad_to_tensor(all_logprobs, max_comp_len, pad_val=0.0),
|
||||||
completions=completions,
|
completions=completions,
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_policy(self, model: torch.nn.Module):
|
def update_policy(self, model: torch.nn.Module):
|
||||||
if dist.is_initialized() and dist.get_rank() != 0: return True
|
ok = True
|
||||||
unwrapped = model.module if isinstance(model, DistributedDataParallel) else model
|
if not dist.is_initialized() or dist.get_rank() == 0:
|
||||||
unwrapped = getattr(unwrapped, '_orig_mod', unwrapped)
|
unwrapped = model.module if isinstance(model, DistributedDataParallel) else model
|
||||||
abs_path = os.path.abspath(self.shared_ckpt_path)
|
unwrapped = getattr(unwrapped, '_orig_mod', unwrapped)
|
||||||
state_dict = {k: v.detach().half().cpu() for k, v in unwrapped.state_dict().items()}
|
abs_path = os.path.abspath(self.shared_ckpt_path)
|
||||||
unwrapped.save_pretrained(abs_path, state_dict=state_dict, safe_serialization=False)
|
state_dict = {k: v.detach().half().cpu() for k, v in unwrapped.state_dict().items()}
|
||||||
self.tokenizer.save_pretrained(abs_path)
|
unwrapped.save_pretrained(abs_path, state_dict=state_dict, safe_serialization=False)
|
||||||
resp = self.http.post(
|
self.tokenizer.save_pretrained(abs_path)
|
||||||
f"{self.base_url}/update_weights_from_disk",
|
resp = self.http.post(
|
||||||
json={"model_path": abs_path},
|
f"{self.base_url}/update_weights_from_disk",
|
||||||
timeout=self.timeout
|
json={"model_path": abs_path},
|
||||||
)
|
timeout=self.timeout
|
||||||
if resp.status_code != 200: print(f"[SGLANG WARNING] update_weights 失败: {resp.status_code}, {resp.text}")
|
)
|
||||||
return resp.status_code == 200
|
if resp.status_code != 200: print(f"[SGLANG WARNING] update_weights 失败: {resp.status_code}, {resp.text}")
|
||||||
|
ok = resp.status_code == 200
|
||||||
|
if dist.is_initialized(): dist.barrier()
|
||||||
|
return ok
|
||||||
|
|
||||||
def flush_cache(self) -> bool:
|
def flush_cache(self) -> bool:
|
||||||
resp = self.http.post(f"{self.base_url}/flush_cache", timeout=30)
|
resp = self.http.post(f"{self.base_url}/flush_cache", timeout=30)
|
||||||
|
|||||||
@@ -267,16 +267,17 @@ def rl_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_model
|
|||||||
prompt_lens = torch.tensor([prompt_len for _, _, prompt_len, _ in packed_samples], device=args.device)
|
prompt_lens = torch.tensor([prompt_len for _, _, prompt_len, _ in packed_samples], device=args.device)
|
||||||
full_response_masks = torch.tensor([mask + [0] * (max_len - len(mask)) for _, mask, _, _ in packed_samples], device=args.device, dtype=torch.float32)
|
full_response_masks = torch.tensor([mask + [0] * (max_len - len(mask)) for _, mask, _, _ in packed_samples], device=args.device, dtype=torch.float32)
|
||||||
old_per_token_logps = torch.tensor([old_logps + [0.0] * ((max_len - 1) - len(old_logps)) for _, _, _, old_logps in packed_samples], device=args.device, dtype=torch.float32)
|
old_per_token_logps = torch.tensor([old_logps + [0.0] * ((max_len - 1) - len(old_logps)) for _, _, _, old_logps in packed_samples], device=args.device, dtype=torch.float32)
|
||||||
|
full_mask = (input_ids != tokenizer.pad_token_id).long()
|
||||||
|
|
||||||
model_unwrapped = model.module if isinstance(model, DistributedDataParallel) else model
|
model_unwrapped = model.module if isinstance(model, DistributedDataParallel) else model
|
||||||
with autocast_ctx:
|
with autocast_ctx:
|
||||||
res = model_unwrapped(input_ids)
|
res = model_unwrapped(input_ids, attention_mask=full_mask)
|
||||||
aux_loss = res.aux_loss if lm_config.use_moe else torch.tensor(0.0, device=args.device)
|
aux_loss = res.aux_loss if lm_config.use_moe else torch.tensor(0.0, device=args.device)
|
||||||
logits = res.logits[:, :-1, :]
|
logits = res.logits[:, :-1, :]
|
||||||
per_token_logps = F.log_softmax(logits, dim=-1).gather(2, input_ids[:, 1:].unsqueeze(-1)).squeeze(-1)
|
per_token_logps = F.log_softmax(logits, dim=-1).gather(2, input_ids[:, 1:].unsqueeze(-1)).squeeze(-1)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
ref_per_token_logps = compute_per_token_logps(ref_model, input_ids, input_ids.size(1) - 1)
|
ref_per_token_logps = compute_per_token_logps(ref_model, input_ids, input_ids.size(1) - 1, attention_mask=full_mask)
|
||||||
|
|
||||||
completion_mask = full_response_masks[:, 1:]
|
completion_mask = full_response_masks[:, 1:]
|
||||||
is_eos = (input_ids[:, 1:] == tokenizer.eos_token_id) & completion_mask.bool()
|
is_eos = (input_ids[:, 1:] == tokenizer.eos_token_id) & completion_mask.bool()
|
||||||
@@ -332,7 +333,6 @@ def rl_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_model
|
|||||||
if step % args.accumulation_steps == 0:
|
if step % args.accumulation_steps == 0:
|
||||||
if args.grad_clip > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
if args.grad_clip > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||||
optimizer.step(); scheduler.step(); optimizer.zero_grad()
|
optimizer.step(); scheduler.step(); optimizer.zero_grad()
|
||||||
if is_main_process() and step % args.save_interval == 0: rollout_engine.update_policy(model)
|
|
||||||
|
|
||||||
if step % args.log_interval == 0 or step == iters:
|
if step % args.log_interval == 0 or step == iters:
|
||||||
pl = loss.item() * args.accumulation_steps
|
pl = loss.item() * args.accumulation_steps
|
||||||
@@ -359,13 +359,14 @@ def rl_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_model
|
|||||||
model.train()
|
model.train()
|
||||||
del state_dict
|
del state_dict
|
||||||
|
|
||||||
|
if step % args.save_interval == 0 or step == iters: rollout_engine.update_policy(model)
|
||||||
|
|
||||||
del per_token_logps, ref_per_token_logps
|
del per_token_logps, ref_per_token_logps
|
||||||
del completions, rewards, grouped_rewards, mean_r, std_r, advantages, completion_mask
|
del completions, rewards, grouped_rewards, mean_r, std_r, advantages, completion_mask
|
||||||
|
|
||||||
if last_step > start_step and last_step % args.accumulation_steps != 0:
|
if last_step > start_step and last_step % args.accumulation_steps != 0:
|
||||||
if args.grad_clip > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
if args.grad_clip > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||||
optimizer.step(); scheduler.step(); optimizer.zero_grad()
|
optimizer.step(); scheduler.step(); optimizer.zero_grad()
|
||||||
if is_main_process() and last_step % args.save_interval == 0: rollout_engine.update_policy(model)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@@ -470,7 +471,7 @@ if __name__ == "__main__":
|
|||||||
rollout_engine.update_policy(model)
|
rollout_engine.update_policy(model)
|
||||||
if dist.is_initialized():
|
if dist.is_initialized():
|
||||||
model = DistributedDataParallel(model, device_ids=[local_rank])
|
model = DistributedDataParallel(model, device_ids=[local_rank])
|
||||||
if is_main_process(): rollout_engine.update_policy(model)
|
rollout_engine.update_policy(model)
|
||||||
|
|
||||||
for epoch in range(start_epoch, args.epochs):
|
for epoch in range(start_epoch, args.epochs):
|
||||||
train_sampler and train_sampler.set_epoch(epoch)
|
train_sampler and train_sampler.set_epoch(epoch)
|
||||||
|
|||||||
+13
-16
@@ -87,20 +87,17 @@ def grpo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_mod
|
|||||||
completion_ids = rollout_result.completion_ids
|
completion_ids = rollout_result.completion_ids
|
||||||
completions = rollout_result.completions
|
completions = rollout_result.completions
|
||||||
old_per_token_logps = rollout_result.per_token_logps.to(args.device)
|
old_per_token_logps = rollout_result.per_token_logps.to(args.device)
|
||||||
|
full_mask = (outputs != tokenizer.pad_token_id).long()
|
||||||
|
|
||||||
model_unwrapped = model.module if isinstance(model, DistributedDataParallel) else model
|
model_unwrapped = model.module if isinstance(model, DistributedDataParallel) else model
|
||||||
with autocast_ctx:
|
with autocast_ctx:
|
||||||
if use_sglang or lm_config.use_moe:
|
res = model_unwrapped(outputs, attention_mask=full_mask)
|
||||||
res = model_unwrapped(outputs)
|
aux_loss = res.aux_loss if lm_config.use_moe else torch.tensor(0.0, device=args.device)
|
||||||
aux_loss = res.aux_loss if lm_config.use_moe else torch.tensor(0.0, device=args.device)
|
logits = res.logits[:, :-1, :]
|
||||||
logits = res.logits[:, :-1, :]
|
per_token_logps = F.log_softmax(logits, dim=-1).gather(2, outputs[:, 1:].unsqueeze(-1)).squeeze(-1)[:, -completion_ids.size(1):]
|
||||||
per_token_logps = F.log_softmax(logits, dim=-1).gather(2, outputs[:, 1:].unsqueeze(-1)).squeeze(-1)[:, -completion_ids.size(1):]
|
|
||||||
else:
|
|
||||||
aux_loss = torch.tensor(0.0, device=args.device)
|
|
||||||
per_token_logps = rollout_result.per_token_logps
|
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
ref_per_token_logps = compute_per_token_logps(ref_model, outputs, completion_ids.size(1))
|
ref_per_token_logps = compute_per_token_logps(ref_model, outputs, completion_ids.size(1), attention_mask=full_mask)
|
||||||
rewards = calculate_rewards(prompts, completions, reward_model).to(args.device) # [B*num_gen]
|
rewards = calculate_rewards(prompts, completions, reward_model).to(args.device) # [B*num_gen]
|
||||||
|
|
||||||
if args.debug_mode and is_main_process() and step % args.debug_interval == 0:
|
if args.debug_mode and is_main_process() and step % args.debug_interval == 0:
|
||||||
@@ -120,7 +117,7 @@ def grpo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_mod
|
|||||||
|
|
||||||
grouped_rewards = rewards.view(-1, args.num_generations) # [B, num_gen]
|
grouped_rewards = rewards.view(-1, args.num_generations) # [B, num_gen]
|
||||||
mean_r = grouped_rewards.mean(dim=1).repeat_interleave(args.num_generations) # [B*num_gen]
|
mean_r = grouped_rewards.mean(dim=1).repeat_interleave(args.num_generations) # [B*num_gen]
|
||||||
std_r = grouped_rewards.std(dim=1).repeat_interleave(args.num_generations) # [B*num_gen]
|
std_r = grouped_rewards.std(dim=1, unbiased=False).repeat_interleave(args.num_generations) # [B*num_gen]
|
||||||
advantages = (rewards - mean_r) / (std_r + 1e-4) # [B*num_gen]
|
advantages = (rewards - mean_r) / (std_r + 1e-4) # [B*num_gen]
|
||||||
|
|
||||||
is_eos = completion_ids == tokenizer.eos_token_id # [B*num_gen, R]
|
is_eos = completion_ids == tokenizer.eos_token_id # [B*num_gen, R]
|
||||||
@@ -149,7 +146,6 @@ def grpo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_mod
|
|||||||
optimizer.step()
|
optimizer.step()
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
if is_main_process() and step % args.save_interval == 0: rollout_engine.update_policy(model)
|
|
||||||
|
|
||||||
if step % args.log_interval == 0 or step == iters:
|
if step % args.log_interval == 0 or step == iters:
|
||||||
policy_loss_val = loss.item() * args.accumulation_steps
|
policy_loss_val = loss.item() * args.accumulation_steps
|
||||||
@@ -190,16 +186,17 @@ def grpo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_mod
|
|||||||
model.train()
|
model.train()
|
||||||
del state_dict
|
del state_dict
|
||||||
|
|
||||||
|
if step % args.save_interval == 0 or step == iters: rollout_engine.update_policy(model)
|
||||||
|
|
||||||
|
del prompt_inputs, outputs, completion_ids, per_token_logps, ref_per_token_logps
|
||||||
|
del completions, rewards, grouped_rewards, mean_r, std_r, advantages, completion_mask
|
||||||
|
|
||||||
if step > start_step and step % args.accumulation_steps != 0:
|
if step > start_step and step % args.accumulation_steps != 0:
|
||||||
if args.grad_clip > 0:
|
if args.grad_clip > 0:
|
||||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
if is_main_process() and step % args.save_interval == 0: rollout_engine.update_policy(model)
|
|
||||||
|
|
||||||
del prompt_inputs, outputs, completion_ids, per_token_logps, ref_per_token_logps
|
|
||||||
del completions, rewards, grouped_rewards, mean_r, std_r, advantages, completion_mask
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@@ -312,7 +309,7 @@ if __name__ == "__main__":
|
|||||||
rollout_engine.update_policy(model)
|
rollout_engine.update_policy(model)
|
||||||
if dist.is_initialized():
|
if dist.is_initialized():
|
||||||
model = DistributedDataParallel(model, device_ids=[local_rank])
|
model = DistributedDataParallel(model, device_ids=[local_rank])
|
||||||
if is_main_process(): rollout_engine.update_policy(model)
|
rollout_engine.update_policy(model)
|
||||||
|
|
||||||
# ========== 8. 开始训练 ==========
|
# ========== 8. 开始训练 ==========
|
||||||
for epoch in range(start_epoch, args.epochs):
|
for epoch in range(start_epoch, args.epochs):
|
||||||
|
|||||||
@@ -115,7 +115,6 @@ def ppo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, actor_sched
|
|||||||
labels = gen_out[:, 1:].clone() # [B, P+R-1]
|
labels = gen_out[:, 1:].clone() # [B, P+R-1]
|
||||||
seq_len, resp_start = gen_out.size(1) - 1, prompt_length - 1
|
seq_len, resp_start = gen_out.size(1) - 1, prompt_length - 1
|
||||||
resp_mask = torch.arange(seq_len, device=gen_out.device).unsqueeze(0) >= resp_start
|
resp_mask = torch.arange(seq_len, device=gen_out.device).unsqueeze(0) >= resp_start
|
||||||
final_mask = (resp_mask & (~labels.eq(tokenizer.pad_token_id))).float() # [B, P+R-1]
|
|
||||||
B = len(prompts)
|
B = len(prompts)
|
||||||
resp_labels = labels[:, resp_start:] # [B, R]
|
resp_labels = labels[:, resp_start:] # [B, R]
|
||||||
resp_idx = torch.arange(resp_labels.size(1), device=gen_out.device).unsqueeze(0)
|
resp_idx = torch.arange(resp_labels.size(1), device=gen_out.device).unsqueeze(0)
|
||||||
@@ -250,7 +249,7 @@ def ppo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, actor_sched
|
|||||||
actor_optimizer.zero_grad()
|
actor_optimizer.zero_grad()
|
||||||
critic_optimizer.zero_grad()
|
critic_optimizer.zero_grad()
|
||||||
|
|
||||||
if is_main_process() and step % args.save_interval == 0: rollout_engine.update_policy(actor_model)
|
if step % args.save_interval == 0: rollout_engine.update_policy(actor_model)
|
||||||
|
|
||||||
if is_main_process():
|
if is_main_process():
|
||||||
critic_loss_val = value_loss_sum / max(log_count, 1)
|
critic_loss_val = value_loss_sum / max(log_count, 1)
|
||||||
@@ -296,7 +295,7 @@ def ppo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, actor_sched
|
|||||||
del actor_state
|
del actor_state
|
||||||
|
|
||||||
del enc, gen_out, responses_text, rewards, full_mask, values_seq, advantages
|
del enc, gen_out, responses_text, rewards, full_mask, values_seq, advantages
|
||||||
del logits, labels, final_mask, resp_labels, resp_idx, resp_pad_mask, eos_mask, has_eos, eos_pos, resp_lengths, resp_policy_mask, resp_value_mask, old_resp_logp, ref_logp_all, ref_resp_logp
|
del logits, labels, resp_labels, resp_idx, resp_pad_mask, eos_mask, has_eos, eos_pos, resp_lengths, resp_policy_mask, resp_value_mask, old_resp_logp, ref_logp_all, ref_resp_logp
|
||||||
del kl, kl_ref, policy_loss, value_loss, loss, token_rewards, returns, old_resp_values
|
del kl, kl_ref, policy_loss, value_loss, loss, token_rewards, returns, old_resp_values
|
||||||
|
|
||||||
|
|
||||||
@@ -423,7 +422,7 @@ if __name__ == "__main__":
|
|||||||
if dist.is_initialized():
|
if dist.is_initialized():
|
||||||
actor_model = DistributedDataParallel(actor_model, device_ids=[local_rank])
|
actor_model = DistributedDataParallel(actor_model, device_ids=[local_rank])
|
||||||
critic_model = DistributedDataParallel(critic_model, device_ids=[local_rank])
|
critic_model = DistributedDataParallel(critic_model, device_ids=[local_rank])
|
||||||
if is_main_process(): rollout_engine.update_policy(actor_model)
|
rollout_engine.update_policy(actor_model)
|
||||||
|
|
||||||
# ========== 8. 开始训练 ==========
|
# ========== 8. 开始训练 ==========
|
||||||
for epoch in range(start_epoch, args.epochs):
|
for epoch in range(start_epoch, args.epochs):
|
||||||
|
|||||||
Reference in New Issue
Block a user