mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-06-06 00:04:50 +00:00
442 lines
26 KiB
Python
442 lines
26 KiB
Python
import os
|
||
import sys
|
||
|
||
__package__ = "trainer"
|
||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||
|
||
import argparse
|
||
import math
|
||
import re
|
||
import warnings
|
||
import torch
|
||
import torch.distributed as dist
|
||
import torch.nn.functional as F
|
||
from transformers import AutoTokenizer
|
||
from contextlib import nullcontext
|
||
from torch import optim, nn
|
||
from torch.nn.parallel import DistributedDataParallel
|
||
from torch.utils.data import DataLoader, DistributedSampler
|
||
from torch.nn.utils import clip_grad_norm_
|
||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
|
||
from dataset.lm_dataset import RLAIFDataset
|
||
from trainer.trainer_utils import Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, SkipBatchSampler, init_model, LMForRewardModel
|
||
from trainer.rollout_engine import create_rollout_engine
|
||
|
||
warnings.filterwarnings('ignore')
|
||
|
||
|
||
def rep_penalty(text, n=3, cap=0.5):
|
||
toks = re.findall(r"\w+|[^\w\s]", text.lower())
|
||
grams = [tuple(toks[i:i + n]) for i in range(len(toks) - n + 1)]
|
||
return min(cap, (len(grams) - len(set(grams))) * cap * 2 / len(grams)) if grams else 0.0
|
||
|
||
|
||
# 自定义的Critic模型,继承自MiniMindLM
|
||
class CriticModel(MiniMindForCausalLM):
|
||
def __init__(self, params):
|
||
super().__init__(params)
|
||
# 替换lm_head为输出单一价值的线性层
|
||
self.value_head = nn.Linear(params.hidden_size, 1)
|
||
|
||
def forward(self, input_ids=None, attention_mask=None, **kwargs):
|
||
# 使用基础模型获取隐藏状态
|
||
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
|
||
hidden_states = self.model.norm(outputs[0])
|
||
# 使用value_head获取价值估计
|
||
values = self.value_head(hidden_states).squeeze(-1)
|
||
return values
|
||
|
||
|
||
def calculate_rewards(prompts, responses, reward_model):
|
||
rewards = torch.zeros(len(responses), device=args.device)
|
||
|
||
with torch.no_grad():
|
||
reward_model_scores = []
|
||
for i, (prompt, response) in enumerate(zip(prompts, responses)):
|
||
pattern = r"<\|im_start\|>(system|user|assistant)\s+(.*?)<\|im_end\|>"
|
||
matches = re.findall(pattern, prompt, re.DOTALL)
|
||
messages = [{"role": role, "content": content.strip()} for role, content in matches]
|
||
answer = response
|
||
rewards[i] += 0.5 if 20 <= len(response.strip()) <= 800 else -0.5
|
||
if '</think>' in response:
|
||
thinking_content, answer_content = response.split('</think>', 1)
|
||
rewards[i] += 1.0 if 20 <= len(thinking_content.strip()) <= 300 else -0.5
|
||
rewards[i] += 0.25 if response.count('</think>') == 1 else -0.25
|
||
answer = answer_content.strip()
|
||
rewards[i] -= rep_penalty(answer)
|
||
|
||
score = reward_model.get_score(messages, answer)
|
||
reward_model_scores.append(score)
|
||
|
||
reward_model_scores = torch.tensor(reward_model_scores, device=args.device)
|
||
rewards += reward_model_scores
|
||
|
||
return rewards
|
||
|
||
|
||
def ppo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, actor_scheduler, critic_scheduler, reward_model, start_step=0, wandb=None, use_sglang=False):
|
||
actor_model.train()
|
||
critic_model.train()
|
||
grad_accum_step = 0
|
||
|
||
for step, batch in enumerate(loader, start=start_step + 1):
|
||
prompts = batch["prompt"] # list[str], length B
|
||
enc = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True, max_length=args.max_seq_len,
|
||
padding_side="left").to(args.device) # input_ids: [B, P], attention_mask: [B, P]
|
||
prompt_length = enc.input_ids.shape[1]
|
||
|
||
rollout_result = rollout_engine.rollout(
|
||
prompt_ids=enc.input_ids,
|
||
attention_mask=enc.attention_mask,
|
||
num_generations=1,
|
||
max_new_tokens=args.max_gen_len,
|
||
temperature=0.8,
|
||
)
|
||
gen_out = rollout_result.output_ids
|
||
responses_text = rollout_result.completions
|
||
rewards = calculate_rewards(prompts, responses_text, reward_model) # [B]
|
||
|
||
if args.debug_mode and is_main_process() and step % args.debug_interval == 0:
|
||
for i in range(len(prompts)):
|
||
Logger(f"[DEBUG] step={step}, sample[{i}]")
|
||
Logger('-'*100)
|
||
Logger(f"{'=' * 30} [DEBUG] sample[{i}] CONTEXT_BEGIN {'=' * 30}")
|
||
Logger(prompts[i])
|
||
Logger(f"{'=' * 31} [DEBUG] sample[{i}] CONTEXT_END {'=' * 31}")
|
||
Logger(f"[DEBUG] prompt_len={prompt_length}, response_len={len(responses_text[i])}")
|
||
Logger(f"{'=' * 28} [DEBUG] sample[{i}] RESPONSE_BEGIN {'=' * 28}")
|
||
Logger(responses_text[i])
|
||
Logger(f"{'=' * 29} [DEBUG] sample[{i}] RESPONSE_END {'=' * 29}")
|
||
Logger(f"[DEBUG] reward={rewards[i].item():.4f}")
|
||
Logger('='*100)
|
||
|
||
full_mask = (gen_out != tokenizer.pad_token_id).long() # [B, P+R]
|
||
labels = gen_out[:, 1:].clone() # [B, P+R-1]
|
||
seq_len, resp_start = gen_out.size(1) - 1, prompt_length - 1
|
||
resp_mask = torch.arange(seq_len, device=gen_out.device).unsqueeze(0) >= resp_start
|
||
final_mask = (resp_mask & (~labels.eq(tokenizer.pad_token_id))).float() # [B, P+R-1]
|
||
B = len(prompts)
|
||
resp_labels = labels[:, resp_start:] # [B, R]
|
||
resp_idx = torch.arange(resp_labels.size(1), device=gen_out.device).unsqueeze(0)
|
||
resp_pad_mask = ~resp_labels.eq(tokenizer.pad_token_id)
|
||
resp_lengths = resp_pad_mask.sum(dim=1); eos_mask = resp_labels.eq(tokenizer.eos_token_id) & resp_pad_mask
|
||
has_eos = eos_mask.any(dim=1); eos_pos = torch.argmax(eos_mask.int(), dim=1)
|
||
resp_lengths = torch.where(has_eos, eos_pos + 1, resp_lengths).long().clamp(min=1)
|
||
resp_policy_mask = ((resp_idx < resp_lengths.unsqueeze(1)) & resp_pad_mask).float()
|
||
resp_value_mask = resp_policy_mask.clone()
|
||
|
||
with torch.no_grad(): # Rollout阶段只需推理获取old_logp和old_values,切断梯度省显存
|
||
critic_for_rollout = critic_model.module if isinstance(critic_model, DistributedDataParallel) else critic_model
|
||
values_seq = critic_for_rollout(input_ids=gen_out, attention_mask=full_mask)
|
||
old_resp_values = values_seq[:, resp_start:-1] * resp_value_mask
|
||
|
||
actor_for_rollout = actor_model.module if isinstance(actor_model, DistributedDataParallel) else actor_model
|
||
with autocast_ctx:
|
||
logits = actor_for_rollout(input_ids=gen_out, attention_mask=full_mask).logits
|
||
|
||
old_resp_logp = F.log_softmax(logits[:, :-1], dim=-1).gather(2, labels.unsqueeze(-1)).squeeze(-1)[:, resp_start:]
|
||
|
||
ref_logp_all = F.log_softmax(ref_model(input_ids=gen_out, attention_mask=full_mask).logits[:, :-1], dim=-1).gather(2, labels.unsqueeze(-1)).squeeze(-1)
|
||
ref_resp_logp = ref_logp_all[:, resp_start:]
|
||
token_rewards = torch.zeros_like(old_resp_logp)
|
||
last_idx = resp_lengths - 1 # [B]
|
||
token_rewards[torch.arange(B, device=args.device), last_idx] += rewards # 末尾加外部奖励
|
||
|
||
gen_len = old_resp_values.size(1); lastgaelam = torch.zeros(B, device=args.device); advs_rev = []
|
||
for t in reversed(range(gen_len)):
|
||
nv = old_resp_values[:, t + 1] if t < gen_len - 1 else 0.0
|
||
delta = token_rewards[:, t] + args.gamma * nv - old_resp_values[:, t]
|
||
lastgaelam = delta + args.gamma * args.lam * lastgaelam
|
||
advs_rev.append(lastgaelam)
|
||
advantages = torch.stack(advs_rev[::-1], dim=1) # [B, R]
|
||
returns = advantages + old_resp_values # [B, R]
|
||
|
||
adv_mean = (advantages * resp_policy_mask).sum() / resp_policy_mask.sum().clamp(min=1)
|
||
adv_var = ((advantages - adv_mean) ** 2 * resp_policy_mask).sum() / resp_policy_mask.sum().clamp(min=1)
|
||
advantages = (advantages - adv_mean) * torch.rsqrt(adv_var + 1e-8) * resp_policy_mask
|
||
|
||
mb_size = max(1, min(args.mini_batch_size, B))
|
||
stop_ppo = False
|
||
policy_loss_sum = 0.0
|
||
value_loss_sum = 0.0
|
||
kl_sum = 0.0
|
||
kl_ref_sum = 0.0
|
||
clipfrac_sum = 0.0
|
||
aux_loss_sum = 0.0
|
||
log_count = 0
|
||
actor_unwrapped = actor_model.module if isinstance(actor_model, DistributedDataParallel) else actor_model
|
||
critic_unwrapped = critic_model.module if isinstance(critic_model, DistributedDataParallel) else critic_model
|
||
for ppo_epoch in range(args.ppo_update_iters):
|
||
if stop_ppo:
|
||
break
|
||
b_inds = torch.randperm(B, device=args.device)
|
||
for i in range(0, B, mb_size):
|
||
inds = b_inds[i:i + mb_size]
|
||
|
||
mb_values_seq = critic_unwrapped(input_ids=gen_out[inds], attention_mask=full_mask[inds])
|
||
mb_resp_values = mb_values_seq[:, resp_start:-1]
|
||
|
||
with autocast_ctx:
|
||
res = actor_unwrapped(input_ids=gen_out[inds], attention_mask=full_mask[inds])
|
||
aux_loss = res.aux_loss if lm_config.use_moe else torch.tensor(0.0, device=args.device)
|
||
|
||
mb_logp_all = F.log_softmax(res.logits[:, :-1], dim=-1).gather(2, labels[inds].unsqueeze(-1)).squeeze(-1)
|
||
mb_resp_logp = mb_logp_all[:, resp_start:]
|
||
|
||
log_ratio = mb_resp_logp - old_resp_logp[inds]
|
||
approx_kl = (0.5 * (log_ratio ** 2) * resp_policy_mask[inds]).sum() / resp_policy_mask[inds].sum().clamp(min=1)
|
||
|
||
# 同步各卡的 approx_kl,防止某卡 break 而其它卡继续导致 DDP 死锁
|
||
approx_kl_val = approx_kl.detach().clone()
|
||
if dist.is_initialized():
|
||
dist.all_reduce(approx_kl_val, op=dist.ReduceOp.AVG)
|
||
|
||
if approx_kl_val > args.early_stop_kl:
|
||
stop_ppo = True
|
||
|
||
ratio = torch.exp(log_ratio)
|
||
clipfrac = ((((ratio - 1.0).abs() > args.clip_epsilon).float() * resp_policy_mask[inds]).sum()
|
||
/ resp_policy_mask[inds].sum().clamp(min=1))
|
||
kl_ref_penalty = ((torch.exp(ref_resp_logp[inds] - mb_resp_logp) - (ref_resp_logp[inds] - mb_resp_logp) - 1.0)
|
||
* resp_policy_mask[inds]).sum() / resp_policy_mask[inds].sum().clamp(min=1)
|
||
policy_loss = ((torch.max(-advantages[inds] * ratio,
|
||
-advantages[inds] * torch.clamp(ratio, 1.0 - args.clip_epsilon, 1.0 + args.clip_epsilon))
|
||
* resp_policy_mask[inds]).sum() / resp_policy_mask[inds].sum().clamp(min=1)
|
||
+ args.kl_coef * kl_ref_penalty)
|
||
value_loss = 0.5 * (torch.max((mb_resp_values - returns[inds]) ** 2,
|
||
(torch.clamp(mb_resp_values, old_resp_values[inds] - args.cliprange_value,
|
||
old_resp_values[inds] + args.cliprange_value) - returns[inds]) ** 2)
|
||
* resp_value_mask[inds]).sum() / resp_value_mask[inds].sum().clamp(min=1)
|
||
|
||
kl = approx_kl_val
|
||
kl_ref = kl_ref_penalty.detach()
|
||
|
||
# 早停时必须保证 forward-backward 闭环,故只截断 loss 不中断 DDP 通信
|
||
if stop_ppo:
|
||
loss = (policy_loss + args.vf_coef * value_loss + aux_loss) * 0.0
|
||
else:
|
||
loss = (policy_loss + args.vf_coef * value_loss + aux_loss) / args.accumulation_steps
|
||
|
||
loss.backward()
|
||
|
||
policy_loss_sum += policy_loss.item()
|
||
value_loss_sum += value_loss.item()
|
||
kl_sum += kl.item()
|
||
kl_ref_sum += kl_ref.item()
|
||
clipfrac_sum += clipfrac.item()
|
||
aux_loss_sum += aux_loss.item()
|
||
log_count += 1
|
||
|
||
grad_accum_step += 1
|
||
|
||
if grad_accum_step % args.accumulation_steps == 0:
|
||
clip_grad_norm_(actor_model.parameters(), args.grad_clip)
|
||
clip_grad_norm_(critic_model.parameters(), args.grad_clip)
|
||
actor_optimizer.step()
|
||
critic_optimizer.step()
|
||
actor_scheduler.step()
|
||
critic_scheduler.step()
|
||
actor_optimizer.zero_grad()
|
||
critic_optimizer.zero_grad()
|
||
|
||
if grad_accum_step % args.accumulation_steps != 0:
|
||
clip_grad_norm_(actor_model.parameters(), args.grad_clip)
|
||
clip_grad_norm_(critic_model.parameters(), args.grad_clip)
|
||
actor_optimizer.step()
|
||
critic_optimizer.step()
|
||
actor_scheduler.step()
|
||
critic_scheduler.step()
|
||
actor_optimizer.zero_grad()
|
||
critic_optimizer.zero_grad()
|
||
|
||
if is_main_process() and step % args.save_interval == 0: rollout_engine.update_policy(actor_model)
|
||
|
||
if is_main_process():
|
||
critic_loss_val = value_loss_sum / max(log_count, 1)
|
||
reward_val = rewards.mean().item()
|
||
approx_kl_val = kl_sum / max(log_count, 1)
|
||
kl_ref_val = kl_ref_sum / max(log_count, 1)
|
||
clipfrac_val = clipfrac_sum / max(log_count, 1)
|
||
avg_len_val = resp_lengths.float().mean().item()
|
||
actor_lr, critic_lr = actor_optimizer.param_groups[0]['lr'], critic_optimizer.param_groups[0]['lr']
|
||
|
||
if wandb is not None:
|
||
wandb.log({
|
||
"reward": reward_val,
|
||
"kl_ref": kl_ref_val,
|
||
"approx_kl": approx_kl_val,
|
||
"clipfrac": clipfrac_val,
|
||
"critic_loss": critic_loss_val,
|
||
"avg_response_len": avg_len_val,
|
||
"actor_lr": actor_lr,
|
||
"critic_lr": critic_lr,
|
||
})
|
||
|
||
Logger(f"Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), "
|
||
f"Reward: {reward_val:.4f}, KL_ref: {kl_ref_val:.4f}, Approx KL: {approx_kl_val:.4f}, "
|
||
f"ClipFrac: {clipfrac_val:.4f}, Critic Loss: {critic_loss_val:.4f}, "
|
||
f"Avg Response Len: {avg_len_val:.2f}, Actor LR: {actor_lr:.8f}, Critic LR: {critic_lr:.8f}")
|
||
|
||
if (step % args.save_interval == 0 or step == iters) and is_main_process():
|
||
actor_model.eval()
|
||
moe_suffix = '_moe' if lm_config.use_moe else ''
|
||
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
|
||
raw_actor = actor_model.module if isinstance(actor_model, DistributedDataParallel) else actor_model
|
||
raw_actor = getattr(raw_actor, '_orig_mod', raw_actor)
|
||
actor_state = raw_actor.state_dict()
|
||
torch.save({k: v.half().cpu() for k, v in actor_state.items()}, ckp)
|
||
|
||
# 使用 lm_checkpoint 保存完整状态(包括 critic)
|
||
lm_checkpoint(lm_config, weight=args.save_weight, model=actor_model, optimizer=actor_optimizer,
|
||
epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints',
|
||
scheduler=actor_scheduler, critic_model=critic_model,
|
||
critic_optimizer=critic_optimizer, critic_scheduler=critic_scheduler)
|
||
actor_model.train()
|
||
del actor_state
|
||
|
||
del enc, gen_out, responses_text, rewards, full_mask, values_seq, advantages
|
||
del logits, labels, final_mask, resp_labels, resp_idx, resp_pad_mask, eos_mask, has_eos, eos_pos, resp_lengths, resp_policy_mask, resp_value_mask, old_resp_logp, ref_logp_all, ref_resp_logp
|
||
del kl, kl_ref, policy_loss, value_loss, loss, token_rewards, returns, old_resp_values
|
||
|
||
|
||
if __name__ == "__main__":
|
||
parser = argparse.ArgumentParser(description="MiniMind PPO (Proximal Policy Optimization)")
|
||
parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
|
||
parser.add_argument('--save_weight', default='ppo_actor', type=str, help="保存权重的前缀名")
|
||
parser.add_argument("--epochs", type=int, default=1, help="训练轮数")
|
||
parser.add_argument("--batch_size", type=int, default=2, help="batch size")
|
||
parser.add_argument("--learning_rate", type=float, default=3e-7, help="Actor学习率")
|
||
parser.add_argument("--critic_learning_rate", type=float, default=5e-7, help="Critic学习率")
|
||
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
|
||
parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
|
||
parser.add_argument("--num_workers", type=int, default=8, help="数据加载线程数")
|
||
parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数")
|
||
parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
|
||
parser.add_argument("--log_interval", type=int, default=1, help="日志打印间隔")
|
||
parser.add_argument("--save_interval", type=int, default=10, help="模型保存间隔")
|
||
parser.add_argument('--hidden_size', default=768, type=int, help="隐藏层维度")
|
||
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
|
||
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)")
|
||
parser.add_argument('--max_seq_len', default=768, type=int, help="Prompt最大长度")
|
||
parser.add_argument("--max_gen_len", type=int, default=1024, help="生成的最大长度")
|
||
parser.add_argument("--data_path", type=str, default="../dataset/rlaif.jsonl", help="RLAIF数据路径")
|
||
parser.add_argument("--clip_epsilon", type=float, default=0.2, help="PPO裁剪参数")
|
||
parser.add_argument("--vf_coef", type=float, default=0.5, help="Value function系数")
|
||
parser.add_argument("--kl_coef", type=float, default=0.02, help="KL散度惩罚系数")
|
||
parser.add_argument("--gamma", type=float, default=1.0, help="GAE折扣因子")
|
||
parser.add_argument("--lam", type=float, default=0.95, help="GAE lambda参数")
|
||
parser.add_argument("--cliprange_value", type=float, default=0.2, help="Value function裁剪范围")
|
||
parser.add_argument("--ppo_update_iters", type=int, default=2, help="同一批rollout重复更新次数")
|
||
parser.add_argument("--early_stop_kl", type=float, default=0.25, help="PPO early stop 的 KL 阈值")
|
||
parser.add_argument("--mini_batch_size", type=int, default=2, help="PPO每次更新的minibatch大小")
|
||
parser.add_argument('--from_weight', default='full_sft', type=str, help="基于哪个权重训练")
|
||
parser.add_argument("--reward_model_path", type=str, default="../../internlm2-1_8b-reward", help="Reward模型路径")
|
||
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训(0=否,1=是)")
|
||
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
|
||
parser.add_argument("--wandb_project", type=str, default="MiniMind-PPO", help="wandb项目名")
|
||
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速(0=否,1=是)")
|
||
parser.add_argument("--debug_mode", action="store_true", help="是否打印训练调试采样")
|
||
parser.add_argument("--debug_interval", type=int, default=20, help="debug模式下每隔多少step打印一次采样")
|
||
parser.add_argument("--thinking_ratio", type=float, default=0.9, help="按概率开启thinking(0.0~1.0)")
|
||
parser.add_argument("--rollout_engine", type=str, default="torch", choices=["torch", "sglang"], help="rollout引擎类型")
|
||
parser.add_argument("--sglang_base_url", type=str, default="http://localhost:8998", help="SGLang服务器URL")
|
||
parser.add_argument("--sglang_model_path", type=str, default="../model", help="SGLang tokenizer路径")
|
||
parser.add_argument("--sglang_shared_path", type=str, default="./sglang_ckpt_ppo", help="SGLang共享存储路径")
|
||
args = parser.parse_args()
|
||
|
||
# ========== 1. 初始化环境和随机种子 ==========
|
||
local_rank = init_distributed_mode()
|
||
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
|
||
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
|
||
|
||
# ========== 2. 配置目录、模型参数、检查ckp ==========
|
||
os.makedirs(args.save_dir, exist_ok=True)
|
||
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe))
|
||
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
|
||
|
||
# ========== 3. 设置混合精度 ==========
|
||
device_type = "cuda" if "cuda" in args.device else "cpu"
|
||
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
|
||
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
|
||
|
||
# ========== 4. 配wandb ==========
|
||
wandb = None
|
||
if args.use_wandb and is_main_process():
|
||
import swanlab as wandb
|
||
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
|
||
resume = 'must' if wandb_id else None
|
||
wandb_run_name = f"MiniMind-PPO-Epoch-{args.epochs}-BS-{args.batch_size}-LR-{args.learning_rate}"
|
||
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
|
||
|
||
# ========== 5. 初始化模型和数据 ==========
|
||
base_weight = args.from_weight
|
||
# Actor模型
|
||
actor_model, tokenizer = init_model(lm_config, base_weight, device=args.device)
|
||
ref_model, _ = init_model(lm_config, base_weight, device=args.device)
|
||
ref_model = ref_model.eval().requires_grad_(False)
|
||
moe_suffix = '_moe' if lm_config.use_moe else ''
|
||
ckp = f'{args.save_dir}/{base_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
|
||
state_dict = torch.load(ckp, map_location=args.device)
|
||
critic_model = CriticModel(lm_config)
|
||
critic_model.load_state_dict(state_dict, strict=False)
|
||
critic_model = critic_model.to(args.device)
|
||
reward_model = LMForRewardModel(args.reward_model_path, device=args.device, dtype=torch.float16)
|
||
# Rollout引擎
|
||
rollout_engine = create_rollout_engine(
|
||
engine_type=args.rollout_engine,
|
||
policy_model=actor_model,
|
||
tokenizer=tokenizer,
|
||
device=args.device,
|
||
autocast_ctx=autocast_ctx,
|
||
sglang_base_url=args.sglang_base_url,
|
||
sglang_model_path=args.sglang_model_path,
|
||
sglang_shared_path=args.sglang_shared_path,
|
||
)
|
||
train_ds = RLAIFDataset(args.data_path, tokenizer, max_length=(args.max_seq_len + args.max_gen_len), thinking_ratio=args.thinking_ratio)
|
||
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
|
||
actor_optimizer = optim.AdamW(actor_model.parameters(), lr=args.learning_rate)
|
||
critic_optimizer = optim.AdamW(critic_model.parameters(), lr=args.critic_learning_rate)
|
||
loader_for_count = DataLoader(train_ds, batch_size=args.batch_size, sampler=train_sampler)
|
||
iters = len(loader_for_count)
|
||
mb_factor = max(1, math.ceil(args.batch_size / args.mini_batch_size))
|
||
total_optimizer_steps = math.ceil(iters * args.epochs * args.ppo_update_iters * mb_factor / args.accumulation_steps)
|
||
actor_scheduler = CosineAnnealingLR(actor_optimizer, T_max=total_optimizer_steps, eta_min=args.learning_rate / 10)
|
||
critic_scheduler = CosineAnnealingLR(critic_optimizer, T_max=total_optimizer_steps, eta_min=args.critic_learning_rate / 10)
|
||
|
||
start_epoch, start_step = 0, 0
|
||
if ckp_data:
|
||
actor_model.load_state_dict(ckp_data['model'])
|
||
critic_model.load_state_dict(ckp_data['critic_model'])
|
||
actor_optimizer.load_state_dict(ckp_data['optimizer'])
|
||
critic_optimizer.load_state_dict(ckp_data['critic_optimizer'])
|
||
actor_scheduler.load_state_dict(ckp_data['scheduler'])
|
||
critic_scheduler.load_state_dict(ckp_data['critic_scheduler'])
|
||
start_epoch = ckp_data['epoch']
|
||
start_step = ckp_data.get('step', 0)
|
||
|
||
# ========== 7. 编译和分布式包装 ==========
|
||
if args.use_compile == 1:
|
||
actor_model = torch.compile(actor_model)
|
||
Logger('torch.compile enabled')
|
||
rollout_engine.update_policy(actor_model)
|
||
if dist.is_initialized():
|
||
actor_model = DistributedDataParallel(actor_model, device_ids=[local_rank])
|
||
critic_model = DistributedDataParallel(critic_model, device_ids=[local_rank])
|
||
if is_main_process(): rollout_engine.update_policy(actor_model)
|
||
|
||
# ========== 8. 开始训练 ==========
|
||
for epoch in range(start_epoch, args.epochs):
|
||
train_sampler and train_sampler.set_epoch(epoch)
|
||
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
|
||
skip = start_step if (epoch == start_epoch and start_step > 0) else 0
|
||
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
|
||
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
||
if skip > 0:
|
||
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始')
|
||
ppo_train_epoch(epoch, loader, len(loader) + skip, rollout_engine, ref_model, actor_scheduler, critic_scheduler, reward_model, start_step, wandb, use_sglang = (args.rollout_engine == "sglang"))
|
||
else:
|
||
ppo_train_epoch(epoch, loader, len(loader), rollout_engine, ref_model, actor_scheduler, critic_scheduler, reward_model, 0, wandb, use_sglang = (args.rollout_engine == "sglang"))
|
||
|
||
# ========== 9. 清理分布进程 ==========
|
||
if dist.is_initialized(): dist.destroy_process_group() |