minimind/trainer/train_grpo.py
2025-10-21 21:19:47 +08:00

317 lines
14 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import sys
__package__ = "trainer"
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import argparse
import time
import re
import gc
import torch
from contextlib import nullcontext
import torch.distributed as dist
from torch import optim
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, DistributedSampler
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
from dataset.lm_dataset import RLAIFDataset
from torch.optim.lr_scheduler import CosineAnnealingLR
def Logger(content):
if not ddp or dist.get_rank() == 0:
print(content)
def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
"""整合所有奖励函数计算总奖励"""
def reasoning_model_reward(rewards):
# 1. 格式奖励(仅针对训练推理模型时使用)
pattern = r"^<think>\n.*?\n</think>\n<answer>\n.*?\n</answer>$"
pattern2 = r"^<think>\n.*?\n</think>\n\n<answer>\n.*?\n</answer>$"
matches_pattern = [re.match(pattern, response, re.S) for response in responses]
matches_pattern2 = [re.match(pattern2, response, re.S) for response in responses]
format_rewards = []
for match_pattern, match_pattern2 in zip(matches_pattern, matches_pattern2):
if match_pattern:
format_rewards.append(0.5)
elif match_pattern2:
format_rewards.append(0.5)
else:
format_rewards.append(0.0)
rewards += torch.tensor(format_rewards, device=args.device)
# 2. 标记奖励(防止严格奖励稀疏,仅针对训练推理模型时使用)
def mark_num(text):
reward = 0
if text.count("<think>") == 1:
reward += 0.25
if text.count("</think>") == 1:
reward += 0.25
if text.count("<answer>") == 1:
reward += 0.25
if text.count("</answer>") == 1:
reward += 0.25
return reward
mark_rewards = [mark_num(response) for response in responses]
rewards += torch.tensor(mark_rewards, device=args.device)
return rewards
rewards = torch.zeros(len(responses), device=args.device)
# 3. 格式奖励
if args.reasoning == 1:
rewards = reasoning_model_reward(rewards) # 训练推理模型时使用
# 4. 使用reward model计算奖励
with torch.no_grad():
reward_model_scores = []
batch_size = len(prompts)
scale = 3.0
for i in range(batch_size):
for j in range(args.num_generations):
response_idx = i * args.num_generations + j
response = responses[response_idx]
prompt = prompts[i]
pattern = r"<\|im_start\|>(system|user|assistant)\s+(.*?)<\|im_end\|>"
matches = re.findall(pattern, prompt, re.DOTALL)
messages = [{"role": role, "content": content.strip()} for role, content in matches]
tmp_chat = messages + [{"role": "assistant", "content": response}]
score = reward_model.get_score(reward_tokenizer, tmp_chat)
score = max(min(score, scale), -scale)
if args.reasoning == 1:
answer_match = re.search(r'<answer>(.*?)</answer>', response, re.DOTALL)
if answer_match:
answer_content = answer_match.group(1).strip()
tmp_chat = messages + [{"role": "assistant", "content": answer_content}]
answer_score = reward_model.get_score(reward_tokenizer, tmp_chat)
answer_score = max(min(answer_score, scale), -scale)
score = score * 0.4 + answer_score * 0.6
reward_model_scores.append(score)
reward_model_scores = torch.tensor(reward_model_scores, device=args.device)
rewards += reward_model_scores
return rewards
def grpo_train_epoch(epoch, wandb):
for step, batch in enumerate(train_loader):
prompts = batch['prompt'] # list[str], length B
prompt_inputs = tokenizer(prompts, return_tensors="pt", padding=True, return_token_type_ids=False,
padding_side="left", add_special_tokens=False).to(args.device) # input_ids: [B, P], attention_mask: [B, P]
if args.max_seq_len:
prompt_inputs["input_ids"] = prompt_inputs["input_ids"][:, -args.max_seq_len:]
prompt_inputs["attention_mask"] = prompt_inputs["attention_mask"][:, -args.max_seq_len:]
with torch.no_grad():
outputs = (model.module if ddp else model).generate(
**prompt_inputs, max_new_tokens=args.max_gen_len, do_sample=True, temperature=0.8,
num_return_sequences=args.num_generations, pad_token_id=tokenizer.pad_token_id) # [B*num_gen, P+R]
completion_ids = outputs[:, prompt_inputs["input_ids"].size(1):] # [B*num_gen, R]
def get_per_token_logps(mdl, input_ids, n_keep):
input_ids = input_ids.detach().clone() if input_ids.is_inference() else input_ids
logits = mdl(input_ids, logits_to_keep=n_keep + 1).logits[:, :-1, :]
per_token_logps = []
for logits_row, ids_row in zip(logits, input_ids[:, -n_keep:]):
ids_row = ids_row.detach().clone() if ids_row.is_inference() else ids_row
per_token_logps.append(torch.gather(logits_row.log_softmax(dim=-1), 1, ids_row.unsqueeze(1)).squeeze(1))
return torch.stack(per_token_logps)
per_token_logps = get_per_token_logps(model, outputs, completion_ids.size(1)) # [B*num_gen, R]
with torch.no_grad():
ref_per_token_logps = get_per_token_logps(ref_model, outputs, completion_ids.size(1)) # [B*num_gen, R]
completions = tokenizer.batch_decode(completion_ids, skip_special_tokens=True)
rewards = calculate_rewards(prompts, completions, reward_model, reward_tokenizer).to(args.device) # [B*num_gen]
grouped_rewards = rewards.view(-1, args.num_generations) # [B, num_gen]
mean_r = grouped_rewards.mean(dim=1).repeat_interleave(args.num_generations) # [B*num_gen]
std_r = grouped_rewards.std(dim=1).repeat_interleave(args.num_generations) # [B*num_gen]
advantages = torch.clamp((rewards - mean_r) / (std_r + 1e-4), -10, 10)
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) # [B*num_gen]
is_eos = completion_ids == tokenizer.eos_token_id # [B*num_gen, R]
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=args.device)
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
completion_mask = (torch.arange(is_eos.size(1), device=args.device).expand(is_eos.size(0), -1) <= eos_idx.unsqueeze(1)).int() # [B*num_gen, R]
kl_div = ref_per_token_logps - per_token_logps
per_token_kl = torch.exp(kl_div) - kl_div - 1 # [B*num_gen, R]
per_token_loss = -(torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) - args.beta * per_token_kl) # [B*num_gen, R]
loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() / args.accumulation_steps # scalar
loss.backward()
if (step + 1) % args.accumulation_steps == 0:
if args.grad_clip > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
if step % args.log_interval == 0 or step == iter_per_epoch - 1:
policy_loss_val = loss.item()
avg_reward_val = rewards.mean().item()
avg_len_val = completion_mask.sum(dim=1).float().mean().item()
current_lr = optimizer.param_groups[0]['lr']
Logger(
f'Epoch: {epoch}, Step: {step + 1}/{iter_per_epoch}, '
f'Actor Loss: {policy_loss_val:.4f}, Reward: {avg_reward_val:.4f}, '
f'Avg Response Len: {avg_len_val:.2f}, LR: {current_lr:.2e}')
if wandb and (not ddp or dist.get_rank() == 0):
log_dict = {
"policy_loss": policy_loss_val,
"reward": avg_reward_val,
"avg_response_len": avg_len_val,
"advantages_mean": advantages.mean().item(),
"learning_rate": current_lr
}
wandb.log(log_dict)
if ((step + 1) % args.save_interval == 0 or step == iter_per_epoch - 1) and (not ddp or dist.get_rank() == 0):
model.eval()
moe_path = '_moe' if lm_config.use_moe else ''
suffix = 'grpo'
ckp = f'{args.save_dir}/{suffix}_{lm_config.hidden_size}{moe_path}.pth'
state_dict = model.module.state_dict() if isinstance(model,
torch.nn.parallel.DistributedDataParallel) else model.state_dict()
torch.save({k: v.half() for k, v in state_dict.items()}, ckp)
model.train()
del prompt_inputs, outputs, completion_ids, per_token_logps, ref_per_token_logps
del completions, rewards, grouped_rewards, mean_r, std_r, advantages, completion_mask
torch.cuda.empty_cache()
gc.collect()
def init_model(lm_config):
tokenizer = AutoTokenizer.from_pretrained('../model/')
model = MiniMindForCausalLM(lm_config)
moe_path = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/full_sft_{lm_config.hidden_size}{moe_path}.pth'
if args.reasoning == 1:
ckp = f'{args.save_dir}/reason_{lm_config.hidden_size}{moe_path}.pth'
state_dict = torch.load(ckp, map_location=args.device)
model.load_state_dict(state_dict, strict=False)
ref_model = MiniMindForCausalLM(lm_config)
ref_model.load_state_dict(state_dict, strict=False)
ref_model.eval().requires_grad_(False)
Logger(f'LLM总参数量{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
model = model.to(args.device)
ref_model = ref_model.to(args.device)
reward_name = "../../internlm2-1_8b-reward"
reward_model = AutoModel.from_pretrained(
reward_name,
device_map="cuda",
torch_dtype=torch.float16,
trust_remote_code=True,
).to(args.device).eval().requires_grad_(False)
reward_tokenizer = AutoTokenizer.from_pretrained(reward_name, trust_remote_code=True)
return model, ref_model, tokenizer, reward_model, reward_tokenizer
def init_distributed_mode():
if not ddp: return
global ddp_local_rank, DEVICE
dist.init_process_group(backend="nccl")
ddp_local_rank = int(os.environ["LOCAL_RANK"])
DEVICE = f"cuda:{ddp_local_rank}"
torch.cuda.set_device(DEVICE)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--out_dir", type=str, default="../out")
parser.add_argument("--epochs", type=int, default=1)
parser.add_argument("--batch_size", type=int, default=2)
parser.add_argument("--learning_rate", type=float, default=8e-8)
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")
parser.add_argument("--dtype", type=str, default="bfloat16")
parser.add_argument("--use_wandb", action="store_true")
parser.add_argument("--wandb_project", type=str, default="MiniMind-GRPO")
parser.add_argument("--num_workers", type=int, default=1)
parser.add_argument("--ddp", action="store_true")
parser.add_argument("--accumulation_steps", type=int, default=1)
parser.add_argument("--grad_clip", type=float, default=1.0)
parser.add_argument("--log_interval", type=int, default=1)
parser.add_argument("--save_interval", type=int, default=10)
parser.add_argument('--hidden_size', default=512, type=int)
parser.add_argument('--num_hidden_layers', default=8, type=int)
parser.add_argument('--use_moe', default=False, type=bool)
parser.add_argument('--max_seq_len', default=66, type=int)
parser.add_argument("--max_gen_len", type=int, default=1536)
parser.add_argument("--data_path", type=str, default="../dataset/rlaif-mini.jsonl")
parser.add_argument("--num_generations", type=int, default=8)
parser.add_argument("--beta", type=float, default=0.02)
parser.add_argument("--reasoning", type=int, default=1, help='0:普通模型1:推理模型')
args = parser.parse_args()
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers,
max_seq_len=args.max_seq_len + args.max_gen_len,
use_moe=args.use_moe)
args.save_dir = os.path.join(args.out_dir)
os.makedirs(args.save_dir, exist_ok=True)
os.makedirs(args.out_dir, exist_ok=True)
ctx = nullcontext() if "cuda" not in args.device else torch.amp.autocast('cuda')
ddp = int(os.environ.get("RANK", -1)) != -1
ddp_local_rank, DEVICE = 0, "cuda:0"
base_seed = 1337
torch.manual_seed(base_seed)
torch.cuda.manual_seed(base_seed)
if ddp:
init_distributed_mode()
args.device = torch.device(DEVICE)
rank = dist.get_rank()
torch.manual_seed(base_seed + rank)
# 同时设置 CUDA 的随机种子
torch.cuda.manual_seed(base_seed + rank)
if args.use_wandb and (not ddp or ddp_local_rank == 0):
import swanlab as wandb
wandb.init(project=args.wandb_project)
else:
wandb = None
model, ref_model, tokenizer, reward_model, reward_tokenizer = init_model(lm_config=lm_config)
train_ds = RLAIFDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
train_sampler = DistributedSampler(train_ds) if ddp else None
train_loader = DataLoader(train_ds, batch_size=args.batch_size, pin_memory=True,
drop_last=False, shuffle=False,
num_workers=args.num_workers, sampler=train_sampler)
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
iter_per_epoch = len(train_loader)
total_optimizer_steps = (iter_per_epoch // args.accumulation_steps) * args.epochs
scheduler = CosineAnnealingLR(optimizer, T_max=total_optimizer_steps, eta_min=args.learning_rate / 10)
if ddp:
model._ddp_params_and_buffers_to_ignore = {"pos_cis"}
model = DistributedDataParallel(model, device_ids=[ddp_local_rank])
for epoch in range(args.epochs):
grpo_train_epoch(epoch, wandb)