mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-04-25 08:48:16 +08:00
[mod] fix spo algorithm in RLAIF part
This commit is contained in:
parent
a9c56b20e9
commit
020bd44f3f
@ -195,6 +195,53 @@ class DPODataset(Dataset):
|
||||
return loss_mask
|
||||
|
||||
|
||||
# 添加SPOdataset
|
||||
class SPODataset(Dataset):
|
||||
def __init__(self, jsonl_path, tokenizer, max_length=1024):
|
||||
super().__init__()
|
||||
self.tokenizer = tokenizer
|
||||
self.max_length = max_length
|
||||
self.samples = self.load_data(jsonl_path)
|
||||
self.bos_id = tokenizer(f'{tokenizer.bos_token}assistant', add_special_tokens=False).input_ids
|
||||
self.eos_id = tokenizer(f'{tokenizer.eos_token}', add_special_tokens=False).input_ids
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
def load_data(self, path):
|
||||
samples = []
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
for line_num, line in enumerate(f, 1):
|
||||
data = json.loads(line.strip())
|
||||
samples.append(data)
|
||||
return samples
|
||||
|
||||
def _create_chat_prompt(self, conversations):
|
||||
"""构建符合ChatML格式的对话"""
|
||||
messages = []
|
||||
answer = ''
|
||||
for i, turn in enumerate(conversations):
|
||||
role = 'user' if i % 2 == 0 else 'assistant'
|
||||
messages.append({"role": role, "content": turn['content']})
|
||||
answer = turn['content']
|
||||
return self.tokenizer.apply_chat_template(
|
||||
messages[:-1],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True # 这里需要True
|
||||
), answer
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.samples[index]
|
||||
# 构建对话提示
|
||||
prompt, answer = self._create_chat_prompt(sample['conversations'])
|
||||
|
||||
return {
|
||||
'prompt': prompt,
|
||||
'answer': answer,
|
||||
'index': index # 关键修改:返回索引
|
||||
}
|
||||
|
||||
|
||||
class RLAIFDataset(Dataset):
|
||||
def __init__(self, jsonl_path, tokenizer, max_length=1024):
|
||||
super().__init__()
|
||||
@ -240,5 +287,6 @@ class RLAIFDataset(Dataset):
|
||||
}
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
||||
|
||||
@ -18,14 +18,47 @@ from torch.utils.data import DataLoader, DistributedSampler
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
from transformers import AutoModel
|
||||
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
|
||||
from dataset.lm_dataset import RLAIFDataset
|
||||
from dataset.lm_dataset import SPODataset
|
||||
from trainer.trainer_utils import Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, SkipBatchSampler, init_model
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
|
||||
# --- 1. 自定义优先采样器 (带分布式同步) ---
|
||||
class WeightedDistributedSampler(DistributedSampler):
|
||||
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0, epsilon=1e-5, device='cpu'):
|
||||
super().__init__(dataset, num_replicas, rank, shuffle, seed)
|
||||
# [修改点]:允许指定 weights 存放的设备,默认 CPU
|
||||
self.weights = torch.ones(len(dataset), dtype=torch.float32).to(device)
|
||||
self.epsilon = epsilon
|
||||
self.device = device
|
||||
|
||||
def update_weights(self, indices, new_v_estimates):
|
||||
"""更新权重"""
|
||||
# 如果 self.weights 在 GPU,这里就不再产生同步阻塞
|
||||
self.weights[indices] = new_v_estimates.to(self.weights.device)
|
||||
|
||||
def sync_weights(self):
|
||||
"""核心:在 Epoch 结束时同步所有卡的权重,防止采样漂移"""
|
||||
if dist.is_initialized():
|
||||
dist.all_reduce(self.weights, op=dist.ReduceOp.SUM)
|
||||
self.weights /= dist.get_world_size()
|
||||
|
||||
def __iter__(self):
|
||||
# 优先级公式:sqrt(v * (1-v)) + epsilon
|
||||
priority = torch.sqrt(self.weights * (1 - self.weights) + 1e-8) + self.epsilon
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self.seed + self.epoch)
|
||||
|
||||
# 全局加权采样
|
||||
indices = torch.multinomial(priority, self.total_size, replacement=True, generator=g).tolist()
|
||||
# 分配到当前进程 (Rank)
|
||||
indices = indices[self.rank:self.total_size:self.num_replicas]
|
||||
return iter(indices)
|
||||
|
||||
|
||||
# --- 2. SPO 自适应价值追踪器 ---
|
||||
class AutoAdaptiveValueTracker:
|
||||
"""SPO自适应价值追踪器"""
|
||||
def __init__(self, rho_mode='kl', rho_const=0.9, D_half=0.06, clip_lower=0.5, clip_upper=0.96):
|
||||
self.rho_mode = rho_mode
|
||||
self.rho_const = rho_const
|
||||
@ -42,9 +75,7 @@ class AutoAdaptiveValueTracker:
|
||||
return torch.full((batch_size,), baseline, dtype=torch.float32)
|
||||
|
||||
def compute_rho(self, cur_mean_logprob):
|
||||
if self.rho_mode == 'constant':
|
||||
return self.rho_const
|
||||
if self.old_mean_logprob is None:
|
||||
if self.rho_mode == 'constant' or self.old_mean_logprob is None:
|
||||
return self.rho_const
|
||||
kl = abs(self.old_mean_logprob - cur_mean_logprob)
|
||||
rho = 2 ** (-kl / self.D_half)
|
||||
@ -64,7 +95,7 @@ class AutoAdaptiveValueTracker:
|
||||
self.alpha = rho * self.alpha + avg_normalized_reward
|
||||
self.beta = rho * self.beta + (1 - avg_normalized_reward)
|
||||
return rho
|
||||
|
||||
|
||||
|
||||
def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
|
||||
"""整合所有奖励函数计算总奖励"""
|
||||
@ -128,84 +159,96 @@ def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
|
||||
return rewards
|
||||
|
||||
|
||||
def spo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_tokenizer, value_tracker, start_step=0, wandb=None):
|
||||
for step, batch in enumerate(loader, start=start_step + 1):
|
||||
prompts = batch['prompt'] # list[str], length B
|
||||
prompt_inputs = tokenizer(prompts, return_tensors="pt", padding=True, return_token_type_ids=False,
|
||||
padding_side="left", add_special_tokens=False).to(args.device) # input_ids: [B, P], attention_mask: [B, P]
|
||||
if args.max_seq_len:
|
||||
prompt_inputs["input_ids"] = prompt_inputs["input_ids"][:, -args.max_seq_len:]
|
||||
prompt_inputs["attention_mask"] = prompt_inputs["attention_mask"][:, -args.max_seq_len:]
|
||||
# --- 4. 核心训练循环 ---
|
||||
def spo_train_epoch(epoch, loader, iters, model, ref_model, reward_model, reward_tokenizer,
|
||||
value_tracker, sampler, tokenizer, args, autocast_ctx, wandb=None):
|
||||
model.train()
|
||||
for step, batch in enumerate(loader, start=1):
|
||||
prompts = batch['prompt']
|
||||
indices = batch['index']
|
||||
|
||||
# 数据预处理
|
||||
prompt_inputs = tokenizer(
|
||||
prompts,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
padding_side="left",
|
||||
add_special_tokens=False,
|
||||
return_token_type_ids=False
|
||||
).to(args.device)
|
||||
|
||||
with torch.no_grad():
|
||||
# DDP 模型需要使用 .module 访问 generate 方法
|
||||
if args.max_seq_len:
|
||||
prompt_inputs = {k: v[:, -args.max_seq_len:] for k, v in prompt_inputs.items()}
|
||||
|
||||
# 1. 采样生成
|
||||
with torch.no_grad(), autocast_ctx:
|
||||
model_for_gen = model.module if isinstance(model, DistributedDataParallel) else model
|
||||
outputs = model_for_gen.generate(
|
||||
**prompt_inputs, max_new_tokens=args.max_gen_len, do_sample=True, temperature=0.8,
|
||||
num_return_sequences=1, pad_token_id=tokenizer.pad_token_id) # [B, P+R]
|
||||
|
||||
completion_ids = outputs[:, prompt_inputs["input_ids"].size(1):] # [B, R]
|
||||
|
||||
pad_token_id=tokenizer.pad_token_id) # use_cache = False
|
||||
completion_ids = outputs[:, prompt_inputs["input_ids"].size(1):]
|
||||
completions = tokenizer.batch_decode(completion_ids, skip_special_tokens=True)
|
||||
|
||||
# 2. 计算 Logprobs
|
||||
def get_per_token_logps(mdl, input_ids, n_keep):
|
||||
input_ids = input_ids.detach().clone() if input_ids.is_inference() else input_ids
|
||||
logits = mdl(input_ids, logits_to_keep=n_keep + 1).logits[:, :-1, :]
|
||||
per_token_logps = []
|
||||
for logits_row, ids_row in zip(logits, input_ids[:, -n_keep:]):
|
||||
ids_row = ids_row.detach().clone() if ids_row.is_inference() else ids_row
|
||||
per_token_logps.append(torch.gather(logits_row.log_softmax(dim=-1), 1, ids_row.unsqueeze(1)).squeeze(1))
|
||||
return torch.stack(per_token_logps)
|
||||
target_ids = input_ids[:, -n_keep:]
|
||||
log_probs = logits.log_softmax(dim=-1)
|
||||
return torch.gather(log_probs, 2, target_ids.unsqueeze(2)).squeeze(2)
|
||||
|
||||
with autocast_ctx:
|
||||
per_token_logps = get_per_token_logps(model, outputs, completion_ids.size(1))
|
||||
with torch.no_grad():
|
||||
ref_per_token_logps = get_per_token_logps(ref_model, outputs, completion_ids.size(1))
|
||||
|
||||
per_token_logps = get_per_token_logps(model, outputs, completion_ids.size(1)) # [B, R]
|
||||
with torch.no_grad():
|
||||
ref_per_token_logps = get_per_token_logps(ref_model, outputs, completion_ids.size(1)) # [B, R]
|
||||
# 3. 奖励与优势计算
|
||||
rewards = calculate_rewards(prompts, completions, reward_model, reward_tokenizer)
|
||||
baselines = value_tracker.get_baselines(len(prompts)).to(args.device)
|
||||
|
||||
completions = tokenizer.batch_decode(completion_ids, skip_special_tokens=True) # list[str], length B
|
||||
rewards = calculate_rewards(prompts, completions, reward_model, reward_tokenizer).to(args.device) # [B]
|
||||
advantages = (rewards - baselines).clamp(-5.0, 5.0)
|
||||
|
||||
baselines = value_tracker.get_baselines(len(prompts)).to(args.device) # [B]
|
||||
# 4. 更新采样权重 (EMA)
|
||||
norm_rewards = ((rewards.detach() + 3.0) / 6.0).clamp(0, 1)
|
||||
current_v_gpu = sampler.weights[indices].to(args.device) # 把 CPU 上的旧权重拉到 GPU
|
||||
updated_v_gpu = 0.7 * current_v_gpu + 0.3 * norm_rewards # 全程 GPU 计算
|
||||
sampler.update_weights(indices, updated_v_gpu.cpu()) # 计算完结果传回 CPU 存储
|
||||
|
||||
scale = 3.0
|
||||
# Un-normalize baselines to be in the same scale as raw rewards [-3, 3]
|
||||
unnormalized_baselines = baselines * (2 * scale) - scale # [B]
|
||||
advantages = rewards - unnormalized_baselines # [B]
|
||||
# 5. 计算损失
|
||||
is_eos = completion_ids == tokenizer.eos_token_id
|
||||
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=args.device)
|
||||
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
|
||||
completion_mask = (torch.arange(is_eos.size(1), device=args.device).expand(is_eos.size(0), -1) <= eos_idx.unsqueeze(1)).int()
|
||||
|
||||
# 直接使用 baseline 提供的优势估计,只做裁剪防止梯度爆炸。不再做 batch 内归一化,因为 baseline 已经提供了跨 batch 的稳定基线
|
||||
advantages = advantages.clamp(-5.0, 5.0)
|
||||
kl_div = ref_per_token_logps - per_token_logps
|
||||
per_token_kl = torch.exp(kl_div) - kl_div - 1
|
||||
|
||||
per_token_loss = -per_token_logps * advantages.unsqueeze(1) + args.beta * per_token_kl
|
||||
loss = ((per_token_loss * completion_mask).sum(dim=1) / (completion_mask.sum(dim=1) + 1e-8)).mean() / args.accumulation_steps
|
||||
|
||||
is_eos = completion_ids == tokenizer.eos_token_id # [B, R]
|
||||
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=args.device) # [B]
|
||||
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
|
||||
completion_mask = (torch.arange(is_eos.size(1), device=args.device).expand(is_eos.size(0), -1) <= eos_idx.unsqueeze(1)).int() # [B, R]
|
||||
|
||||
kl_div = ref_per_token_logps - per_token_logps # [B, R]
|
||||
per_token_kl = torch.exp(kl_div) - kl_div - 1 # [B, R]
|
||||
per_token_loss = -per_token_logps * advantages.unsqueeze(1) + args.beta * per_token_kl # [B, R]
|
||||
loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() / args.accumulation_steps # scalar
|
||||
# 6. 反向传播
|
||||
loss.backward()
|
||||
|
||||
rho = value_tracker.update(rewards, per_token_logps.detach(), completion_mask.float())
|
||||
|
||||
response_masks = completion_mask.float() # [B, R]
|
||||
rho = value_tracker.update(rewards, per_token_logps.detach(), response_masks)
|
||||
|
||||
if (step + 1) % args.accumulation_steps == 0:
|
||||
if args.grad_clip > 0:
|
||||
if step % args.accumulation_steps == 0:
|
||||
if args.grad_clip > 0:
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# 7. 日志打印
|
||||
if step % args.log_interval == 0 or step == iters:
|
||||
policy_loss_val = loss.item()
|
||||
policy_loss_val = loss.item() * args.accumulation_steps # 恢复显示量级
|
||||
avg_reward_val = rewards.mean().item()
|
||||
avg_len_val = completion_mask.sum(dim=1).float().mean().item()
|
||||
kl_val = ((per_token_kl * completion_mask).sum() / (completion_mask.sum() + 1e-8)).item()
|
||||
avg_baseline_val = baselines.mean().item()
|
||||
current_lr = optimizer.param_groups[0]['lr']
|
||||
|
||||
Logger(f'Epoch: {epoch+1}, Step: {step}/{iters}, '
|
||||
f'Actor Loss: {policy_loss_val:.6f}, Reward: {avg_reward_val:.6f}, '
|
||||
f'Baseline: {avg_baseline_val:.4f}, KL: {kl_val:.4f}, Rho: {rho:.4f}, '
|
||||
f'Avg Response Len: {avg_len_val:.2f}, LR: {current_lr:.2e}')
|
||||
Logger(f'Step: {step}/{iters}, Loss: {policy_loss_val:.4f}, Reward: {avg_reward_val:.4f}, '
|
||||
f'Baseline: {avg_baseline_val:.4f}, KL: {kl_val:.4f}, Rho: {rho:.4f}')
|
||||
|
||||
if wandb and is_main_process():
|
||||
wandb.log({
|
||||
@ -215,26 +258,30 @@ def spo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_tokeni
|
||||
"rho": float(rho),
|
||||
"baseline": avg_baseline_val,
|
||||
"advantages_mean": advantages.mean().item(),
|
||||
"learning_rate": current_lr
|
||||
"learning_rate": current_lr,
|
||||
|
||||
})
|
||||
|
||||
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
|
||||
# 8. 模型保存逻辑
|
||||
if (step % args.save_interval == 0 or step == iters) and is_main_process():
|
||||
# ### <--- 修改点 5: 确保 lm_config 在作用域内 (通常从 args 或 model 获取)
|
||||
model.eval()
|
||||
moe_suffix = '_moe' if lm_config.use_moe else ''
|
||||
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
|
||||
state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict()
|
||||
torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
|
||||
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer,
|
||||
epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scheduler=scheduler)
|
||||
lm_checkpoint(model.module.config if hasattr(model, 'module') else model.config,
|
||||
weight=args.save_weight, model=model, optimizer=optimizer,
|
||||
epoch=epoch, step=step, wandb=wandb, save_dir=args.save_dir, scheduler=scheduler)
|
||||
model.train()
|
||||
del state_dict
|
||||
|
||||
# 清理内存
|
||||
del prompt_inputs, outputs, completion_ids, per_token_logps, ref_per_token_logps
|
||||
del completions, rewards, advantages, completion_mask, baselines, response_masks
|
||||
del completions, rewards, advantages, completion_mask, baselines
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="MiniMind SPO (Self-Play Optimization)")
|
||||
parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
|
||||
@ -286,33 +333,54 @@ if __name__ == "__main__":
|
||||
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
|
||||
resume = 'must' if wandb_id else None
|
||||
wandb_run_name = f"MiniMind-SPO-Epoch-{args.epochs}-BS-{args.batch_size}-LR-{args.learning_rate}"
|
||||
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
|
||||
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume, mode="local")
|
||||
|
||||
# ========== 5. 初始化模型(Policy, Ref, Reward)和Value Tracker、数据 ==========
|
||||
# ========== 5. 初始化模型(Policy, Ref, Reward)和Value Tracker、数据 ==========
|
||||
base_weight = "reason" if args.reasoning == 1 else "full_sft"
|
||||
# Policy模型
|
||||
model, tokenizer = init_model(lm_config, base_weight, device=args.device)
|
||||
# Reference模型
|
||||
ref_model, _ = init_model(lm_config, base_weight, device=args.device)
|
||||
ref_model = ref_model.eval().requires_grad_(False)
|
||||
|
||||
# Reward模型
|
||||
reward_model = AutoModel.from_pretrained(
|
||||
args.reward_model_path, torch_dtype=torch.float16, trust_remote_code=True
|
||||
)
|
||||
reward_model = reward_model.to(args.device).eval().requires_grad_(False)
|
||||
reward_tokenizer = AutoTokenizer.from_pretrained(args.reward_model_path, trust_remote_code=True)
|
||||
|
||||
# Value Tracker
|
||||
value_tracker = AutoAdaptiveValueTracker(rho_mode='kl', rho_const=0.9, D_half=0.06, clip_lower=0.5, clip_upper=0.96)
|
||||
|
||||
train_ds = RLAIFDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
|
||||
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
|
||||
# --- 关键改动标注 ---
|
||||
train_ds = SPODataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
|
||||
|
||||
# [标注 1]: 必须显式传入 num_replicas 和 rank,确保分布式采样逻辑正确
|
||||
train_sampler = WeightedDistributedSampler(
|
||||
train_ds,
|
||||
num_replicas=dist.get_world_size() if dist.is_initialized() else 1,
|
||||
rank=local_rank
|
||||
)
|
||||
|
||||
# [标注 2]: DataLoader 必须绑定这个 train_sampler
|
||||
loader = DataLoader(
|
||||
train_ds,
|
||||
batch_size=args.batch_size,
|
||||
sampler=train_sampler,
|
||||
num_workers=args.num_workers,
|
||||
pin_memory=True
|
||||
)
|
||||
|
||||
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
|
||||
|
||||
loader_for_count = DataLoader(train_ds, batch_size=args.batch_size, sampler=train_sampler)
|
||||
iters = len(loader_for_count)
|
||||
total_optimizer_steps = (iters // args.accumulation_steps) * args.epochs
|
||||
scheduler = CosineAnnealingLR(optimizer, T_max=total_optimizer_steps, eta_min=args.learning_rate / 10)
|
||||
# 直接使用 loader 的长度即可,无需额外创建一个 loader_for_count,节省开销
|
||||
iters = len(loader)
|
||||
|
||||
# 确保 total_optimizer_steps 考虑了梯度累积
|
||||
total_optimizer_steps = (iters // args.accumulation_steps) * args.epochs
|
||||
scheduler = CosineAnnealingLR(optimizer, T_max=total_optimizer_steps, eta_min=args.learning_rate / 10)
|
||||
|
||||
# ========== 6. 从ckp恢复状态 ==========
|
||||
start_epoch, start_step = 0, 0
|
||||
if ckp_data:
|
||||
@ -327,16 +395,36 @@ if __name__ == "__main__":
|
||||
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
||||
model = DistributedDataParallel(model, device_ids=[local_rank])
|
||||
|
||||
# ========== 8. 开始训练 ==========
|
||||
# ========== 8. 开始训练 ==========
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
train_sampler and train_sampler.set_epoch(epoch)
|
||||
if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点
|
||||
batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1)
|
||||
# 必须调用 set_epoch 使得 priority 随 epoch 重新洗牌
|
||||
train_sampler.set_epoch(epoch)
|
||||
|
||||
if epoch == start_epoch and start_step > 0:
|
||||
# 续训逻辑
|
||||
batch_sampler = SkipBatchSampler(train_sampler, args.batch_size, start_step + 1)
|
||||
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始')
|
||||
spo_train_epoch(epoch, loader, len(loader) + start_step + 1, ref_model, reward_model, reward_tokenizer, value_tracker, start_step, wandb)
|
||||
else: # 默认从头开始
|
||||
|
||||
# [修改点 1]: 补全了大量缺失的参数。原代码漏掉了 model, tokenizer, args, autocast_ctx 等。
|
||||
# 注意:这里的 iters 应该是全局步数,修正为 len(loader) + start_step
|
||||
spo_train_epoch(
|
||||
epoch, loader, len(loader) + start_step + 1,
|
||||
model, ref_model, reward_model, reward_tokenizer,
|
||||
value_tracker, train_sampler, tokenizer,
|
||||
args, autocast_ctx, wandb
|
||||
)
|
||||
train_sampler.sync_weights()
|
||||
|
||||
else:
|
||||
# 常规训练
|
||||
loader = DataLoader(train_ds, batch_size=args.batch_size, pin_memory=True,
|
||||
drop_last=False, shuffle=(train_sampler is None),
|
||||
num_workers=args.num_workers, sampler=train_sampler)
|
||||
spo_train_epoch(epoch, loader, len(loader), ref_model, reward_model, reward_tokenizer, value_tracker, 0, wandb)
|
||||
drop_last=False, num_workers=args.num_workers, sampler=train_sampler)
|
||||
|
||||
spo_train_epoch(
|
||||
epoch, loader, len(loader),
|
||||
model, ref_model, reward_model, reward_tokenizer,
|
||||
value_tracker, train_sampler, tokenizer,
|
||||
args, autocast_ctx, wandb
|
||||
)
|
||||
# sync_weights 必须在每个 epoch 结束时调用,以同步多卡的采样权重
|
||||
train_sampler.sync_weights()
|
||||
Loading…
Reference in New Issue
Block a user