[mod] fix spo algorithm in RLAIF part

This commit is contained in:
Your Name 2026-01-30 11:03:35 +08:00
parent a9c56b20e9
commit 020bd44f3f
2 changed files with 217 additions and 81 deletions

View File

@ -195,6 +195,53 @@ class DPODataset(Dataset):
return loss_mask
# 添加SPOdataset
class SPODataset(Dataset):
def __init__(self, jsonl_path, tokenizer, max_length=1024):
super().__init__()
self.tokenizer = tokenizer
self.max_length = max_length
self.samples = self.load_data(jsonl_path)
self.bos_id = tokenizer(f'{tokenizer.bos_token}assistant', add_special_tokens=False).input_ids
self.eos_id = tokenizer(f'{tokenizer.eos_token}', add_special_tokens=False).input_ids
def __len__(self):
return len(self.samples)
def load_data(self, path):
samples = []
with open(path, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(f, 1):
data = json.loads(line.strip())
samples.append(data)
return samples
def _create_chat_prompt(self, conversations):
"""构建符合ChatML格式的对话"""
messages = []
answer = ''
for i, turn in enumerate(conversations):
role = 'user' if i % 2 == 0 else 'assistant'
messages.append({"role": role, "content": turn['content']})
answer = turn['content']
return self.tokenizer.apply_chat_template(
messages[:-1],
tokenize=False,
add_generation_prompt=True # 这里需要True
), answer
def __getitem__(self, index):
sample = self.samples[index]
# 构建对话提示
prompt, answer = self._create_chat_prompt(sample['conversations'])
return {
'prompt': prompt,
'answer': answer,
'index': index # 关键修改:返回索引
}
class RLAIFDataset(Dataset):
def __init__(self, jsonl_path, tokenizer, max_length=1024):
super().__init__()
@ -240,5 +287,6 @@ class RLAIFDataset(Dataset):
}
if __name__ == "__main__":
pass

View File

@ -18,14 +18,47 @@ from torch.utils.data import DataLoader, DistributedSampler
from torch.optim.lr_scheduler import CosineAnnealingLR
from transformers import AutoModel
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
from dataset.lm_dataset import RLAIFDataset
from dataset.lm_dataset import SPODataset
from trainer.trainer_utils import Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, SkipBatchSampler, init_model
warnings.filterwarnings('ignore')
# --- 1. 自定义优先采样器 (带分布式同步) ---
class WeightedDistributedSampler(DistributedSampler):
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0, epsilon=1e-5, device='cpu'):
super().__init__(dataset, num_replicas, rank, shuffle, seed)
# [修改点]:允许指定 weights 存放的设备,默认 CPU
self.weights = torch.ones(len(dataset), dtype=torch.float32).to(device)
self.epsilon = epsilon
self.device = device
def update_weights(self, indices, new_v_estimates):
"""更新权重"""
# 如果 self.weights 在 GPU这里就不再产生同步阻塞
self.weights[indices] = new_v_estimates.to(self.weights.device)
def sync_weights(self):
"""核心:在 Epoch 结束时同步所有卡的权重,防止采样漂移"""
if dist.is_initialized():
dist.all_reduce(self.weights, op=dist.ReduceOp.SUM)
self.weights /= dist.get_world_size()
def __iter__(self):
# 优先级公式sqrt(v * (1-v)) + epsilon
priority = torch.sqrt(self.weights * (1 - self.weights) + 1e-8) + self.epsilon
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
# 全局加权采样
indices = torch.multinomial(priority, self.total_size, replacement=True, generator=g).tolist()
# 分配到当前进程 (Rank)
indices = indices[self.rank:self.total_size:self.num_replicas]
return iter(indices)
# --- 2. SPO 自适应价值追踪器 ---
class AutoAdaptiveValueTracker:
"""SPO自适应价值追踪器"""
def __init__(self, rho_mode='kl', rho_const=0.9, D_half=0.06, clip_lower=0.5, clip_upper=0.96):
self.rho_mode = rho_mode
self.rho_const = rho_const
@ -42,9 +75,7 @@ class AutoAdaptiveValueTracker:
return torch.full((batch_size,), baseline, dtype=torch.float32)
def compute_rho(self, cur_mean_logprob):
if self.rho_mode == 'constant':
return self.rho_const
if self.old_mean_logprob is None:
if self.rho_mode == 'constant' or self.old_mean_logprob is None:
return self.rho_const
kl = abs(self.old_mean_logprob - cur_mean_logprob)
rho = 2 ** (-kl / self.D_half)
@ -64,7 +95,7 @@ class AutoAdaptiveValueTracker:
self.alpha = rho * self.alpha + avg_normalized_reward
self.beta = rho * self.beta + (1 - avg_normalized_reward)
return rho
def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
"""整合所有奖励函数计算总奖励"""
@ -128,84 +159,96 @@ def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
return rewards
def spo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_tokenizer, value_tracker, start_step=0, wandb=None):
for step, batch in enumerate(loader, start=start_step + 1):
prompts = batch['prompt'] # list[str], length B
prompt_inputs = tokenizer(prompts, return_tensors="pt", padding=True, return_token_type_ids=False,
padding_side="left", add_special_tokens=False).to(args.device) # input_ids: [B, P], attention_mask: [B, P]
if args.max_seq_len:
prompt_inputs["input_ids"] = prompt_inputs["input_ids"][:, -args.max_seq_len:]
prompt_inputs["attention_mask"] = prompt_inputs["attention_mask"][:, -args.max_seq_len:]
# --- 4. 核心训练循环 ---
def spo_train_epoch(epoch, loader, iters, model, ref_model, reward_model, reward_tokenizer,
value_tracker, sampler, tokenizer, args, autocast_ctx, wandb=None):
model.train()
for step, batch in enumerate(loader, start=1):
prompts = batch['prompt']
indices = batch['index']
# 数据预处理
prompt_inputs = tokenizer(
prompts,
return_tensors="pt",
padding=True,
padding_side="left",
add_special_tokens=False,
return_token_type_ids=False
).to(args.device)
with torch.no_grad():
# DDP 模型需要使用 .module 访问 generate 方法
if args.max_seq_len:
prompt_inputs = {k: v[:, -args.max_seq_len:] for k, v in prompt_inputs.items()}
# 1. 采样生成
with torch.no_grad(), autocast_ctx:
model_for_gen = model.module if isinstance(model, DistributedDataParallel) else model
outputs = model_for_gen.generate(
**prompt_inputs, max_new_tokens=args.max_gen_len, do_sample=True, temperature=0.8,
num_return_sequences=1, pad_token_id=tokenizer.pad_token_id) # [B, P+R]
completion_ids = outputs[:, prompt_inputs["input_ids"].size(1):] # [B, R]
pad_token_id=tokenizer.pad_token_id) # use_cache = False
completion_ids = outputs[:, prompt_inputs["input_ids"].size(1):]
completions = tokenizer.batch_decode(completion_ids, skip_special_tokens=True)
# 2. 计算 Logprobs
def get_per_token_logps(mdl, input_ids, n_keep):
input_ids = input_ids.detach().clone() if input_ids.is_inference() else input_ids
logits = mdl(input_ids, logits_to_keep=n_keep + 1).logits[:, :-1, :]
per_token_logps = []
for logits_row, ids_row in zip(logits, input_ids[:, -n_keep:]):
ids_row = ids_row.detach().clone() if ids_row.is_inference() else ids_row
per_token_logps.append(torch.gather(logits_row.log_softmax(dim=-1), 1, ids_row.unsqueeze(1)).squeeze(1))
return torch.stack(per_token_logps)
target_ids = input_ids[:, -n_keep:]
log_probs = logits.log_softmax(dim=-1)
return torch.gather(log_probs, 2, target_ids.unsqueeze(2)).squeeze(2)
with autocast_ctx:
per_token_logps = get_per_token_logps(model, outputs, completion_ids.size(1))
with torch.no_grad():
ref_per_token_logps = get_per_token_logps(ref_model, outputs, completion_ids.size(1))
per_token_logps = get_per_token_logps(model, outputs, completion_ids.size(1)) # [B, R]
with torch.no_grad():
ref_per_token_logps = get_per_token_logps(ref_model, outputs, completion_ids.size(1)) # [B, R]
# 3. 奖励与优势计算
rewards = calculate_rewards(prompts, completions, reward_model, reward_tokenizer)
baselines = value_tracker.get_baselines(len(prompts)).to(args.device)
completions = tokenizer.batch_decode(completion_ids, skip_special_tokens=True) # list[str], length B
rewards = calculate_rewards(prompts, completions, reward_model, reward_tokenizer).to(args.device) # [B]
advantages = (rewards - baselines).clamp(-5.0, 5.0)
baselines = value_tracker.get_baselines(len(prompts)).to(args.device) # [B]
# 4. 更新采样权重 (EMA)
norm_rewards = ((rewards.detach() + 3.0) / 6.0).clamp(0, 1)
current_v_gpu = sampler.weights[indices].to(args.device) # 把 CPU 上的旧权重拉到 GPU
updated_v_gpu = 0.7 * current_v_gpu + 0.3 * norm_rewards # 全程 GPU 计算
sampler.update_weights(indices, updated_v_gpu.cpu()) # 计算完结果传回 CPU 存储
scale = 3.0
# Un-normalize baselines to be in the same scale as raw rewards [-3, 3]
unnormalized_baselines = baselines * (2 * scale) - scale # [B]
advantages = rewards - unnormalized_baselines # [B]
# 5. 计算损失
is_eos = completion_ids == tokenizer.eos_token_id
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=args.device)
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
completion_mask = (torch.arange(is_eos.size(1), device=args.device).expand(is_eos.size(0), -1) <= eos_idx.unsqueeze(1)).int()
# 直接使用 baseline 提供的优势估计,只做裁剪防止梯度爆炸。不再做 batch 内归一化,因为 baseline 已经提供了跨 batch 的稳定基线
advantages = advantages.clamp(-5.0, 5.0)
kl_div = ref_per_token_logps - per_token_logps
per_token_kl = torch.exp(kl_div) - kl_div - 1
per_token_loss = -per_token_logps * advantages.unsqueeze(1) + args.beta * per_token_kl
loss = ((per_token_loss * completion_mask).sum(dim=1) / (completion_mask.sum(dim=1) + 1e-8)).mean() / args.accumulation_steps
is_eos = completion_ids == tokenizer.eos_token_id # [B, R]
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=args.device) # [B]
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
completion_mask = (torch.arange(is_eos.size(1), device=args.device).expand(is_eos.size(0), -1) <= eos_idx.unsqueeze(1)).int() # [B, R]
kl_div = ref_per_token_logps - per_token_logps # [B, R]
per_token_kl = torch.exp(kl_div) - kl_div - 1 # [B, R]
per_token_loss = -per_token_logps * advantages.unsqueeze(1) + args.beta * per_token_kl # [B, R]
loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() / args.accumulation_steps # scalar
# 6. 反向传播
loss.backward()
rho = value_tracker.update(rewards, per_token_logps.detach(), completion_mask.float())
response_masks = completion_mask.float() # [B, R]
rho = value_tracker.update(rewards, per_token_logps.detach(), response_masks)
if (step + 1) % args.accumulation_steps == 0:
if args.grad_clip > 0:
if step % args.accumulation_steps == 0:
if args.grad_clip > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
torch.cuda.empty_cache()
# 7. 日志打印
if step % args.log_interval == 0 or step == iters:
policy_loss_val = loss.item()
policy_loss_val = loss.item() * args.accumulation_steps # 恢复显示量级
avg_reward_val = rewards.mean().item()
avg_len_val = completion_mask.sum(dim=1).float().mean().item()
kl_val = ((per_token_kl * completion_mask).sum() / (completion_mask.sum() + 1e-8)).item()
avg_baseline_val = baselines.mean().item()
current_lr = optimizer.param_groups[0]['lr']
Logger(f'Epoch: {epoch+1}, Step: {step}/{iters}, '
f'Actor Loss: {policy_loss_val:.6f}, Reward: {avg_reward_val:.6f}, '
f'Baseline: {avg_baseline_val:.4f}, KL: {kl_val:.4f}, Rho: {rho:.4f}, '
f'Avg Response Len: {avg_len_val:.2f}, LR: {current_lr:.2e}')
Logger(f'Step: {step}/{iters}, Loss: {policy_loss_val:.4f}, Reward: {avg_reward_val:.4f}, '
f'Baseline: {avg_baseline_val:.4f}, KL: {kl_val:.4f}, Rho: {rho:.4f}')
if wandb and is_main_process():
wandb.log({
@ -215,26 +258,30 @@ def spo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_tokeni
"rho": float(rho),
"baseline": avg_baseline_val,
"advantages_mean": advantages.mean().item(),
"learning_rate": current_lr
"learning_rate": current_lr,
})
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
# 8. 模型保存逻辑
if (step % args.save_interval == 0 or step == iters) and is_main_process():
# ### <--- 修改点 5: 确保 lm_config 在作用域内 (通常从 args 或 model 获取)
model.eval()
moe_suffix = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict()
torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer,
epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scheduler=scheduler)
lm_checkpoint(model.module.config if hasattr(model, 'module') else model.config,
weight=args.save_weight, model=model, optimizer=optimizer,
epoch=epoch, step=step, wandb=wandb, save_dir=args.save_dir, scheduler=scheduler)
model.train()
del state_dict
# 清理内存
del prompt_inputs, outputs, completion_ids, per_token_logps, ref_per_token_logps
del completions, rewards, advantages, completion_mask, baselines, response_masks
del completions, rewards, advantages, completion_mask, baselines
torch.cuda.empty_cache()
gc.collect()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MiniMind SPO (Self-Play Optimization)")
parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
@ -286,33 +333,54 @@ if __name__ == "__main__":
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
resume = 'must' if wandb_id else None
wandb_run_name = f"MiniMind-SPO-Epoch-{args.epochs}-BS-{args.batch_size}-LR-{args.learning_rate}"
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume, mode="local")
# ========== 5. 初始化模型Policy, Ref, Reward和Value Tracker、数据 ==========
# ========== 5. 初始化模型Policy, Ref, Reward和Value Tracker、数据 ==========
base_weight = "reason" if args.reasoning == 1 else "full_sft"
# Policy模型
model, tokenizer = init_model(lm_config, base_weight, device=args.device)
# Reference模型
ref_model, _ = init_model(lm_config, base_weight, device=args.device)
ref_model = ref_model.eval().requires_grad_(False)
# Reward模型
reward_model = AutoModel.from_pretrained(
args.reward_model_path, torch_dtype=torch.float16, trust_remote_code=True
)
reward_model = reward_model.to(args.device).eval().requires_grad_(False)
reward_tokenizer = AutoTokenizer.from_pretrained(args.reward_model_path, trust_remote_code=True)
# Value Tracker
value_tracker = AutoAdaptiveValueTracker(rho_mode='kl', rho_const=0.9, D_half=0.06, clip_lower=0.5, clip_upper=0.96)
train_ds = RLAIFDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
# --- 关键改动标注 ---
train_ds = SPODataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
# [标注 1]: 必须显式传入 num_replicas 和 rank确保分布式采样逻辑正确
train_sampler = WeightedDistributedSampler(
train_ds,
num_replicas=dist.get_world_size() if dist.is_initialized() else 1,
rank=local_rank
)
# [标注 2]: DataLoader 必须绑定这个 train_sampler
loader = DataLoader(
train_ds,
batch_size=args.batch_size,
sampler=train_sampler,
num_workers=args.num_workers,
pin_memory=True
)
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
loader_for_count = DataLoader(train_ds, batch_size=args.batch_size, sampler=train_sampler)
iters = len(loader_for_count)
total_optimizer_steps = (iters // args.accumulation_steps) * args.epochs
scheduler = CosineAnnealingLR(optimizer, T_max=total_optimizer_steps, eta_min=args.learning_rate / 10)
# 直接使用 loader 的长度即可,无需额外创建一个 loader_for_count节省开销
iters = len(loader)
# 确保 total_optimizer_steps 考虑了梯度累积
total_optimizer_steps = (iters // args.accumulation_steps) * args.epochs
scheduler = CosineAnnealingLR(optimizer, T_max=total_optimizer_steps, eta_min=args.learning_rate / 10)
# ========== 6. 从ckp恢复状态 ==========
start_epoch, start_step = 0, 0
if ckp_data:
@ -327,16 +395,36 @@ if __name__ == "__main__":
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
model = DistributedDataParallel(model, device_ids=[local_rank])
# ========== 8. 开始训练 ==========
# ========== 8. 开始训练 ==========
for epoch in range(start_epoch, args.epochs):
train_sampler and train_sampler.set_epoch(epoch)
if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点
batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1)
# 必须调用 set_epoch 使得 priority 随 epoch 重新洗牌
train_sampler.set_epoch(epoch)
if epoch == start_epoch and start_step > 0:
# 续训逻辑
batch_sampler = SkipBatchSampler(train_sampler, args.batch_size, start_step + 1)
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step从step {start_step + 1}开始')
spo_train_epoch(epoch, loader, len(loader) + start_step + 1, ref_model, reward_model, reward_tokenizer, value_tracker, start_step, wandb)
else: # 默认从头开始
# [修改点 1]: 补全了大量缺失的参数。原代码漏掉了 model, tokenizer, args, autocast_ctx 等。
# 注意:这里的 iters 应该是全局步数,修正为 len(loader) + start_step
spo_train_epoch(
epoch, loader, len(loader) + start_step + 1,
model, ref_model, reward_model, reward_tokenizer,
value_tracker, train_sampler, tokenizer,
args, autocast_ctx, wandb
)
train_sampler.sync_weights()
else:
# 常规训练
loader = DataLoader(train_ds, batch_size=args.batch_size, pin_memory=True,
drop_last=False, shuffle=(train_sampler is None),
num_workers=args.num_workers, sampler=train_sampler)
spo_train_epoch(epoch, loader, len(loader), ref_model, reward_model, reward_tokenizer, value_tracker, 0, wandb)
drop_last=False, num_workers=args.num_workers, sampler=train_sampler)
spo_train_epoch(
epoch, loader, len(loader),
model, ref_model, reward_model, reward_tokenizer,
value_tracker, train_sampler, tokenizer,
args, autocast_ctx, wandb
)
# sync_weights 必须在每个 epoch 结束时调用,以同步多卡的采样权重
train_sampler.sync_weights()