[add] add DAPO argorithm (Decoupled Clip and Dynamic sAmpling Policy Optimization)

I mainly add two points  into original grpo algorithm according to this [paper](https://arxiv.org/pdf/2503.14476): Clip-Higher & Dynamic Sampling
This commit is contained in:
vanking 2026-02-02 22:16:52 +08:00 committed by GitHub
parent 35fe1399c8
commit 7389f64dee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

351
train_dapo.py Normal file
View File

@ -0,0 +1,351 @@
import os
import sys
__package__ = "trainer"
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import argparse
import re
import gc
import warnings
import torch
import torch.distributed as dist
from transformers import AutoTokenizer
from contextlib import nullcontext
from torch import optim
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, DistributedSampler
from torch.optim.lr_scheduler import CosineAnnealingLR
from transformers import AutoModel
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
from dataset.lm_dataset import RLAIFDataset
from trainer.trainer_utils import Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, SkipBatchSampler, init_model
warnings.filterwarnings('ignore')
def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
"""整合所有奖励函数计算总奖励"""
def reasoning_model_reward(rewards):
pattern = r"^<think>\n.*?\n</think>\n<answer>\n.*?\n</answer>$"
pattern2 = r"^<think>\n.*?\n</think>\n\n<answer>\n.*?\n</answer>$"
matches_pattern = [re.match(pattern, response, re.S) for response in responses]
matches_pattern2 = [re.match(pattern2, response, re.S) for response in responses]
format_rewards = []
for match_pattern, match_pattern2 in zip(matches_pattern, matches_pattern2):
if match_pattern or match_pattern2:
format_rewards.append(0.5)
else:
format_rewards.append(0.0)
rewards += torch.tensor(format_rewards, device=args.device)
def mark_num(text):
reward = 0
if text.count("<think>") == 1: reward += 0.25
if text.count("</think>") == 1: reward += 0.25
if text.count("<answer>") == 1: reward += 0.25
if text.count("</answer>") == 1: reward += 0.25
return reward
mark_rewards = [mark_num(response) for response in responses]
rewards += torch.tensor(mark_rewards, device=args.device)
return rewards
rewards = torch.zeros(len(responses), device=args.device)
if args.reasoning == 1:
rewards = reasoning_model_reward(rewards)
with torch.no_grad():
reward_model_scores = []
batch_size = len(prompts)
scale = 3.0
for i in range(batch_size):
for j in range(args.num_generations):
response_idx = i * args.num_generations + j
response = responses[response_idx]
prompt = prompts[i]
pattern = r"<\|im_start\|>(system|user|assistant)\s+(.*?)<\|im_end\|>"
matches = re.findall(pattern, prompt, re.DOTALL)
messages = [{"role": role, "content": content.strip()} for role, content in matches]
tmp_chat = messages + [{"role": "assistant", "content": response}]
score = reward_model.get_score(reward_tokenizer, tmp_chat)
score = max(min(score, scale), -scale)
if args.reasoning == 1:
answer_match = re.search(r'<answer>(.*?)</answer>', response, re.DOTALL)
if answer_match:
answer_content = answer_match.group(1).strip()
tmp_chat = messages + [{"role": "assistant", "content": answer_content}]
answer_score = reward_model.get_score(reward_tokenizer, tmp_chat)
answer_score = max(min(answer_score, scale), -scale)
score = score * 0.4 + answer_score * 0.6
reward_model_scores.append(score)
reward_model_scores = torch.tensor(reward_model_scores, device=args.device)
rewards += reward_model_scores
return rewards
def get_per_token_logps(mdl, input_ids, n_keep):
"""计算每个token的log概率"""
# 确保在计算新策略logp时保留梯度计算旧策略时detach
if not mdl.training:
with torch.no_grad():
logits = mdl(input_ids, logits_to_keep=n_keep + 1).logits[:, :-1, :]
else:
logits = mdl(input_ids, logits_to_keep=n_keep + 1).logits[:, :-1, :]
per_token_logps = []
for logits_row, ids_row in zip(logits, input_ids[:, -n_keep:]):
# 即使是inference模式如果ids_row需要梯度(通常不需要)也需注意。这里假设ids_row是生成的idx
per_token_logps.append(torch.gather(logits_row.log_softmax(dim=-1), 1, ids_row.unsqueeze(1)).squeeze(1))
return torch.stack(per_token_logps)
def grpo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_tokenizer, start_step=0, wandb=None):
for step, batch in enumerate(loader, start=start_step + 1):
prompts = batch['prompt'] # list[str], length B
prompt_inputs = tokenizer(prompts, return_tensors="pt", padding=True, return_token_type_ids=False,
padding_side="left", add_special_tokens=False).to(args.device)
if args.max_seq_len:
prompt_inputs["input_ids"] = prompt_inputs["input_ids"][:, -args.max_seq_len:]
prompt_inputs["attention_mask"] = prompt_inputs["attention_mask"][:, -args.max_seq_len:]
# ========== 采样 (Sampling) ==========
with torch.no_grad():
model_for_gen = model.module if isinstance(model, DistributedDataParallel) else model
outputs = model_for_gen.generate(
**prompt_inputs, max_new_tokens=args.max_gen_len, do_sample=True, temperature=0.8,
num_return_sequences=args.num_generations, pad_token_id=tokenizer.pad_token_id)
completion_ids = outputs[:, prompt_inputs["input_ids"].size(1):]
# 计算旧策略的Log Probs (Old Policy) - Detached
with torch.no_grad():
old_per_token_logps = get_per_token_logps(model, outputs, completion_ids.size(1))
ref_per_token_logps = get_per_token_logps(ref_model, outputs, completion_ids.size(1))
completions = tokenizer.batch_decode(completion_ids, skip_special_tokens=True)
rewards = calculate_rewards(prompts, completions, reward_model, reward_tokenizer).to(args.device)
# ========== 优势计算 (Advantage Computation) ==========
grouped_rewards = rewards.view(-1, args.num_generations) # [B, num_gen]
mean_r = grouped_rewards.mean(dim=1)
std_r = grouped_rewards.std(dim=1)
# DAPO Feature: Dynamic Sampling (过滤掉 reward 方差为 0 的组)
# 如果 std 为 0说明该组所有样本奖励相同无法提供相对优势视为无效样本
valid_group_mask = std_r > 0.3 # 如果取小值1e-6在本项目中偏小没有起到过滤效果实际测试std_r均值绝大部分位于(0.3, 0.9)区间
# 将 valid_group_mask 扩展到 [B * num_gen]
valid_sample_mask = valid_group_mask.repeat_interleave(args.num_generations)
# 标准 GRPO 优势计算
mean_r = mean_r.repeat_interleave(args.num_generations)
std_r = std_r.repeat_interleave(args.num_generations)
advantages = torch.clamp((rewards - mean_r) / (std_r + 1e-4), -10, 10)
# 仅对有效样本进行 Normalize (可选防止全0拉低均值这里保持简单全局归一化)
if valid_sample_mask.sum() > 0:
advantages = (advantages - advantages[valid_sample_mask].mean()) / (advantages[valid_sample_mask].std() + 1e-8)
else:
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
# 处理 Padding Mask
is_eos = completion_ids == tokenizer.eos_token_id
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=args.device)
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
completion_mask = (torch.arange(is_eos.size(1), device=args.device).expand(is_eos.size(0), -1) <= eos_idx.unsqueeze(1)).int()
# 结合 Dynamic Sampling Mask
loss_mask = completion_mask * valid_sample_mask.unsqueeze(1).int()
# ========== 策略优化 (Policy Optimization) ==========
# DAPO/PPO Inner Loop
for _ in range(args.ppo_epochs):
per_token_logps = get_per_token_logps(model, outputs, completion_ids.size(1))
# 计算 Ratio
ratio = torch.exp(per_token_logps - old_per_token_logps) # [B*num_gen, R]
# KL Divergence Penalty (近似)
kl_div = ref_per_token_logps - per_token_logps
per_token_kl = torch.exp(kl_div) - kl_div - 1
# DAPO Feature: Decoupled Clip (解耦裁剪)
# Clip-Higher: 使用不同的 epsilon_low 和 epsilon_high
surr1 = ratio * advantages.unsqueeze(1)
surr2 = torch.clamp(ratio, 1.0 - args.clip_ratio_low, 1.0 + args.clip_ratio_high) * advantages.unsqueeze(1)
# DAPO Loss: Min(Surr1, Surr2) - Beta * KL
per_token_loss = -torch.min(surr1, surr2) + args.beta * per_token_kl
# 聚合 Loss (Token-level mean)
loss = ((per_token_loss * loss_mask).sum(dim=1) / (loss_mask.sum(dim=1) + 1e-8)).mean() / args.accumulation_steps
loss.backward()
if (step + 1) % args.accumulation_steps == 0:
if args.grad_clip > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
torch.cuda.empty_cache()
if step % args.log_interval == 0 or step == iters:
policy_loss_val = loss.item() * args.accumulation_steps # 还原数值
avg_reward_val = rewards[valid_sample_mask].mean().item() if valid_sample_mask.sum() > 0 else 0.0
avg_len_val = completion_mask.sum(dim=1).float().mean().item()
current_lr = optimizer.param_groups[0]['lr']
valid_ratio = valid_group_mask.float().mean().item()
Logger(f'Epoch: {epoch+1}, Step: {step}/{iters}, '
f'Loss: {policy_loss_val:.6f}, Reward: {avg_reward_val:.6f}, '
f'Valid Batch Ratio: {valid_ratio:.2f}, Avg Len: {avg_len_val:.2f}, LR: {current_lr:.2e}')
if wandb and is_main_process():
wandb.log({
"policy_loss": policy_loss_val,
"reward": avg_reward_val,
"valid_group_ratio": valid_ratio,
"avg_response_len": avg_len_val,
"advantages_mean": advantages[valid_sample_mask].mean().item() if valid_sample_mask.sum() > 0 else 0,
"learning_rate": current_lr,
"mean_std": std_r.mean().item()
})
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
model.eval()
moe_suffix = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict()
torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer,
epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scheduler=scheduler)
model.train()
del state_dict
del prompt_inputs, outputs, completion_ids, per_token_logps, ref_per_token_logps, old_per_token_logps
del completions, rewards, grouped_rewards, mean_r, std_r, advantages, completion_mask
torch.cuda.empty_cache()
gc.collect()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MiniMind DAPO (Decoupled Advantage Policy Optimization)")
parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
parser.add_argument('--save_weight', default='dapo', type=str, help="保存权重的前缀名")
parser.add_argument("--epochs", type=int, default=1, help="训练轮数")
parser.add_argument("--batch_size", type=int, default=2, help="batch size")
parser.add_argument("--learning_rate", type=float, default=8e-8, help="初始学习率")
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
parser.add_argument("--num_workers", type=int, default=1, help="数据加载线程数")
parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数")
parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
parser.add_argument("--log_interval", type=int, default=1, help="日志打印间隔")
parser.add_argument("--save_interval", type=int, default=10, help="模型保存间隔")
parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构0=否1=是)")
parser.add_argument('--max_seq_len', default=66, type=int, help="Prompt最大长度")
parser.add_argument("--max_gen_len", type=int, default=1536, help="生成的最大长度")
parser.add_argument("--data_path", type=str, default="../dataset/rlaif-mini.jsonl", help="RLAIF数据路径")
parser.add_argument("--num_generations", type=int, default=8, help="每个prompt生成的样本数")
parser.add_argument("--beta", type=float, default=0.01, help="KL惩罚系数")
parser.add_argument("--reasoning", type=int, default=1, choices=[0, 1], help='推理模型类型')
parser.add_argument("--reward_model_path", type=str, default="../../internlm2-1_8b-reward", help="Reward模型路径")
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训")
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
parser.add_argument("--wandb_project", type=str, default="MiniMind-DAPO", help="wandb项目名")
# DAPO 特有参数
parser.add_argument("--clip_ratio_low", type=float, default=0.2, help="DAPO: 下限裁剪系数 (e.g. 1-0.2)")
parser.add_argument("--clip_ratio_high", type=float, default=0.28, help="DAPO: 上限裁剪系数 (e.g. 1+0.28)")
parser.add_argument("--ppo_epochs", type=int, default=1, help="DAPO: 每个Batch的更新次数")
args = parser.parse_args()
# ========== 1. 初始化环境和随机种子 ==========
local_rank = init_distributed_mode()
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
# ========== 2. 配置目录、模型参数、检查ckp ==========
os.makedirs(args.save_dir, exist_ok=True)
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers,
max_seq_len=args.max_seq_len + args.max_gen_len, use_moe=bool(args.use_moe))
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
# ========== 3. 设置混合精度 ==========
device_type = "cuda" if "cuda" in args.device else "cpu"
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
# ========== 4. 配wandb ==========
wandb = None
if args.use_wandb and is_main_process():
import swanlab as wandb
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
resume = 'must' if wandb_id else None
wandb_run_name = f"MiniMind-DAPO-Epoch-{args.epochs}-BS-{args.batch_size}-LR-{args.learning_rate}"
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume, mode="local")
# ========== 5. 初始化模型和数据 ==========
base_weight = "reason" if args.reasoning == 1 else "full_sft"
# Policy模型
model, tokenizer = init_model(lm_config, base_weight, device=args.device)
# Reference模型
ref_model, _ = init_model(lm_config, base_weight, device=args.device)
ref_model = ref_model.eval().requires_grad_(False)
# Reward模型
reward_model = AutoModel.from_pretrained(
args.reward_model_path, torch_dtype=torch.float16, trust_remote_code=True
)
reward_model = reward_model.to(args.device).eval().requires_grad_(False)
reward_tokenizer = AutoTokenizer.from_pretrained(args.reward_model_path, trust_remote_code=True)
# 数据和优化器
train_ds = RLAIFDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
loader_for_count = DataLoader(train_ds, batch_size=args.batch_size, sampler=train_sampler)
iters = len(loader_for_count)
total_optimizer_steps = (iters // args.accumulation_steps) * args.epochs
scheduler = CosineAnnealingLR(optimizer, T_max=total_optimizer_steps, eta_min=args.learning_rate / 10)
# ========== 6. 从ckp恢复状态 ==========
start_epoch, start_step = 0, 0
if ckp_data:
model.load_state_dict(ckp_data['model'])
optimizer.load_state_dict(ckp_data['optimizer'])
scheduler.load_state_dict(ckp_data['scheduler'])
start_epoch = ckp_data['epoch']
start_step = ckp_data.get('step', 0)
# ========== 7. DDP包模型 ==========
if dist.is_initialized():
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
model = DistributedDataParallel(model, device_ids=[local_rank])
# ========== 8. 开始训练 ==========
for epoch in range(start_epoch, args.epochs):
train_sampler and train_sampler.set_epoch(epoch)
if epoch == start_epoch and start_step > 0:
batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1)
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step从step {start_step + 1}开始')
grpo_train_epoch(epoch, loader, len(loader) + start_step + 1, ref_model, reward_model, reward_tokenizer, start_step, wandb)
else:
loader = DataLoader(train_ds, batch_size=args.batch_size, pin_memory=True,
drop_last=False, shuffle=(train_sampler is None),
num_workers=args.num_workers, sampler=train_sampler)
grpo_train_epoch(epoch, loader, len(loader), ref_model, reward_model, reward_tokenizer, 0, wandb)