Files
2026-06-01 17:50:39 +08:00

493 lines
31 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
import sys
__package__ = "trainer"
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import datasets # noqa: F401 # Windows pyarrow/torch DLL conflict workaround (issue #771)
import re
import gc
import json
import math
import random
import signal
import argparse
import warnings
import torch
import torch.nn.functional as F
import torch.distributed as dist
from contextlib import nullcontext
from torch import optim
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, DistributedSampler
from torch.optim.lr_scheduler import CosineAnnealingLR
from transformers import AutoTokenizer
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
from dataset.lm_dataset import AgentRLDataset
from trainer.trainer_utils import Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, SkipBatchSampler, init_model, LMForRewardModel
from trainer.rollout_engine import create_rollout_engine, compute_per_token_logps
warnings.filterwarnings('ignore')
# ================================ 工具与 Reward = Start ================================
def rep_penalty(text, n=3, cap=0.5):
toks = re.findall(r"\w+|[^\w\s]", text.lower())
grams = [tuple(toks[i:i + n]) for i in range(len(toks) - n + 1)]
return min(cap, (len(grams) - len(set(grams))) * cap * 2 / len(grams)) if grams else 0.0
# ======== 工具定义 ========
TOOLS = [
{"type": "function", "function": {"name": "calculate_math", "description": "计算数学表达式", "parameters": {"type": "object", "properties": {"expression": {"type": "string"}}, "required": ["expression"]}}},
{"type": "function", "function": {"name": "unit_converter", "description": "单位换算", "parameters": {"type": "object", "properties": {"value": {"type": "number"}, "from_unit": {"type": "string"}, "to_unit": {"type": "string"}}, "required": ["value", "from_unit", "to_unit"]}}},
{"type": "function", "function": {"name": "get_current_weather", "description": "获取天气", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}}},
{"type": "function", "function": {"name": "get_current_time", "description": "获取时间", "parameters": {"type": "object", "properties": {"timezone": {"type": "string", "default": "Asia/Shanghai"}}, "required": []}}},
{"type": "function", "function": {"name": "get_exchange_rate", "description": "查询汇率", "parameters": {"type": "object", "properties": {"from_currency": {"type": "string"}, "to_currency": {"type": "string"}}, "required": ["from_currency", "to_currency"]}}},
{"type": "function", "function": {"name": "translate_text", "description": "翻译文本", "parameters": {"type": "object", "properties": {"text": {"type": "string"}, "target_language": {"type": "string"}}, "required": ["text", "target_language"]}}},
]
# ======== 模拟数据 ========
WEATHER_DATA = {"北京": ("28°C", ""), "上海": ("15°C", "多云"), "广州": ("32°C", "闷热"), "深圳": ("30°C", ""), "杭州": ("22°C", ""), "成都": ("18°C", "小雨"), "武汉": ("25°C", "多云"), "南京": ("20°C", ""), "西安": ("16°C", "大风"), "重庆": ("26°C", ""), "Tokyo": ("12°C", ""), "New York": ("8°C", "多云"), "London": ("5°C", "小雨"), "Paris": ("10°C", ""), "Sydney": ("25°C", "晴朗")}
TIME_DATA = {"Asia/Shanghai": "2025-03-07 14:30:00", "America/New_York": "2025-03-07 01:30:00", "Europe/London": "2025-03-07 06:30:00", "Asia/Tokyo": "2025-03-07 15:30:00", "Europe/Paris": "2025-03-07 07:30:00", "Australia/Sydney": "2025-03-07 17:30:00"}
EXCHANGE_DATA = {("USD", "CNY"): 7.21, ("EUR", "CNY"): 7.85, ("GBP", "CNY"): 9.12, ("JPY", "CNY"): 0.048, ("USD", "EUR"): 0.92, ("USD", "GBP"): 0.79, ("CNY", "JPY"): 20.83, ("AUD", "CNY"): 4.72}
TRANSLATE_DATA = {("你好世界", "english"): "Hello World", ("Good morning", "chinese"): "早上好", ("今天天气真好", "english"): "The weather is nice today", ("I love programming", "chinese"): "我喜欢编程", ("机器学习很有趣", "english"): "Machine learning is interesting", ("Happy birthday", "chinese"): "生日快乐"}
UNIT_DATA = {"km_miles": 0.621371, "miles_km": 1.60934, "kg_pounds": 2.20462, "pounds_kg": 0.453592, "meters_feet": 3.28084, "feet_meters": 0.3048, "celsius_fahrenheit": 1.8, "fahrenheit_celsius": 0.5556}
# ======== 模拟执行 ========
MOCK_RESULTS = {
"calculate_math": lambda args: {"result": str(eval(str(args.get("expression", "0")).replace("^", "**").replace("×", "*").replace("÷", "/").replace("", "-").replace("", "(").replace("", ")"), {"__builtins__": {}, "math": math}))},
"unit_converter": lambda args: {"result": round(float(args.get("value", 0)) * UNIT_DATA.get(f"{args.get('from_unit', '').lower()}_{args.get('to_unit', '').lower()}", 1), 4)},
"get_current_weather": lambda args: (lambda w: {"city": args.get("location"), "temperature": w[0], "humidity": "65%", "condition": w[1]})(WEATHER_DATA.get(args.get("location"), ("22°C", ""))),
"get_current_time": lambda args: {"datetime": TIME_DATA.get(args.get("timezone", "Asia/Shanghai"), "2025-03-07 14:30:00"), "timezone": args.get("timezone", "Asia/Shanghai")},
"get_exchange_rate": lambda args: {"from": args.get("from_currency"), "to": args.get("to_currency"), "rate": EXCHANGE_DATA.get((args.get("from_currency"), args.get("to_currency")), 1.0)},
"translate_text": lambda args: {"translated_text": TRANSLATE_DATA.get((args.get("text"), args.get("target_language")), args.get("text", ""))},
}
# ======== 参数校验 ========
CHECK_ARGS = {
"calculate_math": lambda a: bool(a.get("expression")),
"unit_converter": lambda a: a.get("value") is not None and a.get("from_unit") and a.get("to_unit"),
"get_current_weather": lambda a: bool(a.get("location")),
"get_current_time": lambda a: True,
"get_exchange_rate": lambda a: bool(a.get("from_currency")) and bool(a.get("to_currency")),
"translate_text": lambda a: bool(a.get("text")) and bool(a.get("target_language")),
}
# ======== 工具调用解析与执行 ========
def parse_tool_calls(text):
calls = []
for m in re.findall(r'<tool_call>(.*?)</tool_call>', text, re.DOTALL):
try: calls.append(json.loads(m.strip()))
except: pass
return calls
def execute_tool(name, args):
fn = MOCK_RESULTS.get(name)
if not fn: return None
try:
signal.signal(signal.SIGALRM, lambda *_: (_ for _ in ()).throw(TimeoutError()))
signal.alarm(1)
return fn(args)
except:
return None
finally:
try: signal.alarm(0)
except: pass
# ======== 多轮 Rollout ========
def rollout_single(rollout_engine, tokenizer, messages, tools, max_turns=3, max_new_tokens=256, thinking_ratio=0.5, device="cuda"):
all_outputs = []
prompt_ids = None
response_ids = []
response_mask = []
response_old_logps = []
final_context = ""
unfinished = False
open_thinking = random.random() < thinking_ratio
for turn in range(max_turns):
context = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, tools=tools, open_thinking=open_thinking)
inputs = tokenizer(context, return_tensors="pt", add_special_tokens=False).to(device)
context_ids = inputs["input_ids"][0].tolist()
if prompt_ids is None:
prompt_ids = context_ids
rollout_result = rollout_engine.rollout(
prompt_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
num_generations=1,
max_new_tokens=max_new_tokens,
temperature=0.8,
)
new_ids = rollout_result.completion_ids[0].tolist()
new_logps = rollout_result.per_token_logps[0].tolist()
if len(new_ids) != len(new_logps): Logger(f"rollout token/logprob length mismatch: {len(new_ids)} vs {len(new_logps)}")
pairs = [(t, lp) for t, lp in zip(new_ids, new_logps) if t != tokenizer.pad_token_id and t != tokenizer.eos_token_id]
new_ids = [t for t, _ in pairs]
new_logps = [lp for _, lp in pairs]
new_text = rollout_result.completions[0]
all_outputs.append(new_text)
response_ids.extend(new_ids)
response_mask.extend([1] * len(new_ids))
response_old_logps.extend(new_logps)
final_context = context + new_text
calls = parse_tool_calls(new_text)
if not calls:
break
unfinished = turn == max_turns - 1
messages.append({"role": "assistant", "content": new_text})
for call in calls:
name, raw = call.get("name", ""), call.get("arguments", {})
if isinstance(raw, str):
try: raw = json.loads(raw)
except: raw = {}
result = execute_tool(name, raw)
result_str = (json.dumps(result, ensure_ascii=False) if result else '{"error": "tool not found"}')[:2048] # 防止天文数字撑爆tokenizer
messages.append({"role": "tool", "content": result_str})
observe_context = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=not unfinished, tools=tools, open_thinking=open_thinking)
observe_ids = tokenizer(observe_context, return_tensors="pt", add_special_tokens=False)["input_ids"][0].tolist()
current_len = len(prompt_ids) + len(response_ids)
obs_delta = observe_ids[current_len:]
response_ids.extend(obs_delta)
response_mask.extend([0] * len(obs_delta))
response_old_logps.extend([0.0] * len(obs_delta))
final_context = observe_context
final_output = all_outputs[-1] if all_outputs else ""
prompt_ids = prompt_ids or []
return final_output, final_context, prompt_ids, response_ids, response_mask, response_old_logps, list(all_outputs), unfinished
def rollout_batch(rollout_engine, tokenizer, messages_batch, tools_batch, num_gen, max_turns=3, max_new_tokens=256, thinking_ratio=0.5, device="cuda"):
all_completions = []
all_contexts = []
all_prompt_ids = []
all_response_ids = []
all_response_masks = []
all_response_old_logps = []
all_turn_outputs = []
all_unfinished = []
for messages, tools in zip(messages_batch, tools_batch):
for _ in range(num_gen):
msgs_copy = [dict(m) for m in messages]
completion, context, prompt_ids, response_ids, response_mask, response_old_logps, turn_outputs, unfinished = rollout_single(rollout_engine, tokenizer, msgs_copy, tools, max_turns, max_new_tokens, thinking_ratio, device)
all_completions.append(completion)
all_contexts.append(context)
all_prompt_ids.append(prompt_ids)
all_response_ids.append(response_ids)
all_response_masks.append(response_mask)
all_response_old_logps.append(response_old_logps)
all_turn_outputs.append(turn_outputs)
all_unfinished.append(unfinished)
return all_completions, all_contexts, all_prompt_ids, all_response_ids, all_response_masks, all_response_old_logps, all_turn_outputs, all_unfinished
# ======== Reward 计算 ========
def validate_gt_in_text(text, gt_list):
text, text_num = str(text), str(text).replace(',', '')
nums = [float(x) for x in re.findall(r'(?<![\w.])[-+]?\d+(?:\.\d+)?(?![\w.])', text_num)]
return {g for g in gt_list if ((s := str(g).strip()) and s.lower() in text.lower()) or (re.fullmatch(r'[-+]?\d+(?:\.\d+)?', str(g).strip().replace(',', '')) and any(abs(float(str(g).strip().replace(',', '')) - n) < 1e-6 for n in nums))}
def calculate_rewards(prompts, completions, gt_batch, tools_batch, num_gen, reward_model=None, device="cuda", turn_outputs_batch=None, unfinished_batch=None):
rewards = torch.zeros(len(completions), device=device)
for idx, response in enumerate(completions):
reward, answer = 0.0, response
sample_idx = idx // num_gen
tools = tools_batch[sample_idx]
turn_outputs = turn_outputs_batch[idx] if turn_outputs_batch is not None else [response]
unfinished = unfinished_batch[idx] if unfinished_batch is not None else False
turn_answers = [turn.split('</think>', 1)[-1].strip() if '</think>' in turn else turn.strip() for turn in turn_outputs]
answer = turn_answers[-1] if turn_answers else response.strip()
valid_names = {t['function']['name'] for t in tools} if tools else set()
tool_calls = []
for turn_answer in turn_answers: tool_calls.extend(parse_tool_calls(turn_answer)) # 解析tool调用
reward -= 0.5 * sum(abs(turn.count('<tool_call>') - turn.count('</tool_call>')) for turn in turn_answers) # 标签扣分
# -------- 无工具调用:格式+reward奖励 --------
if not tool_calls:
reward += 0.5 if 5 <= len(response.strip()) <= 800 else -0.5 # 长度分
if '</think>' in response:
think, answer = response.split('</think>', 1)
reward += 1.0 if 20 <= len(think.strip()) <= 300 else -0.5 # 思考长度分
reward += 0.25 if response.count('</think>') == 1 else -0.25 # 思考闭合分
answer = answer.strip()
if reward_model is not None:
prompt = prompts[sample_idx]
pattern = r"<\|im_start\|>(system|user|assistant)\s+(.*?)<\|im_end\|>"
matches = re.findall(pattern, prompt, re.DOTALL)
messages = [{"role": role, "content": content.strip()} for role, content in matches]
score = reward_model.get_score(messages, answer)
reward += score # RM分
reward -= rep_penalty(answer)
rewards[idx] = max(min(reward, 3.0), -3.0) # 总分Clip
# -------- 有工具调用:执行结果奖励 --------
else:
gt = gt_batch[sample_idx]
valid_call_count = 0
for tool_call in tool_calls:
name, raw = tool_call.get("name", ""), tool_call.get("arguments", {})
if isinstance(raw, str):
try: raw = json.loads(raw)
except: raw = {}
check = CHECK_ARGS.get(name)
valid_call_count += int(bool(name in valid_names and check and check(raw)))
tool_gap = abs(valid_call_count - len(gt)) + max(0, len(tool_calls) - valid_call_count) # tool数差值
reward += 0.5 if tool_gap == 0 else -0.5 * tool_gap # tool对齐分
final_text = "" if unfinished else (answer.split('</tool_call>')[-1] if '</tool_call>' in answer else answer)
verified = validate_gt_in_text(final_text, gt) if gt else set()
if gt: reward += 2.5 * len(verified) / len(gt) # GT分
if unfinished: reward -= 0.5 # 未完成扣分
reward -= rep_penalty(final_text if final_text else answer)
rewards[idx] = max(min(reward, 3.0), -3.0) # 总分Clip
return rewards
# ================================ 工具与 Reward = End ================================
def rl_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_model=None, start_step=0, wandb=None, use_sglang=False):
last_step = start_step
for step, batch in enumerate(loader, start=start_step + 1):
messages_batch = batch['messages']
tools_batch = batch['tools']
gt_batch = batch['gt']
last_step = step
with torch.no_grad():
completions, contexts, prompt_ids_batch, response_ids_batch, response_masks_batch, response_old_logps_batch, turn_outputs_batch, unfinished_batch = rollout_batch(rollout_engine, tokenizer, messages_batch, tools_batch, args.num_generations, max_turns=3, max_new_tokens=args.max_gen_len, thinking_ratio=args.thinking_ratio, device=args.device)
prompts = [tokenizer.apply_chat_template(m, tokenize=False, add_generation_prompt=True, tools=t) for m, t in zip(messages_batch, tools_batch)]
packed_samples = []
for p, r, m, old_lp in zip(prompt_ids_batch, response_ids_batch, response_masks_batch, response_old_logps_batch):
ids = p + r
mask = [0] * len(p) + m
old_logps = [0.0] * max(len(p) - 1, 0) + old_lp
if len(ids) > args.max_total_len:
ids = ids[-args.max_total_len:]
mask = mask[-args.max_total_len:]
old_logps = old_logps[-(len(ids) - 1):]
prompt_len = next((i for i, v in enumerate(mask) if v == 1), len(mask))
packed_samples.append((ids, mask, prompt_len, old_logps))
seq_lens = torch.tensor([len(ids) for ids, _, _, _ in packed_samples], device=args.device)
max_len = seq_lens.max().item()
input_ids = torch.tensor([ids + [tokenizer.pad_token_id] * (max_len - len(ids)) for ids, _, _, _ in packed_samples], device=args.device)
prompt_lens = torch.tensor([prompt_len for _, _, prompt_len, _ in packed_samples], device=args.device)
full_response_masks = torch.tensor([mask + [0] * (max_len - len(mask)) for _, mask, _, _ in packed_samples], device=args.device, dtype=torch.float32)
old_per_token_logps = torch.tensor([old_logps + [0.0] * ((max_len - 1) - len(old_logps)) for _, _, _, old_logps in packed_samples], device=args.device, dtype=torch.float32)
full_mask = (input_ids != tokenizer.pad_token_id).long()
rewards = calculate_rewards(prompts, completions, gt_batch, tools_batch, args.num_generations, reward_model, device=args.device, turn_outputs_batch=turn_outputs_batch, unfinished_batch=unfinished_batch)
model_unwrapped = model.module if isinstance(model, DistributedDataParallel) else model
with autocast_ctx:
res = model_unwrapped(input_ids, attention_mask=full_mask)
aux_loss = res.aux_loss if lm_config.use_moe else torch.tensor(0.0, device=args.device)
logits = res.logits[:, :-1, :]
per_token_logps = F.log_softmax(logits, dim=-1).gather(2, input_ids[:, 1:].unsqueeze(-1)).squeeze(-1)
with torch.no_grad():
ref_per_token_logps = compute_per_token_logps(ref_model, input_ids, input_ids.size(1) - 1, attention_mask=full_mask)
completion_mask = full_response_masks[:, 1:]
is_eos = (input_ids[:, 1:] == tokenizer.eos_token_id) & completion_mask.bool()
eos_idx = torch.full((completion_mask.size(0),), completion_mask.size(1) - 1, device=args.device, dtype=torch.long)
has_eos = is_eos.any(dim=1)
eos_idx[has_eos] = is_eos.int().argmax(dim=1)[has_eos]
pos = torch.arange(completion_mask.size(1), device=args.device).unsqueeze(0)
completion_mask = completion_mask * (pos <= eos_idx.unsqueeze(1)).float()
token_counts = completion_mask.sum(dim=1)
valid_rows = token_counts > 0
if args.debug_mode and is_main_process() and step % args.debug_interval == 0:
for i in range(len(messages_batch)):
Logger(f"[DEBUG] step={step}, gt[{i}]: {repr(gt_batch[i])}")
Logger('-'*100)
for j in range(args.num_generations):
idx = i * args.num_generations + j
plen, slen = prompt_lens[idx].item(), seq_lens[idx].item()
Logger(f"{'=' * 30} [DEBUG] gen[{i}][{j}] CONTEXT_BEGIN {'=' * 30}")
Logger(contexts[idx])
Logger(f"{'=' * 31} [DEBUG] gen[{i}][{j}] CONTEXT_END {'=' * 31}")
Logger(f"[DEBUG] gen[{i}][{j}] prompt_len={plen}, seq_len={slen}")
tokens = input_ids[idx, plen:slen].tolist()
text = tokenizer.decode(tokens, skip_special_tokens=False)
Logger(f"{'=' * 28} [DEBUG] gen[{i}][{j}] COMPLETION_BEGIN [{plen}:{slen}] {'=' * 28}")
Logger(text)
Logger(f"{'=' * 29} [DEBUG] gen[{i}][{j}] COMPLETION_END {'=' * 29}")
Logger(f"[DEBUG] gen[{i}][{j}] reward={rewards[idx].item():.4f}")
Logger('='*100)
grouped_rewards = rewards.view(-1, args.num_generations)
mean_r = grouped_rewards.mean(dim=1).repeat_interleave(args.num_generations)
std_r = grouped_rewards.std(dim=1, unbiased=False).repeat_interleave(args.num_generations)
advantages = (rewards - mean_r) / (std_r + 1e-4)
kl_div = ref_per_token_logps - per_token_logps
per_token_kl = torch.exp(kl_div) - kl_div - 1
ratio = torch.exp(per_token_logps - old_per_token_logps)
if args.loss_type == "cispo":
clamped_ratio = torch.clamp(ratio, max=args.epsilon_high).detach()
per_token_loss = -(clamped_ratio * advantages.unsqueeze(1) * per_token_logps - args.beta * per_token_kl)
else:
clipped_ratio = torch.clamp(ratio, 1 - args.epsilon, 1 + args.epsilon)
per_token_loss1 = ratio * advantages.unsqueeze(1)
per_token_loss2 = clipped_ratio * advantages.unsqueeze(1)
per_token_loss = -(torch.min(per_token_loss1, per_token_loss2) - args.beta * per_token_kl)
policy_loss = (((per_token_loss * completion_mask).sum(dim=1)[valid_rows] / token_counts[valid_rows].clamp(min=1)).mean()
if valid_rows.any() else per_token_loss.sum() * 0.0)
loss = (policy_loss + aux_loss) / args.accumulation_steps
loss.backward()
if step % args.accumulation_steps == 0:
if args.grad_clip > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
optimizer.step(); scheduler.step(); optimizer.zero_grad()
if step % args.log_interval == 0 or step == iters:
pl = loss.item() * args.accumulation_steps
ar = rewards.mean().item()
al = token_counts.float().mean().item()
kl = ((ref_per_token_logps - per_token_logps) * completion_mask).sum().item() / max(token_counts.sum().item(), 1)
gs = grouped_rewards.std(dim=1, unbiased=False).mean().item()
am, ast = advantages.mean().item(), advantages.std().item()
lr = optimizer.param_groups[0]['lr']
Logger(f'Epoch:[{epoch+1}/{args.epochs}]({step}/{iters}), Reward:{ar:.4f}, KL:{kl:.4f}, GrpStd:{gs:.4f}, AdvStd:{ast:.4f}, Loss:{pl:.4f}, AvgLen:{al:.2f}, AdvMean:{am:.4f}, LR:{lr:.8f}')
if wandb and is_main_process():
wandb.log({"reward":ar,"kl_ref":kl,"group_reward_std":gs,"advantages_std":ast,"policy_loss":pl,"avg_response_len":al,"advantages_mean":am,"learning_rate":lr})
if (step % args.save_interval == 0 or step == iters) and is_main_process():
model.eval()
moe_suffix = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
raw_model = getattr(raw_model, '_orig_mod', raw_model)
state_dict = raw_model.state_dict()
torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer,
epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scheduler=scheduler)
model.train()
del state_dict
if step % args.save_interval == 0 or step == iters: rollout_engine.update_policy(model)
del per_token_logps, ref_per_token_logps
del completions, rewards, grouped_rewards, mean_r, std_r, advantages, completion_mask
if last_step > start_step and last_step % args.accumulation_steps != 0:
if args.grad_clip > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
optimizer.step(); scheduler.step(); optimizer.zero_grad()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MiniMind Agent RL")
parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
parser.add_argument('--save_weight', default='agent', type=str, help="保存权重名称")
parser.add_argument("--epochs", type=int, default=1, help="训练轮数")
parser.add_argument("--batch_size", type=int, default=2, help="批次大小")
parser.add_argument("--learning_rate", type=float, default=3e-7, help="学习率")
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
parser.add_argument("--dtype", type=str, default="bfloat16", help="数据类型 bfloat16/float16")
parser.add_argument("--num_workers", type=int, default=8, help="数据加载线程数")
parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数")
parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
parser.add_argument("--log_interval", type=int, default=1, help="日志打印间隔")
parser.add_argument("--save_interval", type=int, default=10, help="模型保存间隔")
parser.add_argument('--hidden_size', default=768, type=int, help="模型隐藏层维度")
parser.add_argument('--num_hidden_layers', default=8, type=int, help="模型层数")
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE")
parser.add_argument('--max_seq_len', default=1024, type=int, help="最大序列长度")
parser.add_argument("--max_gen_len", type=int, default=768, help="单次最大生成长度")
parser.add_argument("--max_total_len", type=int, default=2500, help="训练侧最终总长度上界")
parser.add_argument("--data_path", type=str, default="../dataset/agent_rl.jsonl", help="训练数据路径")
parser.add_argument("--num_generations", type=int, default=4, help="每个prompt生成数量")
parser.add_argument("--beta", type=float, default=0.1, help="KL散度惩罚系数")
parser.add_argument("--loss_type", type=str, default="cispo", choices=["grpo", "cispo"], help="loss类型")
parser.add_argument("--epsilon", type=float, default=0.2, help="GRPO的PPO clip epsilon")
parser.add_argument("--epsilon_high", type=float, default=5.0, help="epsilon上界")
parser.add_argument('--from_weight', default='full_sft', type=str, help="加载预训练权重名称")
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否从checkpoint恢复")
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb记录")
parser.add_argument("--wandb_project", type=str, default="MiniMind-Agent-RL", help="wandb项目名称")
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile")
parser.add_argument("--debug_mode", action="store_true", help="调试模式")
parser.add_argument("--debug_interval", type=int, default=20, help="调试日志间隔")
parser.add_argument("--thinking_ratio", type=float, default=0.1, help="按概率开启thinking0.0~1.0")
parser.add_argument("--reward_model_path", type=str, default="../../internlm2-1_8b-reward", help="Reward模型路径")
parser.add_argument("--rollout_engine", type=str, default="torch", choices=["torch", "sglang"], help="rollout引擎类型")
parser.add_argument("--sglang_base_url", type=str, default="http://localhost:8998", help="SGLang服务器URL")
parser.add_argument("--sglang_model_path", type=str, default="../model", help="SGLang tokenizer路径")
parser.add_argument("--sglang_shared_path", type=str, default="./sglang_ckpt_agent", help="SGLang共享存储路径")
args = parser.parse_args()
local_rank = init_distributed_mode()
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
os.makedirs(args.save_dir, exist_ok=True)
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers,
max_seq_len=args.max_seq_len + args.max_gen_len, use_moe=bool(args.use_moe))
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume == 1 else None
device_type = "cuda" if "cuda" in args.device else "cpu"
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
wandb = None
if args.use_wandb and is_main_process():
import swanlab as wandb
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
resume = 'must' if wandb_id else None
wandb.init(project=args.wandb_project, name=f"Agent-RL-E{args.epochs}-B{args.batch_size}-LR{args.learning_rate}", id=wandb_id, resume=resume)
model, tokenizer = init_model(lm_config, args.from_weight, device=args.device)
ref_model, _ = init_model(lm_config, args.from_weight, device=args.device)
ref_model = ref_model.eval().requires_grad_(False)
reward_model = LMForRewardModel(args.reward_model_path, device=args.device, dtype=torch.float16)
Logger(f'Loaded reward model from {args.reward_model_path}')
# Rollout引擎
rollout_engine = create_rollout_engine(
engine_type=args.rollout_engine,
policy_model=model,
tokenizer=tokenizer,
device=args.device,
autocast_ctx=autocast_ctx,
sglang_base_url=args.sglang_base_url,
sglang_model_path=args.sglang_model_path,
sglang_shared_path=args.sglang_shared_path,
)
train_ds = AgentRLDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
def collate_fn(batch): return {'messages': [b['messages'] for b in batch], 'tools': [b['tools'] for b in batch], 'gt': [b['gt'] for b in batch]}
loader_for_count = DataLoader(train_ds, batch_size=args.batch_size, sampler=train_sampler, collate_fn=collate_fn)
iters = len(loader_for_count)
total_optimizer_steps = math.ceil(iters / args.accumulation_steps) * args.epochs
scheduler = CosineAnnealingLR(optimizer, T_max=total_optimizer_steps, eta_min=args.learning_rate / 10)
start_epoch, start_step = 0, 0
if ckp_data:
model.load_state_dict(ckp_data['model'])
optimizer.load_state_dict(ckp_data['optimizer'])
scheduler.load_state_dict(ckp_data['scheduler'])
start_epoch = ckp_data['epoch']
start_step = ckp_data.get('step', 0)
if args.use_compile == 1:
model = torch.compile(model)
Logger('torch.compile enabled')
rollout_engine.update_policy(model)
if dist.is_initialized():
model = DistributedDataParallel(model, device_ids=[local_rank])
rollout_engine.update_policy(model)
for epoch in range(start_epoch, args.epochs):
train_sampler and train_sampler.set_epoch(epoch)
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
skip = start_step if (epoch == start_epoch and start_step > 0) else 0
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True, collate_fn=collate_fn)
if skip > 0:
Logger(f'Epoch [{epoch+1}/{args.epochs}]: skip {start_step} steps')
rl_train_epoch(epoch, loader, len(loader) + skip, rollout_engine, ref_model, reward_model, start_step, wandb, use_sglang = (args.rollout_engine == "sglang"))
else:
rl_train_epoch(epoch, loader, len(loader), rollout_engine, ref_model, reward_model, 0, wandb, use_sglang = (args.rollout_engine == "sglang"))
if dist.is_initialized():
dist.barrier()
dist.destroy_process_group()