import os import sys __package__ = "trainer" sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) import argparse import re import gc import warnings import torch import torch.distributed as dist from transformers import AutoTokenizer from contextlib import nullcontext from torch import optim from torch.nn.parallel import DistributedDataParallel from torch.utils.data import DataLoader, DistributedSampler from torch.optim.lr_scheduler import CosineAnnealingLR from transformers import AutoModel from model.model_minimind import MiniMindConfig, MiniMindForCausalLM from dataset.lm_dataset import RLAIFDataset from trainer.trainer_utils import Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, SkipBatchSampler, init_model warnings.filterwarnings('ignore') def calculate_rewards(prompts, responses, reward_model, reward_tokenizer): """整合所有奖励函数计算总奖励""" def reasoning_model_reward(rewards): pattern = r"^\n.*?\n\n\n.*?\n$" pattern2 = r"^\n.*?\n\n\n\n.*?\n$" matches_pattern = [re.match(pattern, response, re.S) for response in responses] matches_pattern2 = [re.match(pattern2, response, re.S) for response in responses] format_rewards = [] for match_pattern, match_pattern2 in zip(matches_pattern, matches_pattern2): if match_pattern or match_pattern2: format_rewards.append(0.5) else: format_rewards.append(0.0) rewards += torch.tensor(format_rewards, device=args.device) def mark_num(text): reward = 0 if text.count("") == 1: reward += 0.25 if text.count("") == 1: reward += 0.25 if text.count("") == 1: reward += 0.25 if text.count("") == 1: reward += 0.25 return reward mark_rewards = [mark_num(response) for response in responses] rewards += torch.tensor(mark_rewards, device=args.device) return rewards rewards = torch.zeros(len(responses), device=args.device) if args.reasoning == 1: rewards = reasoning_model_reward(rewards) with torch.no_grad(): reward_model_scores = [] batch_size = len(prompts) scale = 3.0 for i in range(batch_size): for j in range(args.num_generations): response_idx = i * args.num_generations + j response = responses[response_idx] prompt = prompts[i] pattern = r"<\|im_start\|>(system|user|assistant)\s+(.*?)<\|im_end\|>" matches = re.findall(pattern, prompt, re.DOTALL) messages = [{"role": role, "content": content.strip()} for role, content in matches] tmp_chat = messages + [{"role": "assistant", "content": response}] score = reward_model.get_score(reward_tokenizer, tmp_chat) score = max(min(score, scale), -scale) if args.reasoning == 1: answer_match = re.search(r'(.*?)', response, re.DOTALL) if answer_match: answer_content = answer_match.group(1).strip() tmp_chat = messages + [{"role": "assistant", "content": answer_content}] answer_score = reward_model.get_score(reward_tokenizer, tmp_chat) answer_score = max(min(answer_score, scale), -scale) score = score * 0.4 + answer_score * 0.6 reward_model_scores.append(score) reward_model_scores = torch.tensor(reward_model_scores, device=args.device) rewards += reward_model_scores return rewards def grpo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_tokenizer, start_step=0, wandb=None): for step, batch in enumerate(loader, start=start_step + 1): prompts = batch['prompt'] # list[str], length B prompt_inputs = tokenizer(prompts, return_tensors="pt", padding=True, return_token_type_ids=False, padding_side="left", add_special_tokens=False).to(args.device) # input_ids: [B, P], attention_mask: [B, P] if args.max_seq_len: prompt_inputs["input_ids"] = prompt_inputs["input_ids"][:, -args.max_seq_len:] prompt_inputs["attention_mask"] = prompt_inputs["attention_mask"][:, -args.max_seq_len:] with torch.no_grad(): # DDP 模型需要使用 .module 访问 generate 方法 model_for_gen = model.module if isinstance(model, DistributedDataParallel) else model outputs = model_for_gen.generate( **prompt_inputs, max_new_tokens=args.max_gen_len, do_sample=True, temperature=0.8, num_return_sequences=args.num_generations, pad_token_id=tokenizer.pad_token_id) # [B*num_gen, P+R] completion_ids = outputs[:, prompt_inputs["input_ids"].size(1):] # [B*num_gen, R] def get_per_token_logps(mdl, input_ids, n_keep): input_ids = input_ids.detach().clone() if input_ids.is_inference() else input_ids logits = mdl(input_ids, logits_to_keep=n_keep + 1).logits[:, :-1, :] per_token_logps = [] for logits_row, ids_row in zip(logits, input_ids[:, -n_keep:]): ids_row = ids_row.detach().clone() if ids_row.is_inference() else ids_row per_token_logps.append(torch.gather(logits_row.log_softmax(dim=-1), 1, ids_row.unsqueeze(1)).squeeze(1)) return torch.stack(per_token_logps) with autocast_ctx: per_token_logps = get_per_token_logps(model, outputs, completion_ids.size(1)) # [B*num_gen, R] res = model(outputs) if lm_config.use_moe else None aux_loss = res.aux_loss if res is not None else torch.tensor(0.0, device=args.device) with torch.no_grad(): ref_per_token_logps = get_per_token_logps(ref_model, outputs, completion_ids.size(1)) # [B*num_gen, R] completions = tokenizer.batch_decode(completion_ids, skip_special_tokens=True) rewards = calculate_rewards(prompts, completions, reward_model, reward_tokenizer).to(args.device) # [B*num_gen] grouped_rewards = rewards.view(-1, args.num_generations) # [B, num_gen] mean_r = grouped_rewards.mean(dim=1).repeat_interleave(args.num_generations) # [B*num_gen] std_r = grouped_rewards.std(dim=1).repeat_interleave(args.num_generations) # [B*num_gen] advantages = torch.clamp((rewards - mean_r) / (std_r + 1e-4), -10, 10) advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) # [B*num_gen] is_eos = completion_ids == tokenizer.eos_token_id # [B*num_gen, R] eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=args.device) eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] completion_mask = (torch.arange(is_eos.size(1), device=args.device).expand(is_eos.size(0), -1) <= eos_idx.unsqueeze(1)).int() # [B*num_gen, R] kl_div = ref_per_token_logps - per_token_logps per_token_kl = torch.exp(kl_div) - kl_div - 1 # [B*num_gen, R] per_token_loss = -(torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) - args.beta * per_token_kl) # [B*num_gen, R] policy_loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() loss = (policy_loss + aux_loss) / args.accumulation_steps # scalar loss.backward() if (step + 1) % args.accumulation_steps == 0: if args.grad_clip > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) optimizer.step() scheduler.step() optimizer.zero_grad() if step % args.log_interval == 0 or step == iters: policy_loss_val = loss.item() * args.accumulation_steps current_aux_loss = aux_loss.item() avg_reward_val = rewards.mean().item() avg_len_val = completion_mask.sum(dim=1).float().mean().item() current_lr = optimizer.param_groups[0]['lr'] Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), ' f'Actor Loss: {policy_loss_val:.4f}, Aux Loss: {current_aux_loss:.4f}, Reward: {avg_reward_val:.4f}, ' f'Avg Response Len: {avg_len_val:.2f}, Learning Rate: {current_lr:.8f}') if wandb and is_main_process(): wandb.log({ "policy_loss": policy_loss_val, "aux_loss": current_aux_loss, "reward": avg_reward_val, "avg_response_len": avg_len_val, "advantages_mean": advantages.mean().item(), "learning_rate": current_lr }) if (step % args.save_interval == 0 or step == iters - 1) and is_main_process(): model.eval() moe_suffix = '_moe' if lm_config.use_moe else '' ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth' state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict() torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp) lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scheduler=scheduler) model.train() del state_dict del prompt_inputs, outputs, completion_ids, per_token_logps, ref_per_token_logps del completions, rewards, grouped_rewards, mean_r, std_r, advantages, completion_mask if __name__ == "__main__": parser = argparse.ArgumentParser(description="MiniMind GRPO (Group Relative Policy Optimization)") parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录") parser.add_argument('--save_weight', default='grpo', type=str, help="保存权重的前缀名") parser.add_argument("--epochs", type=int, default=1, help="训练轮数") parser.add_argument("--batch_size", type=int, default=2, help="batch size") parser.add_argument("--learning_rate", type=float, default=8e-8, help="初始学习率") parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备") parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型") parser.add_argument("--num_workers", type=int, default=8, help="数据加载线程数") parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数") parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值") parser.add_argument("--log_interval", type=int, default=1, help="日志打印间隔") parser.add_argument("--save_interval", type=int, default=10, help="模型保存间隔") parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度") parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量") parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)") parser.add_argument('--max_seq_len', default=66, type=int, help="Prompt最大长度") parser.add_argument("--max_gen_len", type=int, default=1536, help="生成的最大长度") parser.add_argument("--data_path", type=str, default="../dataset/rlaif-mini.jsonl", help="RLAIF数据路径") parser.add_argument("--num_generations", type=int, default=8, help="每个prompt生成的样本数") parser.add_argument("--beta", type=float, default=0.02, help="KL惩罚系数") parser.add_argument("--reasoning", type=int, default=1, choices=[0, 1], help='推理模型类型(0=普通模型,1=推理模型)') parser.add_argument("--reward_model_path", type=str, default="../../internlm2-1_8b-reward", help="Reward模型路径") parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训(0=否,1=是)") parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb") parser.add_argument("--wandb_project", type=str, default="MiniMind-GRPO", help="wandb项目名") args = parser.parse_args() # ========== 1. 初始化环境和随机种子 ========== local_rank = init_distributed_mode() if dist.is_initialized(): args.device = f"cuda:{local_rank}" setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0)) # ========== 2. 配置目录、模型参数、检查ckp ========== os.makedirs(args.save_dir, exist_ok=True) lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, max_seq_len=args.max_seq_len + args.max_gen_len, use_moe=bool(args.use_moe)) ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None # ========== 3. 设置混合精度 ========== device_type = "cuda" if "cuda" in args.device else "cpu" dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16 autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype) # ========== 4. 配wandb ========== wandb = None if args.use_wandb and is_main_process(): import swanlab as wandb wandb_id = ckp_data.get('wandb_id') if ckp_data else None resume = 'must' if wandb_id else None wandb_run_name = f"MiniMind-GRPO-Epoch-{args.epochs}-BS-{args.batch_size}-LR-{args.learning_rate}" wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume) # ========== 5. 初始化模型和数据 ========== base_weight = "reason" if args.reasoning == 1 else "full_sft" # Policy模型 model, tokenizer = init_model(lm_config, base_weight, device=args.device) # Reference模型 ref_model, _ = init_model(lm_config, base_weight, device=args.device) ref_model = ref_model.eval().requires_grad_(False) # Reward模型 reward_model = AutoModel.from_pretrained( args.reward_model_path, torch_dtype=torch.float16, trust_remote_code=True ) reward_model = reward_model.to(args.device).eval().requires_grad_(False) reward_tokenizer = AutoTokenizer.from_pretrained(args.reward_model_path, trust_remote_code=True) # 数据和优化器 train_ds = RLAIFDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len) train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate) loader_for_count = DataLoader(train_ds, batch_size=args.batch_size, sampler=train_sampler) iters = len(loader_for_count) total_optimizer_steps = (iters // args.accumulation_steps) * args.epochs scheduler = CosineAnnealingLR(optimizer, T_max=total_optimizer_steps, eta_min=args.learning_rate / 10) # ========== 6. 从ckp恢复状态 ========== start_epoch, start_step = 0, 0 if ckp_data: model.load_state_dict(ckp_data['model']) optimizer.load_state_dict(ckp_data['optimizer']) scheduler.load_state_dict(ckp_data['scheduler']) start_epoch = ckp_data['epoch'] start_step = ckp_data.get('step', 0) # ========== 7. DDP包模型 ========== if dist.is_initialized(): model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} model = DistributedDataParallel(model, device_ids=[local_rank]) # ========== 8. 开始训练 ========== for epoch in range(start_epoch, args.epochs): train_sampler and train_sampler.set_epoch(epoch) if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点 batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1) loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True) Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始') grpo_train_epoch(epoch, loader, len(loader) + start_step + 1, ref_model, reward_model, reward_tokenizer, start_step, wandb) else: # 默认从头开始 loader = DataLoader(train_ds, batch_size=args.batch_size, pin_memory=True, drop_last=False, shuffle=(train_sampler is None), num_workers=args.num_workers, sampler=train_sampler) grpo_train_epoch(epoch, loader, len(loader), ref_model, reward_model, reward_tokenizer, 0, wandb)