mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-04-25 08:48:16 +08:00
[add] add CISPO (Clipped Importance Sampling Policy Optimization) algorithm
This the main addition reference GRPO algorithm: By changing the gradient of out-of-bounds tokens from "directly set to 0" to "bounded clipping", we ensure that high-value exploration tokens can continue to participate in parameter updates while maintaining training stability. Actually, this change is not very well in experiment, because the ratio is almost nearly at 1, which means seldom out-of-bound.
This commit is contained in:
parent
7389f64dee
commit
ce9ed24dcd
343
train_cispo.py
Normal file
343
train_cispo.py
Normal file
@ -0,0 +1,343 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
__package__ = "trainer"
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
|
||||
import argparse
|
||||
import re
|
||||
import gc
|
||||
import warnings
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from transformers import AutoTokenizer
|
||||
from contextlib import nullcontext
|
||||
from torch import optim
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
from transformers import AutoModel
|
||||
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
|
||||
from dataset.lm_dataset import RLAIFDataset
|
||||
from trainer.trainer_utils import Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, SkipBatchSampler, init_model
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
|
||||
def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
|
||||
"""整合所有奖励函数计算总奖励"""
|
||||
def reasoning_model_reward(rewards):
|
||||
pattern = r"^<think>\n.*?\n</think>\n<answer>\n.*?\n</answer>$"
|
||||
pattern2 = r"^<think>\n.*?\n</think>\n\n<answer>\n.*?\n</answer>$"
|
||||
matches_pattern = [re.match(pattern, response, re.S) for response in responses]
|
||||
matches_pattern2 = [re.match(pattern2, response, re.S) for response in responses]
|
||||
|
||||
format_rewards = []
|
||||
for match_pattern, match_pattern2 in zip(matches_pattern, matches_pattern2):
|
||||
if match_pattern or match_pattern2:
|
||||
format_rewards.append(0.5)
|
||||
else:
|
||||
format_rewards.append(0.0)
|
||||
rewards += torch.tensor(format_rewards, device=args.device)
|
||||
|
||||
def mark_num(text):
|
||||
reward = 0
|
||||
if text.count("<think>") == 1: reward += 0.25
|
||||
if text.count("</think>") == 1: reward += 0.25
|
||||
if text.count("<answer>") == 1: reward += 0.25
|
||||
if text.count("</answer>") == 1: reward += 0.25
|
||||
return reward
|
||||
|
||||
mark_rewards = [mark_num(response) for response in responses]
|
||||
rewards += torch.tensor(mark_rewards, device=args.device)
|
||||
return rewards
|
||||
|
||||
rewards = torch.zeros(len(responses), device=args.device)
|
||||
if args.reasoning == 1:
|
||||
rewards = reasoning_model_reward(rewards)
|
||||
|
||||
with torch.no_grad():
|
||||
reward_model_scores = []
|
||||
batch_size = len(prompts)
|
||||
scale = 3.0
|
||||
|
||||
for i in range(batch_size):
|
||||
for j in range(args.num_generations):
|
||||
response_idx = i * args.num_generations + j
|
||||
response = responses[response_idx]
|
||||
prompt = prompts[i]
|
||||
|
||||
pattern = r"<\|im_start\|>(system|user|assistant)\s+(.*?)<\|im_end\|>"
|
||||
matches = re.findall(pattern, prompt, re.DOTALL)
|
||||
messages = [{"role": role, "content": content.strip()} for role, content in matches]
|
||||
|
||||
tmp_chat = messages + [{"role": "assistant", "content": response}]
|
||||
score = reward_model.get_score(reward_tokenizer, tmp_chat)
|
||||
score = max(min(score, scale), -scale)
|
||||
|
||||
if args.reasoning == 1:
|
||||
answer_match = re.search(r'<answer>(.*?)</answer>', response, re.DOTALL)
|
||||
if answer_match:
|
||||
answer_content = answer_match.group(1).strip()
|
||||
tmp_chat = messages + [{"role": "assistant", "content": answer_content}]
|
||||
answer_score = reward_model.get_score(reward_tokenizer, tmp_chat)
|
||||
answer_score = max(min(answer_score, scale), -scale)
|
||||
score = score * 0.4 + answer_score * 0.6
|
||||
|
||||
reward_model_scores.append(score)
|
||||
|
||||
reward_model_scores = torch.tensor(reward_model_scores, device=args.device)
|
||||
rewards += reward_model_scores
|
||||
|
||||
return rewards
|
||||
|
||||
|
||||
def get_per_token_logps(mdl, input_ids, n_keep):
|
||||
"""计算每个token的log概率"""
|
||||
# CISPO 需要多次计算 logps,必须保留梯度
|
||||
if not mdl.training:
|
||||
with torch.no_grad():
|
||||
logits = mdl(input_ids, logits_to_keep=n_keep + 1).logits[:, :-1, :]
|
||||
else:
|
||||
logits = mdl(input_ids, logits_to_keep=n_keep + 1).logits[:, :-1, :]
|
||||
|
||||
per_token_logps = []
|
||||
for logits_row, ids_row in zip(logits, input_ids[:, -n_keep:]):
|
||||
per_token_logps.append(torch.gather(logits_row.log_softmax(dim=-1), 1, ids_row.unsqueeze(1)).squeeze(1))
|
||||
return torch.stack(per_token_logps)
|
||||
|
||||
def cispo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_tokenizer, start_step=0, wandb=None):
|
||||
for step, batch in enumerate(loader, start=start_step + 1):
|
||||
prompts = batch['prompt']
|
||||
prompt_inputs = tokenizer(prompts, return_tensors="pt", padding=True, return_token_type_ids=False,
|
||||
padding_side="left", add_special_tokens=False).to(args.device)
|
||||
|
||||
if args.max_seq_len:
|
||||
prompt_inputs["input_ids"] = prompt_inputs["input_ids"][:, -args.max_seq_len:]
|
||||
prompt_inputs["attention_mask"] = prompt_inputs["attention_mask"][:, -args.max_seq_len:]
|
||||
|
||||
# ========== 1. 采样 (Sampling) ==========
|
||||
with torch.no_grad():
|
||||
model_for_gen = model.module if isinstance(model, DistributedDataParallel) else model
|
||||
outputs = model_for_gen.generate(
|
||||
**prompt_inputs, max_new_tokens=args.max_gen_len, do_sample=True, temperature=0.8,
|
||||
num_return_sequences=args.num_generations, pad_token_id=tokenizer.pad_token_id)
|
||||
|
||||
completion_ids = outputs[:, prompt_inputs["input_ids"].size(1):]
|
||||
|
||||
# ========== 2. 准备 Reference 和 Old Policy ==========
|
||||
with torch.no_grad():
|
||||
# 计算参考模型的 logps (用于 KL 惩罚)
|
||||
ref_per_token_logps = get_per_token_logps(ref_model, outputs, completion_ids.size(1))
|
||||
# 计算当前(旧)策略的 logps (用于计算 Ratio,在 Inner Loop 中保持不变)
|
||||
old_per_token_logps = get_per_token_logps(model, outputs, completion_ids.size(1))
|
||||
|
||||
# ========== 3. 计算奖励 (Reward & Advantage) ==========
|
||||
completions = tokenizer.batch_decode(completion_ids, skip_special_tokens=True)
|
||||
rewards = calculate_rewards(prompts, completions, reward_model, reward_tokenizer).to(args.device)
|
||||
|
||||
grouped_rewards = rewards.view(-1, args.num_generations)
|
||||
mean_r = grouped_rewards.mean(dim=1).repeat_interleave(args.num_generations)
|
||||
std_r = grouped_rewards.std(dim=1).repeat_interleave(args.num_generations)
|
||||
|
||||
# 归一化优势
|
||||
advantages = torch.clamp((rewards - mean_r) / (std_r + 1e-4), -10, 10)
|
||||
# 此处可以做全局归一化
|
||||
# advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
|
||||
|
||||
# 准备 Mask
|
||||
is_eos = completion_ids == tokenizer.eos_token_id
|
||||
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=args.device)
|
||||
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
|
||||
completion_mask = (torch.arange(is_eos.size(1), device=args.device).expand(is_eos.size(0), -1) <= eos_idx.unsqueeze(1)).int()
|
||||
|
||||
# ========== 4. CISPO Inner Loop (多步更新) ==========
|
||||
# 循环多次,Ratio 才会发生变化,Clipping 才会生效
|
||||
for _ in range(args.ppo_epochs):
|
||||
# 获取当前模型的 logps (带有梯度)
|
||||
per_token_logps = get_per_token_logps(model, outputs, completion_ids.size(1))
|
||||
|
||||
# 计算 Ratio
|
||||
# ratio = exp(curr - old)
|
||||
ratio = torch.exp(per_token_logps - old_per_token_logps)
|
||||
|
||||
|
||||
# 计算裁剪统计信息 (用于日志记录)
|
||||
# 找到高于上限和低于下限的部分
|
||||
over_upper = (ratio > (1.0 + args.clip_ratio)).float()
|
||||
under_lower = (ratio < (1.0 - args.clip_ratio)).float()
|
||||
|
||||
# 计算总的被裁剪比例 (Token 级别)
|
||||
clip_fraction = (over_upper + under_lower).mean().item()
|
||||
|
||||
# 分别查看上溢和下溢(有助于分析模型是变激进了还是变保守了)
|
||||
upper_fraction = over_upper.mean().item()
|
||||
lower_fraction = under_lower.mean().item()
|
||||
|
||||
|
||||
# --- CISPO 核心逻辑 ---
|
||||
# 1. 计算裁剪后的权重系数 (Weight Clipping)
|
||||
# 关键:必须 .detach(),使其变为常数系数,而不是目标函数的一部分
|
||||
clipped_ratio = torch.clamp(ratio, 1.0 - args.clip_ratio, 1.0 + args.clip_ratio).detach()
|
||||
|
||||
# 2. 计算 KL 惩罚 (Token-level)
|
||||
kl_div = ref_per_token_logps - per_token_logps
|
||||
per_token_kl = torch.exp(kl_div) - kl_div - 1
|
||||
|
||||
# 3. 计算 Loss
|
||||
# 公式: L = - (Clipped_Weight * log_pi * A) + beta * KL
|
||||
# 这里利用 pytorch 的自动求导:nabla(log_pi) * A * Weight
|
||||
cispo_loss = - (clipped_ratio * per_token_logps * advantages.unsqueeze(1)) + args.beta * per_token_kl
|
||||
# --- CISPO 核心逻辑 End ---
|
||||
|
||||
# Mask 和 聚合
|
||||
loss = ((cispo_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() / args.accumulation_steps
|
||||
|
||||
loss.backward()
|
||||
|
||||
if (step + 1) % args.accumulation_steps == 0:
|
||||
if args.grad_clip > 0:
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if step % args.log_interval == 0 or step == iters:
|
||||
# 记录日志
|
||||
policy_loss_val = loss.item() * args.accumulation_steps
|
||||
avg_reward_val = rewards.mean().item()
|
||||
avg_len_val = completion_mask.sum(dim=1).float().mean().item()
|
||||
current_lr = optimizer.param_groups[0]['lr']
|
||||
|
||||
Logger(f'Epoch: {epoch+1}, Step: {step}/{iters}, '
|
||||
f'CISPO Loss: {policy_loss_val:.6f}, Reward: {avg_reward_val:.6f}, '
|
||||
f'Avg Response Len: {avg_len_val:.2f}, LR: {current_lr:.2e}')
|
||||
|
||||
if wandb and is_main_process():
|
||||
wandb.log({
|
||||
"policy_loss": policy_loss_val,
|
||||
"reward": avg_reward_val,
|
||||
"avg_response_len": avg_len_val,
|
||||
"advantages_mean": advantages.mean().item(),
|
||||
"learning_rate": current_lr,
|
||||
"ratio": ratio.mean().item(),
|
||||
"cispo/clip_fraction": clip_fraction, # 总裁剪率
|
||||
"cispo/upper_clip_fraction": upper_fraction, # 上裁剪率(由于当前动作概率大幅增加导致)
|
||||
"cispo/lower_clip_fraction": lower_fraction, # 下裁剪率(由于当前动作概率大幅降低导致)
|
||||
})
|
||||
|
||||
# 保存模型的代码保持不变
|
||||
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
|
||||
model.eval()
|
||||
moe_suffix = '_moe' if lm_config.use_moe else ''
|
||||
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
|
||||
state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict()
|
||||
torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
|
||||
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer,
|
||||
epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scheduler=scheduler)
|
||||
model.train()
|
||||
del state_dict
|
||||
|
||||
del prompt_inputs, outputs, completion_ids, per_token_logps, ref_per_token_logps, old_per_token_logps
|
||||
del completions, rewards, grouped_rewards, mean_r, std_r, advantages, completion_mask
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="MiniMind CISPO (Clipped Importance Sampling Policy Optimization)")
|
||||
# 原有参数保持不变
|
||||
parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
|
||||
parser.add_argument('--save_weight', default='cispo', type=str, help="保存权重的前缀名")
|
||||
parser.add_argument("--epochs", type=int, default=1, help="训练轮数")
|
||||
parser.add_argument("--batch_size", type=int, default=2, help="batch size")
|
||||
parser.add_argument("--learning_rate", type=float, default=8e-8, help="初始学习率")
|
||||
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
|
||||
parser.add_argument("--num_workers", type=int, default=1, help="数据加载线程数")
|
||||
parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数")
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
|
||||
parser.add_argument("--log_interval", type=int, default=1, help="日志打印间隔")
|
||||
parser.add_argument("--save_interval", type=int, default=10, help="模型保存间隔")
|
||||
parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
|
||||
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
|
||||
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)")
|
||||
parser.add_argument('--max_seq_len', default=66, type=int, help="Prompt最大长度")
|
||||
parser.add_argument("--max_gen_len", type=int, default=1536, help="生成的最大长度")
|
||||
parser.add_argument("--data_path", type=str, default="../dataset/rlaif-mini.jsonl", help="RLAIF数据路径")
|
||||
parser.add_argument("--num_generations", type=int, default=8, help="每个prompt生成的样本数")
|
||||
parser.add_argument("--beta", type=float, default=0.02, help="KL惩罚系数")
|
||||
parser.add_argument("--reasoning", type=int, default=1, choices=[0, 1], help='推理模型类型')
|
||||
parser.add_argument("--reward_model_path", type=str, default="../../internlm2-1_8b-reward", help="Reward模型路径")
|
||||
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训")
|
||||
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
|
||||
parser.add_argument("--wandb_project", type=str, default="MiniMind-CISPO", help="wandb项目名")
|
||||
|
||||
# === CISPO 新增参数 ===
|
||||
parser.add_argument("--clip_ratio", type=float, default=0.2, help="CISPO/PPO 裁剪系数")
|
||||
parser.add_argument("--ppo_epochs", type=int, default=1, help="每个Batch的更新次数 (Inner Loop)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 后续初始化代码与原 GRPO 代码一致,只需将 train_epoch 调用替换为 cispo_train_epoch
|
||||
# ========== 1. 初始化环境 ==========
|
||||
local_rank = init_distributed_mode()
|
||||
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
|
||||
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
|
||||
|
||||
os.makedirs(args.save_dir, exist_ok=True)
|
||||
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers,
|
||||
max_seq_len=args.max_seq_len + args.max_gen_len, use_moe=bool(args.use_moe))
|
||||
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
|
||||
|
||||
device_type = "cuda" if "cuda" in args.device else "cpu"
|
||||
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
|
||||
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
|
||||
|
||||
wandb = None
|
||||
if args.use_wandb and is_main_process():
|
||||
import swanlab as wandb
|
||||
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
|
||||
resume = 'must' if wandb_id else None
|
||||
wandb_run_name = f"MiniMind-CISPO-Epoch-{args.epochs}"
|
||||
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume, mode="local")
|
||||
|
||||
base_weight = "reason" if args.reasoning == 1 else "full_sft"
|
||||
model, tokenizer = init_model(lm_config, base_weight, device=args.device)
|
||||
ref_model, _ = init_model(lm_config, base_weight, device=args.device)
|
||||
ref_model = ref_model.eval().requires_grad_(False)
|
||||
|
||||
reward_model = AutoModel.from_pretrained(args.reward_model_path, torch_dtype=torch.float16, trust_remote_code=True)
|
||||
reward_model = reward_model.to(args.device).eval().requires_grad_(False)
|
||||
reward_tokenizer = AutoTokenizer.from_pretrained(args.reward_model_path, trust_remote_code=True)
|
||||
|
||||
train_ds = RLAIFDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
|
||||
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
|
||||
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
|
||||
loader_for_count = DataLoader(train_ds, batch_size=args.batch_size, sampler=train_sampler)
|
||||
iters = len(loader_for_count)
|
||||
total_optimizer_steps = (iters // args.accumulation_steps) * args.epochs * args.ppo_epochs # 注意这里total steps变了
|
||||
scheduler = CosineAnnealingLR(optimizer, T_max=total_optimizer_steps, eta_min=args.learning_rate / 10)
|
||||
|
||||
start_epoch, start_step = 0, 0
|
||||
if ckp_data:
|
||||
model.load_state_dict(ckp_data['model'])
|
||||
optimizer.load_state_dict(ckp_data['optimizer'])
|
||||
scheduler.load_state_dict(ckp_data['scheduler'])
|
||||
start_epoch = ckp_data['epoch']
|
||||
start_step = ckp_data.get('step', 0)
|
||||
|
||||
if dist.is_initialized():
|
||||
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
||||
model = DistributedDataParallel(model, device_ids=[local_rank])
|
||||
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
train_sampler and train_sampler.set_epoch(epoch)
|
||||
if epoch == start_epoch and start_step > 0:
|
||||
batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1)
|
||||
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
cispo_train_epoch(epoch, loader, len(loader) + start_step + 1, ref_model, reward_model, reward_tokenizer, start_step, wandb)
|
||||
else:
|
||||
loader = DataLoader(train_ds, batch_size=args.batch_size, pin_memory=True, drop_last=False, shuffle=(train_sampler is None), num_workers=args.num_workers, sampler=train_sampler)
|
||||
cispo_train_epoch(epoch, loader, len(loader), ref_model, reward_model, reward_tokenizer, 0, wandb)
|
||||
Loading…
Reference in New Issue
Block a user