[add] add CISPO (Clipped Importance Sampling Policy Optimization) algorithm

This the main addition reference GRPO algorithm: By changing the gradient of out-of-bounds tokens from "directly set to 0" to "bounded clipping", we ensure that high-value exploration tokens can continue to participate in parameter updates while maintaining training stability.
Actually, this change is not very well in experiment, because the ratio is almost nearly at 1, which means seldom out-of-bound.
This commit is contained in:
vanking 2026-02-02 22:23:27 +08:00 committed by GitHub
parent 7389f64dee
commit ce9ed24dcd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

343
train_cispo.py Normal file
View File

@ -0,0 +1,343 @@
import os
import sys
__package__ = "trainer"
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import argparse
import re
import gc
import warnings
import torch
import torch.distributed as dist
from transformers import AutoTokenizer
from contextlib import nullcontext
from torch import optim
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, DistributedSampler
from torch.optim.lr_scheduler import CosineAnnealingLR
from transformers import AutoModel
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
from dataset.lm_dataset import RLAIFDataset
from trainer.trainer_utils import Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, SkipBatchSampler, init_model
warnings.filterwarnings('ignore')
def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
"""整合所有奖励函数计算总奖励"""
def reasoning_model_reward(rewards):
pattern = r"^<think>\n.*?\n</think>\n<answer>\n.*?\n</answer>$"
pattern2 = r"^<think>\n.*?\n</think>\n\n<answer>\n.*?\n</answer>$"
matches_pattern = [re.match(pattern, response, re.S) for response in responses]
matches_pattern2 = [re.match(pattern2, response, re.S) for response in responses]
format_rewards = []
for match_pattern, match_pattern2 in zip(matches_pattern, matches_pattern2):
if match_pattern or match_pattern2:
format_rewards.append(0.5)
else:
format_rewards.append(0.0)
rewards += torch.tensor(format_rewards, device=args.device)
def mark_num(text):
reward = 0
if text.count("<think>") == 1: reward += 0.25
if text.count("</think>") == 1: reward += 0.25
if text.count("<answer>") == 1: reward += 0.25
if text.count("</answer>") == 1: reward += 0.25
return reward
mark_rewards = [mark_num(response) for response in responses]
rewards += torch.tensor(mark_rewards, device=args.device)
return rewards
rewards = torch.zeros(len(responses), device=args.device)
if args.reasoning == 1:
rewards = reasoning_model_reward(rewards)
with torch.no_grad():
reward_model_scores = []
batch_size = len(prompts)
scale = 3.0
for i in range(batch_size):
for j in range(args.num_generations):
response_idx = i * args.num_generations + j
response = responses[response_idx]
prompt = prompts[i]
pattern = r"<\|im_start\|>(system|user|assistant)\s+(.*?)<\|im_end\|>"
matches = re.findall(pattern, prompt, re.DOTALL)
messages = [{"role": role, "content": content.strip()} for role, content in matches]
tmp_chat = messages + [{"role": "assistant", "content": response}]
score = reward_model.get_score(reward_tokenizer, tmp_chat)
score = max(min(score, scale), -scale)
if args.reasoning == 1:
answer_match = re.search(r'<answer>(.*?)</answer>', response, re.DOTALL)
if answer_match:
answer_content = answer_match.group(1).strip()
tmp_chat = messages + [{"role": "assistant", "content": answer_content}]
answer_score = reward_model.get_score(reward_tokenizer, tmp_chat)
answer_score = max(min(answer_score, scale), -scale)
score = score * 0.4 + answer_score * 0.6
reward_model_scores.append(score)
reward_model_scores = torch.tensor(reward_model_scores, device=args.device)
rewards += reward_model_scores
return rewards
def get_per_token_logps(mdl, input_ids, n_keep):
"""计算每个token的log概率"""
# CISPO 需要多次计算 logps必须保留梯度
if not mdl.training:
with torch.no_grad():
logits = mdl(input_ids, logits_to_keep=n_keep + 1).logits[:, :-1, :]
else:
logits = mdl(input_ids, logits_to_keep=n_keep + 1).logits[:, :-1, :]
per_token_logps = []
for logits_row, ids_row in zip(logits, input_ids[:, -n_keep:]):
per_token_logps.append(torch.gather(logits_row.log_softmax(dim=-1), 1, ids_row.unsqueeze(1)).squeeze(1))
return torch.stack(per_token_logps)
def cispo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_tokenizer, start_step=0, wandb=None):
for step, batch in enumerate(loader, start=start_step + 1):
prompts = batch['prompt']
prompt_inputs = tokenizer(prompts, return_tensors="pt", padding=True, return_token_type_ids=False,
padding_side="left", add_special_tokens=False).to(args.device)
if args.max_seq_len:
prompt_inputs["input_ids"] = prompt_inputs["input_ids"][:, -args.max_seq_len:]
prompt_inputs["attention_mask"] = prompt_inputs["attention_mask"][:, -args.max_seq_len:]
# ========== 1. 采样 (Sampling) ==========
with torch.no_grad():
model_for_gen = model.module if isinstance(model, DistributedDataParallel) else model
outputs = model_for_gen.generate(
**prompt_inputs, max_new_tokens=args.max_gen_len, do_sample=True, temperature=0.8,
num_return_sequences=args.num_generations, pad_token_id=tokenizer.pad_token_id)
completion_ids = outputs[:, prompt_inputs["input_ids"].size(1):]
# ========== 2. 准备 Reference 和 Old Policy ==========
with torch.no_grad():
# 计算参考模型的 logps (用于 KL 惩罚)
ref_per_token_logps = get_per_token_logps(ref_model, outputs, completion_ids.size(1))
# 计算当前(旧)策略的 logps (用于计算 Ratio在 Inner Loop 中保持不变)
old_per_token_logps = get_per_token_logps(model, outputs, completion_ids.size(1))
# ========== 3. 计算奖励 (Reward & Advantage) ==========
completions = tokenizer.batch_decode(completion_ids, skip_special_tokens=True)
rewards = calculate_rewards(prompts, completions, reward_model, reward_tokenizer).to(args.device)
grouped_rewards = rewards.view(-1, args.num_generations)
mean_r = grouped_rewards.mean(dim=1).repeat_interleave(args.num_generations)
std_r = grouped_rewards.std(dim=1).repeat_interleave(args.num_generations)
# 归一化优势
advantages = torch.clamp((rewards - mean_r) / (std_r + 1e-4), -10, 10)
# 此处可以做全局归一化
# advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
# 准备 Mask
is_eos = completion_ids == tokenizer.eos_token_id
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=args.device)
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
completion_mask = (torch.arange(is_eos.size(1), device=args.device).expand(is_eos.size(0), -1) <= eos_idx.unsqueeze(1)).int()
# ========== 4. CISPO Inner Loop (多步更新) ==========
# 循环多次Ratio 才会发生变化Clipping 才会生效
for _ in range(args.ppo_epochs):
# 获取当前模型的 logps (带有梯度)
per_token_logps = get_per_token_logps(model, outputs, completion_ids.size(1))
# 计算 Ratio
# ratio = exp(curr - old)
ratio = torch.exp(per_token_logps - old_per_token_logps)
# 计算裁剪统计信息 (用于日志记录)
# 找到高于上限和低于下限的部分
over_upper = (ratio > (1.0 + args.clip_ratio)).float()
under_lower = (ratio < (1.0 - args.clip_ratio)).float()
# 计算总的被裁剪比例 (Token 级别)
clip_fraction = (over_upper + under_lower).mean().item()
# 分别查看上溢和下溢(有助于分析模型是变激进了还是变保守了)
upper_fraction = over_upper.mean().item()
lower_fraction = under_lower.mean().item()
# --- CISPO 核心逻辑 ---
# 1. 计算裁剪后的权重系数 (Weight Clipping)
# 关键:必须 .detach(),使其变为常数系数,而不是目标函数的一部分
clipped_ratio = torch.clamp(ratio, 1.0 - args.clip_ratio, 1.0 + args.clip_ratio).detach()
# 2. 计算 KL 惩罚 (Token-level)
kl_div = ref_per_token_logps - per_token_logps
per_token_kl = torch.exp(kl_div) - kl_div - 1
# 3. 计算 Loss
# 公式: L = - (Clipped_Weight * log_pi * A) + beta * KL
# 这里利用 pytorch 的自动求导nabla(log_pi) * A * Weight
cispo_loss = - (clipped_ratio * per_token_logps * advantages.unsqueeze(1)) + args.beta * per_token_kl
# --- CISPO 核心逻辑 End ---
# Mask 和 聚合
loss = ((cispo_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() / args.accumulation_steps
loss.backward()
if (step + 1) % args.accumulation_steps == 0:
if args.grad_clip > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
torch.cuda.empty_cache()
if step % args.log_interval == 0 or step == iters:
# 记录日志
policy_loss_val = loss.item() * args.accumulation_steps
avg_reward_val = rewards.mean().item()
avg_len_val = completion_mask.sum(dim=1).float().mean().item()
current_lr = optimizer.param_groups[0]['lr']
Logger(f'Epoch: {epoch+1}, Step: {step}/{iters}, '
f'CISPO Loss: {policy_loss_val:.6f}, Reward: {avg_reward_val:.6f}, '
f'Avg Response Len: {avg_len_val:.2f}, LR: {current_lr:.2e}')
if wandb and is_main_process():
wandb.log({
"policy_loss": policy_loss_val,
"reward": avg_reward_val,
"avg_response_len": avg_len_val,
"advantages_mean": advantages.mean().item(),
"learning_rate": current_lr,
"ratio": ratio.mean().item(),
"cispo/clip_fraction": clip_fraction, # 总裁剪率
"cispo/upper_clip_fraction": upper_fraction, # 上裁剪率(由于当前动作概率大幅增加导致)
"cispo/lower_clip_fraction": lower_fraction, # 下裁剪率(由于当前动作概率大幅降低导致)
})
# 保存模型的代码保持不变
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
model.eval()
moe_suffix = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict()
torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer,
epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scheduler=scheduler)
model.train()
del state_dict
del prompt_inputs, outputs, completion_ids, per_token_logps, ref_per_token_logps, old_per_token_logps
del completions, rewards, grouped_rewards, mean_r, std_r, advantages, completion_mask
torch.cuda.empty_cache()
gc.collect()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MiniMind CISPO (Clipped Importance Sampling Policy Optimization)")
# 原有参数保持不变
parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
parser.add_argument('--save_weight', default='cispo', type=str, help="保存权重的前缀名")
parser.add_argument("--epochs", type=int, default=1, help="训练轮数")
parser.add_argument("--batch_size", type=int, default=2, help="batch size")
parser.add_argument("--learning_rate", type=float, default=8e-8, help="初始学习率")
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
parser.add_argument("--num_workers", type=int, default=1, help="数据加载线程数")
parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数")
parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
parser.add_argument("--log_interval", type=int, default=1, help="日志打印间隔")
parser.add_argument("--save_interval", type=int, default=10, help="模型保存间隔")
parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构0=否1=是)")
parser.add_argument('--max_seq_len', default=66, type=int, help="Prompt最大长度")
parser.add_argument("--max_gen_len", type=int, default=1536, help="生成的最大长度")
parser.add_argument("--data_path", type=str, default="../dataset/rlaif-mini.jsonl", help="RLAIF数据路径")
parser.add_argument("--num_generations", type=int, default=8, help="每个prompt生成的样本数")
parser.add_argument("--beta", type=float, default=0.02, help="KL惩罚系数")
parser.add_argument("--reasoning", type=int, default=1, choices=[0, 1], help='推理模型类型')
parser.add_argument("--reward_model_path", type=str, default="../../internlm2-1_8b-reward", help="Reward模型路径")
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训")
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
parser.add_argument("--wandb_project", type=str, default="MiniMind-CISPO", help="wandb项目名")
# === CISPO 新增参数 ===
parser.add_argument("--clip_ratio", type=float, default=0.2, help="CISPO/PPO 裁剪系数")
parser.add_argument("--ppo_epochs", type=int, default=1, help="每个Batch的更新次数 (Inner Loop)")
args = parser.parse_args()
# 后续初始化代码与原 GRPO 代码一致,只需将 train_epoch 调用替换为 cispo_train_epoch
# ========== 1. 初始化环境 ==========
local_rank = init_distributed_mode()
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
os.makedirs(args.save_dir, exist_ok=True)
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers,
max_seq_len=args.max_seq_len + args.max_gen_len, use_moe=bool(args.use_moe))
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
device_type = "cuda" if "cuda" in args.device else "cpu"
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
wandb = None
if args.use_wandb and is_main_process():
import swanlab as wandb
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
resume = 'must' if wandb_id else None
wandb_run_name = f"MiniMind-CISPO-Epoch-{args.epochs}"
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume, mode="local")
base_weight = "reason" if args.reasoning == 1 else "full_sft"
model, tokenizer = init_model(lm_config, base_weight, device=args.device)
ref_model, _ = init_model(lm_config, base_weight, device=args.device)
ref_model = ref_model.eval().requires_grad_(False)
reward_model = AutoModel.from_pretrained(args.reward_model_path, torch_dtype=torch.float16, trust_remote_code=True)
reward_model = reward_model.to(args.device).eval().requires_grad_(False)
reward_tokenizer = AutoTokenizer.from_pretrained(args.reward_model_path, trust_remote_code=True)
train_ds = RLAIFDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
loader_for_count = DataLoader(train_ds, batch_size=args.batch_size, sampler=train_sampler)
iters = len(loader_for_count)
total_optimizer_steps = (iters // args.accumulation_steps) * args.epochs * args.ppo_epochs # 注意这里total steps变了
scheduler = CosineAnnealingLR(optimizer, T_max=total_optimizer_steps, eta_min=args.learning_rate / 10)
start_epoch, start_step = 0, 0
if ckp_data:
model.load_state_dict(ckp_data['model'])
optimizer.load_state_dict(ckp_data['optimizer'])
scheduler.load_state_dict(ckp_data['scheduler'])
start_epoch = ckp_data['epoch']
start_step = ckp_data.get('step', 0)
if dist.is_initialized():
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
model = DistributedDataParallel(model, device_ids=[local_rank])
for epoch in range(start_epoch, args.epochs):
train_sampler and train_sampler.set_epoch(epoch)
if epoch == start_epoch and start_step > 0:
batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1)
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
cispo_train_epoch(epoch, loader, len(loader) + start_step + 1, ref_model, reward_model, reward_tokenizer, start_step, wandb)
else:
loader = DataLoader(train_ds, batch_size=args.batch_size, pin_memory=True, drop_last=False, shuffle=(train_sampler is None), num_workers=args.num_workers, sampler=train_sampler)
cispo_train_epoch(epoch, loader, len(loader), ref_model, reward_model, reward_tokenizer, 0, wandb)