mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-05-01 11:48:14 +08:00
[fix] bugs
This commit is contained in:
parent
6361510016
commit
773e451b11
@ -43,6 +43,8 @@ class RolloutResult:
|
||||
completion_ids: Tensor
|
||||
per_token_logps: Tensor
|
||||
completions: List[str]
|
||||
prompt_lens: Tensor
|
||||
completion_mask: Tensor
|
||||
|
||||
|
||||
# ===== Rollout 引擎抽象基类 =====
|
||||
@ -71,12 +73,12 @@ class TorchRolloutEngine(RolloutEngine):
|
||||
ctx = self.autocast_ctx if self.autocast_ctx else nullcontext()
|
||||
with torch.no_grad(), ctx:
|
||||
output_ids = model.generate(
|
||||
input_ids=prompt_ids,
|
||||
attention_mask=attention_mask,
|
||||
input_ids=prompt_ids.repeat_interleave(num_generations, dim=0),
|
||||
attention_mask=attention_mask.repeat_interleave(num_generations, dim=0),
|
||||
max_new_tokens=max_new_tokens,
|
||||
do_sample=True,
|
||||
temperature=temperature,
|
||||
num_return_sequences=num_generations,
|
||||
num_return_sequences=1,
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
eos_token_id=self.tokenizer.eos_token_id,
|
||||
) # [B*num_gen, P+R]
|
||||
@ -85,7 +87,9 @@ class TorchRolloutEngine(RolloutEngine):
|
||||
full_mask = (output_ids != self.tokenizer.pad_token_id).long()
|
||||
per_token_logps = compute_per_token_logps(self.policy_model, output_ids, completion_ids.size(1), attention_mask=full_mask)
|
||||
completions = self.tokenizer.batch_decode(completion_ids, skip_special_tokens=True)
|
||||
return RolloutResult(output_ids, completion_ids, per_token_logps, completions)
|
||||
return RolloutResult(output_ids, completion_ids, per_token_logps, completions,
|
||||
prompt_ids.new_full((output_ids.size(0),), prompt_len),
|
||||
attention_mask.new_ones(output_ids.size(0), completion_ids.size(1)))
|
||||
|
||||
def update_policy(self, model: torch.nn.Module):
|
||||
self.policy_model = model
|
||||
@ -127,7 +131,6 @@ class SGLangRolloutEngine(RolloutEngine):
|
||||
|
||||
all_output_ids, all_completion_ids, all_logprobs = [], [], []
|
||||
completions = []
|
||||
prompt_len = prompt_ids.size(1)
|
||||
|
||||
for i, result in enumerate(results):
|
||||
meta = result.get("meta_info", {})
|
||||
@ -144,7 +147,7 @@ class SGLangRolloutEngine(RolloutEngine):
|
||||
if len(logprobs) < len(completion_ids):
|
||||
logprobs = [0.0] * (len(completion_ids) - len(logprobs)) + logprobs
|
||||
elif len(logprobs) > len(completion_ids):
|
||||
logprobs = logprobs[-len(completion_ids):]
|
||||
logprobs = logprobs[-len(completion_ids):] if completion_ids else []
|
||||
prompt = all_input_ids[i]
|
||||
full_output = prompt + completion_ids
|
||||
all_output_ids.append(full_output)
|
||||
@ -153,8 +156,8 @@ class SGLangRolloutEngine(RolloutEngine):
|
||||
completions.append(self.tokenizer.decode(completion_ids, skip_special_tokens=True))
|
||||
|
||||
device = prompt_ids.device
|
||||
max_out_len = max(len(ids) for ids in all_output_ids)
|
||||
max_comp_len = max(len(ids) for ids in all_completion_ids)
|
||||
max_comp_len = max(1, max(len(ids) for ids in all_completion_ids))
|
||||
max_out_len = max(len(ids) for ids in all_input_ids) + max_comp_len
|
||||
|
||||
def pad_to_tensor(seqs, max_len, pad_val=0):
|
||||
return torch.tensor([s + [pad_val] * (max_len - len(s)) for s in seqs], device=device)
|
||||
@ -165,25 +168,29 @@ class SGLangRolloutEngine(RolloutEngine):
|
||||
completion_ids=pad_to_tensor(all_completion_ids, max_comp_len, pad_val=pad_id),
|
||||
per_token_logps=pad_to_tensor(all_logprobs, max_comp_len, pad_val=0.0),
|
||||
completions=completions,
|
||||
prompt_lens=torch.tensor([len(ids) for ids in all_input_ids], device=device),
|
||||
completion_mask=torch.tensor([[1] * len(ids) + [0] * (max_comp_len - len(ids)) for ids in all_completion_ids], device=device),
|
||||
)
|
||||
|
||||
def update_policy(self, model: torch.nn.Module):
|
||||
ok = True
|
||||
if not dist.is_initialized() or dist.get_rank() == 0:
|
||||
unwrapped = model.module if isinstance(model, DistributedDataParallel) else model
|
||||
unwrapped = getattr(unwrapped, '_orig_mod', unwrapped)
|
||||
abs_path = os.path.abspath(self.shared_ckpt_path)
|
||||
state_dict = {k: v.detach().half().cpu() for k, v in unwrapped.state_dict().items()}
|
||||
unwrapped.save_pretrained(abs_path, state_dict=state_dict, safe_serialization=False)
|
||||
self.tokenizer.save_pretrained(abs_path)
|
||||
resp = self.http.post(
|
||||
f"{self.base_url}/update_weights_from_disk",
|
||||
json={"model_path": abs_path},
|
||||
timeout=self.timeout
|
||||
)
|
||||
if resp.status_code != 200: print(f"[SGLANG WARNING] update_weights 失败: {resp.status_code}, {resp.text}")
|
||||
ok = resp.status_code == 200
|
||||
if dist.is_initialized(): dist.barrier()
|
||||
try:
|
||||
unwrapped = model.module if isinstance(model, DistributedDataParallel) else model
|
||||
unwrapped = getattr(unwrapped, '_orig_mod', unwrapped)
|
||||
abs_path = os.path.abspath(self.shared_ckpt_path)
|
||||
state_dict = {k: v.detach().half().cpu() for k, v in unwrapped.state_dict().items()}
|
||||
unwrapped.save_pretrained(abs_path, state_dict=state_dict, safe_serialization=False)
|
||||
self.tokenizer.save_pretrained(abs_path)
|
||||
resp = self.http.post(f"{self.base_url}/update_weights_from_disk", json={"model_path": abs_path}, timeout=self.timeout)
|
||||
if resp.status_code != 200: print(f"[SGLANG WARNING] update_weights 失败: {resp.status_code}, {resp.text}")
|
||||
ok = resp.status_code == 200
|
||||
except Exception as e:
|
||||
print(f"[SGLANG WARNING] update_weights 异常: {e}"); ok = False
|
||||
if dist.is_initialized():
|
||||
ok_t = torch.tensor(int(ok), device=next(model.parameters()).device)
|
||||
dist.broadcast(ok_t, src=0); dist.barrier(); ok = bool(ok_t.item())
|
||||
if not ok: raise RuntimeError("SGLang update_policy failed")
|
||||
return ok
|
||||
|
||||
def flush_cache(self) -> bool:
|
||||
|
||||
@ -22,7 +22,7 @@ from transformers import AutoModel
|
||||
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
|
||||
from dataset.lm_dataset import RLAIFDataset
|
||||
from trainer.trainer_utils import Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, SkipBatchSampler, init_model, LMForRewardModel
|
||||
from trainer.rollout_engine import create_rollout_engine, compute_per_token_logps
|
||||
from trainer.rollout_engine import create_rollout_engine
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
@ -87,17 +87,18 @@ def grpo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_mod
|
||||
completion_ids = rollout_result.completion_ids
|
||||
completions = rollout_result.completions
|
||||
old_per_token_logps = rollout_result.per_token_logps.to(args.device)
|
||||
prompt_lens = rollout_result.prompt_lens.to(args.device)
|
||||
full_mask = (outputs != tokenizer.pad_token_id).long()
|
||||
logp_pos = prompt_lens.unsqueeze(1) - 1 + torch.arange(completion_ids.size(1), device=args.device).unsqueeze(0)
|
||||
|
||||
model_unwrapped = model.module if isinstance(model, DistributedDataParallel) else model
|
||||
with autocast_ctx:
|
||||
res = model_unwrapped(outputs, attention_mask=full_mask)
|
||||
aux_loss = res.aux_loss if lm_config.use_moe else torch.tensor(0.0, device=args.device)
|
||||
logits = res.logits[:, :-1, :]
|
||||
per_token_logps = F.log_softmax(logits, dim=-1).gather(2, outputs[:, 1:].unsqueeze(-1)).squeeze(-1)[:, -completion_ids.size(1):]
|
||||
per_token_logps = F.log_softmax(res.logits[:, :-1, :], dim=-1).gather(2, outputs[:, 1:].unsqueeze(-1)).squeeze(-1).gather(1, logp_pos)
|
||||
|
||||
with torch.no_grad():
|
||||
ref_per_token_logps = compute_per_token_logps(ref_model, outputs, completion_ids.size(1), attention_mask=full_mask)
|
||||
ref_per_token_logps = F.log_softmax(ref_model(outputs, attention_mask=full_mask).logits[:, :-1, :], dim=-1).gather(2, outputs[:, 1:].unsqueeze(-1)).squeeze(-1).gather(1, logp_pos)
|
||||
rewards = calculate_rewards(prompts, completions, reward_model).to(args.device) # [B*num_gen]
|
||||
|
||||
if args.debug_mode and is_main_process() and step % args.debug_interval == 0:
|
||||
@ -120,10 +121,11 @@ def grpo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_mod
|
||||
std_r = grouped_rewards.std(dim=1, unbiased=False).repeat_interleave(args.num_generations) # [B*num_gen]
|
||||
advantages = (rewards - mean_r) / (std_r + 1e-4) # [B*num_gen]
|
||||
|
||||
is_eos = completion_ids == tokenizer.eos_token_id # [B*num_gen, R]
|
||||
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=args.device)
|
||||
completion_pad_mask = rollout_result.completion_mask.to(args.device).bool()
|
||||
is_eos = (completion_ids == tokenizer.eos_token_id) & completion_pad_mask # [B*num_gen, R]
|
||||
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1) - 1, dtype=torch.long, device=args.device)
|
||||
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
|
||||
completion_mask = (torch.arange(is_eos.size(1), device=args.device).expand(is_eos.size(0), -1) <= eos_idx.unsqueeze(1)).int() # [B*num_gen, R]
|
||||
completion_mask = ((torch.arange(is_eos.size(1), device=args.device).expand(is_eos.size(0), -1) <= eos_idx.unsqueeze(1)) & completion_pad_mask).int() # [B*num_gen, R]
|
||||
|
||||
kl_div = ref_per_token_logps - per_token_logps
|
||||
per_token_kl = torch.exp(kl_div) - kl_div - 1 # [B*num_gen, R]
|
||||
@ -136,7 +138,7 @@ def grpo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_mod
|
||||
per_token_loss1 = ratio * advantages.unsqueeze(1)
|
||||
per_token_loss2 = clipped_ratio * advantages.unsqueeze(1)
|
||||
per_token_loss = -(torch.min(per_token_loss1, per_token_loss2) - args.beta * per_token_kl)
|
||||
policy_loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
|
||||
policy_loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1).clamp(min=1)).mean()
|
||||
loss = (policy_loss + aux_loss) / args.accumulation_steps # scalar
|
||||
loss.backward()
|
||||
|
||||
@ -152,7 +154,7 @@ def grpo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_mod
|
||||
current_aux_loss = aux_loss.item()
|
||||
avg_reward_val = rewards.mean().item()
|
||||
avg_len_val = completion_mask.sum(dim=1).float().mean().item()
|
||||
kl_ref_val = ((ref_per_token_logps - per_token_logps) * completion_mask).sum().item() / completion_mask.sum().item()
|
||||
kl_ref_val = ((ref_per_token_logps - per_token_logps) * completion_mask).sum().item() / max(completion_mask.sum().item(), 1)
|
||||
advantages_mean_val = advantages.mean().item()
|
||||
advantages_std_val = advantages.std().item()
|
||||
current_lr = optimizer.param_groups[0]['lr']
|
||||
@ -189,7 +191,7 @@ def grpo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_mod
|
||||
if step % args.save_interval == 0 or step == iters: rollout_engine.update_policy(model)
|
||||
|
||||
del prompt_inputs, outputs, completion_ids, per_token_logps, ref_per_token_logps
|
||||
del completions, rewards, grouped_rewards, mean_r, std_r, advantages, completion_mask
|
||||
del completions, rewards, grouped_rewards, mean_r, std_r, advantages, completion_mask, completion_pad_mask, prompt_lens, logp_pos
|
||||
|
||||
if step > start_step and step % args.accumulation_steps != 0:
|
||||
if args.grad_clip > 0:
|
||||
|
||||
@ -84,7 +84,6 @@ def ppo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, actor_sched
|
||||
prompts = batch["prompt"] # list[str], length B
|
||||
enc = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True, max_length=args.max_seq_len,
|
||||
padding_side="left").to(args.device) # input_ids: [B, P], attention_mask: [B, P]
|
||||
prompt_length = enc.input_ids.shape[1]
|
||||
|
||||
rollout_result = rollout_engine.rollout(
|
||||
prompt_ids=enc.input_ids,
|
||||
@ -94,7 +93,10 @@ def ppo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, actor_sched
|
||||
temperature=0.8,
|
||||
)
|
||||
gen_out = rollout_result.output_ids
|
||||
completion_ids = rollout_result.completion_ids
|
||||
prompt_lens = rollout_result.prompt_lens.to(args.device)
|
||||
responses_text = rollout_result.completions
|
||||
old_resp_logp = rollout_result.per_token_logps.to(args.device)
|
||||
rewards = calculate_rewards(prompts, responses_text, reward_model) # [B]
|
||||
|
||||
if args.debug_mode and is_main_process() and step % args.debug_interval == 0:
|
||||
@ -104,7 +106,7 @@ def ppo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, actor_sched
|
||||
Logger(f"{'=' * 30} [DEBUG] sample[{i}] CONTEXT_BEGIN {'=' * 30}")
|
||||
Logger(prompts[i])
|
||||
Logger(f"{'=' * 31} [DEBUG] sample[{i}] CONTEXT_END {'=' * 31}")
|
||||
Logger(f"[DEBUG] prompt_len={prompt_length}, response_len={len(responses_text[i])}")
|
||||
Logger(f"[DEBUG] prompt_len={prompt_lens[i].item()}, response_len={len(responses_text[i])}")
|
||||
Logger(f"{'=' * 28} [DEBUG] sample[{i}] RESPONSE_BEGIN {'=' * 28}")
|
||||
Logger(responses_text[i])
|
||||
Logger(f"{'=' * 29} [DEBUG] sample[{i}] RESPONSE_END {'=' * 29}")
|
||||
@ -113,13 +115,12 @@ def ppo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, actor_sched
|
||||
|
||||
full_mask = (gen_out != tokenizer.pad_token_id).long() # [B, P+R]
|
||||
labels = gen_out[:, 1:].clone() # [B, P+R-1]
|
||||
seq_len, resp_start = gen_out.size(1) - 1, prompt_length - 1
|
||||
resp_mask = torch.arange(seq_len, device=gen_out.device).unsqueeze(0) >= resp_start
|
||||
B = len(prompts)
|
||||
resp_labels = labels[:, resp_start:] # [B, R]
|
||||
resp_labels = completion_ids
|
||||
resp_idx = torch.arange(resp_labels.size(1), device=gen_out.device).unsqueeze(0)
|
||||
resp_pad_mask = ~resp_labels.eq(tokenizer.pad_token_id)
|
||||
resp_lengths = resp_pad_mask.sum(dim=1); eos_mask = resp_labels.eq(tokenizer.eos_token_id) & resp_pad_mask
|
||||
logp_pos = prompt_lens.unsqueeze(1) - 1 + resp_idx
|
||||
resp_pad_mask = rollout_result.completion_mask.to(args.device).bool()
|
||||
resp_lengths = resp_pad_mask.sum(dim=1); valid_resp = resp_lengths > 0; eos_mask = resp_labels.eq(tokenizer.eos_token_id) & resp_pad_mask
|
||||
has_eos = eos_mask.any(dim=1); eos_pos = torch.argmax(eos_mask.int(), dim=1)
|
||||
resp_lengths = torch.where(has_eos, eos_pos + 1, resp_lengths).long().clamp(min=1)
|
||||
resp_policy_mask = ((resp_idx < resp_lengths.unsqueeze(1)) & resp_pad_mask).float()
|
||||
@ -128,19 +129,12 @@ def ppo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, actor_sched
|
||||
with torch.no_grad(): # Rollout阶段只需推理获取old_logp和old_values,切断梯度省显存
|
||||
critic_for_rollout = critic_model.module if isinstance(critic_model, DistributedDataParallel) else critic_model
|
||||
values_seq = critic_for_rollout(input_ids=gen_out, attention_mask=full_mask)
|
||||
old_resp_values = values_seq[:, resp_start:-1] * resp_value_mask
|
||||
old_resp_values = values_seq.gather(1, logp_pos) * resp_value_mask
|
||||
|
||||
actor_for_rollout = actor_model.module if isinstance(actor_model, DistributedDataParallel) else actor_model
|
||||
with autocast_ctx:
|
||||
logits = actor_for_rollout(input_ids=gen_out, attention_mask=full_mask).logits
|
||||
|
||||
old_resp_logp = F.log_softmax(logits[:, :-1], dim=-1).gather(2, labels.unsqueeze(-1)).squeeze(-1)[:, resp_start:]
|
||||
|
||||
ref_logp_all = F.log_softmax(ref_model(input_ids=gen_out, attention_mask=full_mask).logits[:, :-1], dim=-1).gather(2, labels.unsqueeze(-1)).squeeze(-1)
|
||||
ref_resp_logp = ref_logp_all[:, resp_start:]
|
||||
ref_resp_logp = F.log_softmax(ref_model(input_ids=gen_out, attention_mask=full_mask).logits[:, :-1], dim=-1).gather(2, labels.unsqueeze(-1)).squeeze(-1).gather(1, logp_pos)
|
||||
token_rewards = torch.zeros_like(old_resp_logp)
|
||||
last_idx = resp_lengths - 1 # [B]
|
||||
token_rewards[torch.arange(B, device=args.device), last_idx] += rewards # 末尾加外部奖励
|
||||
token_rewards[torch.arange(B, device=args.device)[valid_resp], last_idx[valid_resp]] += rewards[valid_resp] # 末尾加外部奖励
|
||||
|
||||
gen_len = old_resp_values.size(1); lastgaelam = torch.zeros(B, device=args.device); advs_rev = []
|
||||
for t in reversed(range(gen_len)):
|
||||
@ -174,14 +168,13 @@ def ppo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, actor_sched
|
||||
inds = b_inds[i:i + mb_size]
|
||||
|
||||
mb_values_seq = critic_unwrapped(input_ids=gen_out[inds], attention_mask=full_mask[inds])
|
||||
mb_resp_values = mb_values_seq[:, resp_start:-1]
|
||||
mb_resp_values = mb_values_seq.gather(1, logp_pos[inds])
|
||||
|
||||
with autocast_ctx:
|
||||
res = actor_unwrapped(input_ids=gen_out[inds], attention_mask=full_mask[inds])
|
||||
aux_loss = res.aux_loss if lm_config.use_moe else torch.tensor(0.0, device=args.device)
|
||||
|
||||
mb_logp_all = F.log_softmax(res.logits[:, :-1], dim=-1).gather(2, labels[inds].unsqueeze(-1)).squeeze(-1)
|
||||
mb_resp_logp = mb_logp_all[:, resp_start:]
|
||||
mb_resp_logp = F.log_softmax(res.logits[:, :-1], dim=-1).gather(2, labels[inds].unsqueeze(-1)).squeeze(-1).gather(1, logp_pos[inds])
|
||||
|
||||
log_ratio = mb_resp_logp - old_resp_logp[inds]
|
||||
approx_kl = (0.5 * (log_ratio ** 2) * resp_policy_mask[inds]).sum() / resp_policy_mask[inds].sum().clamp(min=1)
|
||||
@ -249,7 +242,7 @@ def ppo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, actor_sched
|
||||
actor_optimizer.zero_grad()
|
||||
critic_optimizer.zero_grad()
|
||||
|
||||
if step % args.save_interval == 0: rollout_engine.update_policy(actor_model)
|
||||
if step % args.save_interval == 0 or step == iters: rollout_engine.update_policy(actor_model)
|
||||
|
||||
if is_main_process():
|
||||
critic_loss_val = value_loss_sum / max(log_count, 1)
|
||||
@ -294,9 +287,9 @@ def ppo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, actor_sched
|
||||
actor_model.train()
|
||||
del actor_state
|
||||
|
||||
del enc, gen_out, responses_text, rewards, full_mask, values_seq, advantages
|
||||
del logits, labels, resp_labels, resp_idx, resp_pad_mask, eos_mask, has_eos, eos_pos, resp_lengths, resp_policy_mask, resp_value_mask, old_resp_logp, ref_logp_all, ref_resp_logp
|
||||
del kl, kl_ref, policy_loss, value_loss, loss, token_rewards, returns, old_resp_values
|
||||
del enc, gen_out, completion_ids, responses_text, rewards, full_mask, values_seq, advantages
|
||||
del labels, resp_labels, resp_idx, resp_pad_mask, valid_resp, eos_mask, has_eos, eos_pos, resp_lengths, resp_policy_mask, resp_value_mask, old_resp_logp, ref_resp_logp
|
||||
del kl, kl_ref, policy_loss, value_loss, loss, token_rewards, returns, old_resp_values, prompt_lens, logp_pos
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Loading…
Reference in New Issue
Block a user