mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-01-13 19:57:20 +08:00
317 lines
14 KiB
Python
Executable File
317 lines
14 KiB
Python
Executable File
import os
|
||
import sys
|
||
|
||
__package__ = "trainer"
|
||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||
import argparse
|
||
import time
|
||
import re
|
||
import gc
|
||
import torch
|
||
from contextlib import nullcontext
|
||
import torch.distributed as dist
|
||
from torch import optim
|
||
from torch.nn.parallel import DistributedDataParallel
|
||
from torch.utils.data import DataLoader, DistributedSampler
|
||
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel
|
||
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
|
||
from dataset.lm_dataset import RLAIFDataset
|
||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||
|
||
|
||
def Logger(content):
|
||
if not ddp or dist.get_rank() == 0:
|
||
print(content)
|
||
|
||
|
||
def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
|
||
"""整合所有奖励函数计算总奖励"""
|
||
|
||
def reasoning_model_reward(rewards):
|
||
# 1. 格式奖励(仅针对训练推理模型时使用)
|
||
pattern = r"^<think>\n.*?\n</think>\n<answer>\n.*?\n</answer>$"
|
||
pattern2 = r"^<think>\n.*?\n</think>\n\n<answer>\n.*?\n</answer>$"
|
||
|
||
matches_pattern = [re.match(pattern, response, re.S) for response in responses]
|
||
matches_pattern2 = [re.match(pattern2, response, re.S) for response in responses]
|
||
|
||
format_rewards = []
|
||
for match_pattern, match_pattern2 in zip(matches_pattern, matches_pattern2):
|
||
if match_pattern:
|
||
format_rewards.append(0.5)
|
||
elif match_pattern2:
|
||
format_rewards.append(0.5)
|
||
else:
|
||
format_rewards.append(0.0)
|
||
rewards += torch.tensor(format_rewards, device=args.device)
|
||
|
||
# 2. 标记奖励(防止严格奖励稀疏,仅针对训练推理模型时使用)
|
||
def mark_num(text):
|
||
reward = 0
|
||
if text.count("<think>") == 1:
|
||
reward += 0.25
|
||
if text.count("</think>") == 1:
|
||
reward += 0.25
|
||
if text.count("<answer>") == 1:
|
||
reward += 0.25
|
||
if text.count("</answer>") == 1:
|
||
reward += 0.25
|
||
return reward
|
||
|
||
mark_rewards = [mark_num(response) for response in responses]
|
||
rewards += torch.tensor(mark_rewards, device=args.device)
|
||
return rewards
|
||
|
||
rewards = torch.zeros(len(responses), device=args.device)
|
||
|
||
# 3. 格式奖励
|
||
if args.reasoning == 1:
|
||
rewards = reasoning_model_reward(rewards) # 训练推理模型时使用
|
||
|
||
# 4. 使用reward model计算奖励
|
||
with torch.no_grad():
|
||
reward_model_scores = []
|
||
batch_size = len(prompts)
|
||
scale = 3.0
|
||
|
||
for i in range(batch_size):
|
||
for j in range(args.num_generations):
|
||
response_idx = i * args.num_generations + j
|
||
response = responses[response_idx]
|
||
prompt = prompts[i]
|
||
|
||
pattern = r"<\|im_start\|>(system|user|assistant)\s+(.*?)<\|im_end\|>"
|
||
matches = re.findall(pattern, prompt, re.DOTALL)
|
||
messages = [{"role": role, "content": content.strip()} for role, content in matches]
|
||
|
||
tmp_chat = messages + [{"role": "assistant", "content": response}]
|
||
score = reward_model.get_score(reward_tokenizer, tmp_chat)
|
||
score = max(min(score, scale), -scale)
|
||
|
||
if args.reasoning == 1:
|
||
answer_match = re.search(r'<answer>(.*?)</answer>', response, re.DOTALL)
|
||
if answer_match:
|
||
answer_content = answer_match.group(1).strip()
|
||
tmp_chat = messages + [{"role": "assistant", "content": answer_content}]
|
||
answer_score = reward_model.get_score(reward_tokenizer, tmp_chat)
|
||
answer_score = max(min(answer_score, scale), -scale)
|
||
score = score * 0.4 + answer_score * 0.6
|
||
|
||
reward_model_scores.append(score)
|
||
|
||
reward_model_scores = torch.tensor(reward_model_scores, device=args.device)
|
||
rewards += reward_model_scores
|
||
|
||
return rewards
|
||
|
||
|
||
def grpo_train_epoch(epoch, wandb):
|
||
for step, batch in enumerate(train_loader):
|
||
prompts = batch['prompt'] # list[str], length B
|
||
prompt_inputs = tokenizer(prompts, return_tensors="pt", padding=True, return_token_type_ids=False,
|
||
padding_side="left", add_special_tokens=False).to(args.device) # input_ids: [B, P], attention_mask: [B, P]
|
||
if args.max_seq_len:
|
||
prompt_inputs["input_ids"] = prompt_inputs["input_ids"][:, -args.max_seq_len:]
|
||
prompt_inputs["attention_mask"] = prompt_inputs["attention_mask"][:, -args.max_seq_len:]
|
||
|
||
with torch.no_grad():
|
||
outputs = (model.module if ddp else model).generate(
|
||
**prompt_inputs, max_new_tokens=args.max_gen_len, do_sample=True, temperature=0.8,
|
||
num_return_sequences=args.num_generations, pad_token_id=tokenizer.pad_token_id) # [B*num_gen, P+R]
|
||
|
||
completion_ids = outputs[:, prompt_inputs["input_ids"].size(1):] # [B*num_gen, R]
|
||
|
||
def get_per_token_logps(mdl, input_ids, n_keep):
|
||
input_ids = input_ids.detach().clone() if input_ids.is_inference() else input_ids
|
||
logits = mdl(input_ids, logits_to_keep=n_keep + 1).logits[:, :-1, :]
|
||
per_token_logps = []
|
||
for logits_row, ids_row in zip(logits, input_ids[:, -n_keep:]):
|
||
ids_row = ids_row.detach().clone() if ids_row.is_inference() else ids_row
|
||
per_token_logps.append(torch.gather(logits_row.log_softmax(dim=-1), 1, ids_row.unsqueeze(1)).squeeze(1))
|
||
return torch.stack(per_token_logps)
|
||
|
||
per_token_logps = get_per_token_logps(model, outputs, completion_ids.size(1)) # [B*num_gen, R]
|
||
with torch.no_grad():
|
||
ref_per_token_logps = get_per_token_logps(ref_model, outputs, completion_ids.size(1)) # [B*num_gen, R]
|
||
|
||
completions = tokenizer.batch_decode(completion_ids, skip_special_tokens=True)
|
||
rewards = calculate_rewards(prompts, completions, reward_model, reward_tokenizer).to(args.device) # [B*num_gen]
|
||
|
||
grouped_rewards = rewards.view(-1, args.num_generations) # [B, num_gen]
|
||
mean_r = grouped_rewards.mean(dim=1).repeat_interleave(args.num_generations) # [B*num_gen]
|
||
std_r = grouped_rewards.std(dim=1).repeat_interleave(args.num_generations) # [B*num_gen]
|
||
advantages = torch.clamp((rewards - mean_r) / (std_r + 1e-4), -10, 10)
|
||
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) # [B*num_gen]
|
||
|
||
is_eos = completion_ids == tokenizer.eos_token_id # [B*num_gen, R]
|
||
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=args.device)
|
||
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
|
||
completion_mask = (torch.arange(is_eos.size(1), device=args.device).expand(is_eos.size(0), -1) <= eos_idx.unsqueeze(1)).int() # [B*num_gen, R]
|
||
|
||
kl_div = ref_per_token_logps - per_token_logps
|
||
per_token_kl = torch.exp(kl_div) - kl_div - 1 # [B*num_gen, R]
|
||
per_token_loss = -(torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) - args.beta * per_token_kl) # [B*num_gen, R]
|
||
loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() / args.accumulation_steps # scalar
|
||
loss.backward()
|
||
|
||
if (step + 1) % args.accumulation_steps == 0:
|
||
if args.grad_clip > 0:
|
||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||
optimizer.step()
|
||
scheduler.step()
|
||
optimizer.zero_grad()
|
||
|
||
if step % args.log_interval == 0 or step == iter_per_epoch - 1:
|
||
policy_loss_val = loss.item()
|
||
avg_reward_val = rewards.mean().item()
|
||
avg_len_val = completion_mask.sum(dim=1).float().mean().item()
|
||
current_lr = optimizer.param_groups[0]['lr']
|
||
|
||
Logger(
|
||
f'Epoch: {epoch}, Step: {step + 1}/{iter_per_epoch}, '
|
||
f'Actor Loss: {policy_loss_val:.4f}, Reward: {avg_reward_val:.4f}, '
|
||
f'Avg Response Len: {avg_len_val:.2f}, LR: {current_lr:.2e}')
|
||
|
||
if wandb and (not ddp or dist.get_rank() == 0):
|
||
log_dict = {
|
||
"policy_loss": policy_loss_val,
|
||
"reward": avg_reward_val,
|
||
"avg_response_len": avg_len_val,
|
||
"advantages_mean": advantages.mean().item(),
|
||
"learning_rate": current_lr
|
||
}
|
||
wandb.log(log_dict)
|
||
|
||
if ((step + 1) % args.save_interval == 0 or step == iter_per_epoch - 1) and (not ddp or dist.get_rank() == 0):
|
||
model.eval()
|
||
moe_path = '_moe' if lm_config.use_moe else ''
|
||
suffix = 'grpo'
|
||
ckp = f'{args.save_dir}/{suffix}_{lm_config.hidden_size}{moe_path}.pth'
|
||
|
||
state_dict = model.module.state_dict() if isinstance(model,
|
||
torch.nn.parallel.DistributedDataParallel) else model.state_dict()
|
||
torch.save({k: v.half() for k, v in state_dict.items()}, ckp)
|
||
model.train()
|
||
|
||
del prompt_inputs, outputs, completion_ids, per_token_logps, ref_per_token_logps
|
||
del completions, rewards, grouped_rewards, mean_r, std_r, advantages, completion_mask
|
||
torch.cuda.empty_cache()
|
||
gc.collect()
|
||
|
||
|
||
def init_model(lm_config):
|
||
tokenizer = AutoTokenizer.from_pretrained('../model/')
|
||
model = MiniMindForCausalLM(lm_config)
|
||
moe_path = '_moe' if lm_config.use_moe else ''
|
||
ckp = f'{args.save_dir}/full_sft_{lm_config.hidden_size}{moe_path}.pth'
|
||
if args.reasoning == 1:
|
||
ckp = f'{args.save_dir}/reason_{lm_config.hidden_size}{moe_path}.pth'
|
||
state_dict = torch.load(ckp, map_location=args.device)
|
||
model.load_state_dict(state_dict, strict=False)
|
||
|
||
ref_model = MiniMindForCausalLM(lm_config)
|
||
ref_model.load_state_dict(state_dict, strict=False)
|
||
ref_model.eval().requires_grad_(False)
|
||
|
||
Logger(f'LLM总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
|
||
model = model.to(args.device)
|
||
ref_model = ref_model.to(args.device)
|
||
|
||
reward_name = "../../internlm2-1_8b-reward"
|
||
reward_model = AutoModel.from_pretrained(
|
||
reward_name,
|
||
device_map="cuda",
|
||
torch_dtype=torch.float16,
|
||
trust_remote_code=True,
|
||
).to(args.device).eval().requires_grad_(False)
|
||
reward_tokenizer = AutoTokenizer.from_pretrained(reward_name, trust_remote_code=True)
|
||
|
||
return model, ref_model, tokenizer, reward_model, reward_tokenizer
|
||
|
||
|
||
def init_distributed_mode():
|
||
if not ddp: return
|
||
global ddp_local_rank, DEVICE
|
||
dist.init_process_group(backend="nccl")
|
||
ddp_local_rank = int(os.environ["LOCAL_RANK"])
|
||
DEVICE = f"cuda:{ddp_local_rank}"
|
||
torch.cuda.set_device(DEVICE)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
parser = argparse.ArgumentParser()
|
||
parser.add_argument("--out_dir", type=str, default="../out")
|
||
parser.add_argument("--epochs", type=int, default=1)
|
||
parser.add_argument("--batch_size", type=int, default=2)
|
||
parser.add_argument("--learning_rate", type=float, default=8e-8)
|
||
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")
|
||
parser.add_argument("--dtype", type=str, default="bfloat16")
|
||
parser.add_argument("--use_wandb", action="store_true")
|
||
parser.add_argument("--wandb_project", type=str, default="MiniMind-GRPO")
|
||
parser.add_argument("--num_workers", type=int, default=1)
|
||
parser.add_argument("--ddp", action="store_true")
|
||
parser.add_argument("--accumulation_steps", type=int, default=1)
|
||
parser.add_argument("--grad_clip", type=float, default=1.0)
|
||
parser.add_argument("--log_interval", type=int, default=1)
|
||
parser.add_argument("--save_interval", type=int, default=10)
|
||
parser.add_argument('--hidden_size', default=512, type=int)
|
||
parser.add_argument('--num_hidden_layers', default=8, type=int)
|
||
parser.add_argument('--use_moe', default=False, type=bool)
|
||
parser.add_argument('--max_seq_len', default=66, type=int)
|
||
parser.add_argument("--max_gen_len", type=int, default=1536)
|
||
parser.add_argument("--data_path", type=str, default="../dataset/rlaif-mini.jsonl")
|
||
parser.add_argument("--num_generations", type=int, default=8)
|
||
parser.add_argument("--beta", type=float, default=0.02)
|
||
parser.add_argument("--reasoning", type=int, default=1, help='0:普通模型,1:推理模型')
|
||
args = parser.parse_args()
|
||
|
||
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers,
|
||
max_seq_len=args.max_seq_len + args.max_gen_len,
|
||
use_moe=args.use_moe)
|
||
args.save_dir = os.path.join(args.out_dir)
|
||
os.makedirs(args.save_dir, exist_ok=True)
|
||
os.makedirs(args.out_dir, exist_ok=True)
|
||
|
||
ctx = nullcontext() if "cuda" not in args.device else torch.amp.autocast('cuda')
|
||
ddp = int(os.environ.get("RANK", -1)) != -1
|
||
ddp_local_rank, DEVICE = 0, "cuda:0"
|
||
|
||
base_seed = 1337
|
||
torch.manual_seed(base_seed)
|
||
torch.cuda.manual_seed(base_seed)
|
||
|
||
if ddp:
|
||
init_distributed_mode()
|
||
args.device = torch.device(DEVICE)
|
||
rank = dist.get_rank()
|
||
torch.manual_seed(base_seed + rank)
|
||
# 同时设置 CUDA 的随机种子
|
||
torch.cuda.manual_seed(base_seed + rank)
|
||
|
||
if args.use_wandb and (not ddp or ddp_local_rank == 0):
|
||
import swanlab as wandb
|
||
|
||
wandb.init(project=args.wandb_project)
|
||
else:
|
||
wandb = None
|
||
|
||
model, ref_model, tokenizer, reward_model, reward_tokenizer = init_model(lm_config=lm_config)
|
||
train_ds = RLAIFDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
|
||
train_sampler = DistributedSampler(train_ds) if ddp else None
|
||
train_loader = DataLoader(train_ds, batch_size=args.batch_size, pin_memory=True,
|
||
drop_last=False, shuffle=False,
|
||
num_workers=args.num_workers, sampler=train_sampler)
|
||
|
||
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
|
||
|
||
iter_per_epoch = len(train_loader)
|
||
total_optimizer_steps = (iter_per_epoch // args.accumulation_steps) * args.epochs
|
||
scheduler = CosineAnnealingLR(optimizer, T_max=total_optimizer_steps, eta_min=args.learning_rate / 10)
|
||
|
||
if ddp:
|
||
model._ddp_params_and_buffers_to_ignore = {"pos_cis"}
|
||
model = DistributedDataParallel(model, device_ids=[ddp_local_rank])
|
||
|
||
for epoch in range(args.epochs):
|
||
grpo_train_epoch(epoch, wandb)
|