一些小的改动

1、给模型定义和模型训练的相关程序添加注释,并移动代码位置,增加代码的可读性和可理解性
2、修改了训练因故中断后,无法正常从分布式环境中续训的问题
This commit is contained in:
翟锦洋 2026-04-12 18:06:52 +08:00
parent 88e675dc2c
commit ef40a1f271
11 changed files with 376 additions and 266 deletions

View File

@ -11,21 +11,17 @@ class MiniMindConfig(PretrainedConfig):
model_type = "minimind"
def __init__(self, hidden_size=768, num_hidden_layers=8, use_moe=False, **kwargs):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.use_moe = use_moe
self.dropout = kwargs.get("dropout", 0.0)
####################################################
# token相关
####################################################
self.vocab_size = kwargs.get("vocab_size", 6400)
self.bos_token_id = kwargs.get("bos_token_id", 1)
self.eos_token_id = kwargs.get("eos_token_id", 2)
self.flash_attn = kwargs.get("flash_attn", True)
self.num_attention_heads = kwargs.get("num_attention_heads", 8)
self.num_key_value_heads = kwargs.get("num_key_value_heads", 4)
self.head_dim = kwargs.get("head_dim", self.hidden_size // self.num_attention_heads)
self.hidden_act = kwargs.get("hidden_act", 'silu')
self.intermediate_size = kwargs.get("intermediate_size", math.ceil(hidden_size * math.pi / 64) * 64)
####################################################
# embedding相关
####################################################
self.max_position_embeddings = kwargs.get("max_position_embeddings", 32768)
self.rms_norm_eps = kwargs.get("rms_norm_eps", 1e-6)
self.rope_theta = kwargs.get("rope_theta", 1e6)
self.inference_rope_scaling = kwargs.get("inference_rope_scaling", False)
self.rope_scaling = {
@ -36,6 +32,35 @@ class MiniMindConfig(PretrainedConfig):
"attention_factor": 1.0,
"type": "yarn"
} if self.inference_rope_scaling else None
####################################################
# 表示空间Representation Space相关
####################################################
self.hidden_size = hidden_size
####################################################
# transformer相关
####################################################
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = kwargs.get("num_attention_heads", 8)
# GQA中的KV复用机制
self.num_key_value_heads = kwargs.get("num_key_value_heads", 4)
self.head_dim = kwargs.get("head_dim", self.hidden_size // self.num_attention_heads)
self.flash_attn = kwargs.get("flash_attn", True)
####################################################
# 前馈网络相关
####################################################
self.intermediate_size = kwargs.get("intermediate_size", math.ceil(hidden_size * math.pi / 64) * 64)
self.hidden_act = kwargs.get("hidden_act", 'silu')
####################################################
# 模型整体架构相关
####################################################
self.use_moe = use_moe
self.dropout = kwargs.get("dropout", 0.0)
self.rms_norm_eps = kwargs.get("rms_norm_eps", 1e-6)
### MoE specific configs (ignored if use_moe = False)
self.num_experts = kwargs.get("num_experts", 4)
self.num_experts_per_tok = kwargs.get("num_experts_per_tok", 1)
@ -46,17 +71,6 @@ class MiniMindConfig(PretrainedConfig):
# 🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏
# MiniMind Model
# 🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
return (self.weight * self.norm(x.float())).type_as(x)
def precompute_freqs_cis(dim: int, end: int = int(32 * 1024), rope_base: float = 1e6, rope_scaling: dict = None):
freqs, attn_factor = 1.0 / (rope_base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)), 1.0
@ -87,6 +101,18 @@ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
if n_rep == 1: return x
return (x[:, :, :, None, :].expand(bs, slen, num_key_value_heads, n_rep, head_dim).reshape(bs, slen, num_key_value_heads * n_rep, head_dim))
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
return (self.weight * self.norm(x.float())).type_as(x)
class Attention(nn.Module):
def __init__(self, config: MiniMindConfig):
super().__init__()

View File

@ -126,10 +126,12 @@ def convert_json_to_jinja(json_file_path, output_path):
if __name__ == '__main__':
lm_config = MiniMindConfig(hidden_size=768, num_hidden_layers=8, max_seq_len=8192, use_moe=True)
# 注意这里use_moe参数的配置默认使用非MoE模型
lm_config = MiniMindConfig(hidden_size=768, num_hidden_layers=8, max_seq_len=8192, use_moe=False)
# convert torch to transformers
torch_path = f"../out/full_sft_{lm_config.hidden_size}{'_moe' if lm_config.use_moe else ''}.pth"
transformers_path = '../minimind-3-moe'
transformers_path = '../minimind-3'
convert_torch2transformers(torch_path, transformers_path)
# # merge lora

View File

@ -241,15 +241,19 @@ def calculate_rewards(prompts, completions, gt_batch, tools_batch, num_gen, rewa
def rl_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_model=None, start_step=0, wandb=None, use_sglang=False):
last_step = start_step
for step, batch in enumerate(loader, start=start_step + 1):
########################### 训练前操作 ###########################
# 数据准备
messages_batch = batch['messages']
tools_batch = batch['tools']
gt_batch = batch['gt']
last_step = step
prompts = [tokenizer.apply_chat_template(m, tokenize=False, add_generation_prompt=True, tools=t) for m, t in zip(messages_batch, tools_batch)]
with torch.no_grad():
completions, contexts, prompt_ids_batch, response_ids_batch, response_masks_batch, response_old_logps_batch, turn_outputs_batch, unfinished_batch = rollout_batch(rollout_engine, tokenizer, messages_batch, tools_batch, args.num_generations, max_turns=3, max_new_tokens=args.max_gen_len, thinking_ratio=args.thinking_ratio, device=args.device)
prompts = [tokenizer.apply_chat_template(m, tokenize=False, add_generation_prompt=True, tools=t) for m, t in zip(messages_batch, tools_batch)]
# 数据处理(保证序列长度一致)
packed_samples = []
for p, r, m, old_lp in zip(prompt_ids_batch, response_ids_batch, response_masks_batch, response_old_logps_batch):
ids = p + r
@ -268,6 +272,8 @@ def rl_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_model
full_response_masks = torch.tensor([mask + [0] * (max_len - len(mask)) for _, mask, _, _ in packed_samples], device=args.device, dtype=torch.float32)
old_per_token_logps = torch.tensor([old_logps + [0.0] * ((max_len - 1) - len(old_logps)) for _, _, _, old_logps in packed_samples], device=args.device, dtype=torch.float32)
########################### 训练中操作 ###########################
# 数据计算
model_unwrapped = model.module if isinstance(model, DistributedDataParallel) else model
with autocast_ctx:
res = model_unwrapped(input_ids)
@ -316,6 +322,8 @@ def rl_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_model
kl_div = ref_per_token_logps - per_token_logps
per_token_kl = torch.exp(kl_div) - kl_div - 1
ratio = torch.exp(per_token_logps - old_per_token_logps)
# 定义损失函数
if args.loss_type == "cispo":
clamped_ratio = torch.clamp(ratio, max=args.epsilon_high).detach()
per_token_loss = -(clamped_ratio * advantages.unsqueeze(1) * per_token_logps - args.beta * per_token_kl)
@ -327,13 +335,18 @@ def rl_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_model
policy_loss = (((per_token_loss * completion_mask).sum(dim=1)[valid_rows] / token_counts[valid_rows].clamp(min=1)).mean()
if valid_rows.any() else per_token_loss.sum() * 0.0)
loss = (policy_loss + aux_loss) / args.accumulation_steps
# 反向传播
loss.backward()
# 梯度更新
if step % args.accumulation_steps == 0:
if args.grad_clip > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
optimizer.step(); scheduler.step(); optimizer.zero_grad()
if is_main_process() and step % args.save_interval == 0: rollout_engine.update_policy(model)
########################### 训练后操作 ###########################
# 日志打印
if step % args.log_interval == 0 or step == iters:
pl = loss.item() * args.accumulation_steps
ar = rewards.mean().item()
@ -346,6 +359,7 @@ def rl_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_model
if wandb and is_main_process():
wandb.log({"reward":ar,"kl_ref":kl,"group_reward_std":gs,"advantages_std":ast,"policy_loss":pl,"avg_response_len":al,"advantages_mean":am,"learning_rate":lr})
# 模型保存
if (step % args.save_interval == 0 or step == iters) and is_main_process():
model.eval()
moe_suffix = '_moe' if lm_config.use_moe else ''
@ -409,25 +423,20 @@ if __name__ == "__main__":
parser.add_argument("--sglang_shared_path", type=str, default="./sglang_ckpt_agent", help="SGLang共享存储路径")
args = parser.parse_args()
# ========== 1. 创建训练环境 ==========
local_rank = init_distributed_mode()
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
# ========== 2. 模型相关 ==========
os.makedirs(args.save_dir, exist_ok=True)
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers,
max_seq_len=args.max_seq_len + args.max_gen_len, use_moe=bool(args.use_moe))
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume == 1 else None
device_type = "cuda" if "cuda" in args.device else "cpu"
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
wandb = None
if args.use_wandb and is_main_process():
import swanlab as wandb
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
resume = 'must' if wandb_id else None
wandb.init(project=args.wandb_project, name=f"Agent-RL-E{args.epochs}-B{args.batch_size}-LR{args.learning_rate}", id=wandb_id, resume=resume)
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers,
max_seq_len=args.max_seq_len + args.max_gen_len, use_moe=bool(args.use_moe))
model, tokenizer = init_model(lm_config, args.from_weight, device=args.device)
@ -456,14 +465,6 @@ if __name__ == "__main__":
total_optimizer_steps = math.ceil(iters / args.accumulation_steps) * args.epochs
scheduler = CosineAnnealingLR(optimizer, T_max=total_optimizer_steps, eta_min=args.learning_rate / 10)
start_epoch, start_step = 0, 0
if ckp_data:
model.load_state_dict(ckp_data['model'])
optimizer.load_state_dict(ckp_data['optimizer'])
scheduler.load_state_dict(ckp_data['scheduler'])
start_epoch = ckp_data['epoch']
start_step = ckp_data.get('step', 0)
if args.use_compile == 1:
model = torch.compile(model)
Logger('torch.compile enabled')
@ -472,6 +473,25 @@ if __name__ == "__main__":
model = DistributedDataParallel(model, device_ids=[local_rank])
if is_main_process(): rollout_engine.update_policy(model)
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume == 1 else None
# ========== 3. checkpoint相关 ==========
wandb = None
if args.use_wandb and is_main_process():
import swanlab as wandb
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
resume = 'must' if wandb_id else None
wandb.init(project=args.wandb_project, name=f"Agent-RL-E{args.epochs}-B{args.batch_size}-LR{args.learning_rate}", id=wandb_id, resume=resume)
start_epoch, start_step = 0, 0
if ckp_data:
model.module.load_state_dict(ckp_data['model'])
optimizer.load_state_dict(ckp_data['optimizer'])
scheduler.load_state_dict(ckp_data['scheduler'])
start_epoch = ckp_data['epoch']
start_step = ckp_data.get('step', 0)
# ========== 4. 开始训练 ==========
for epoch in range(start_epoch, args.epochs):
train_sampler and train_sampler.set_epoch(epoch)
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
@ -484,4 +504,5 @@ if __name__ == "__main__":
else:
rl_train_epoch(epoch, loader, len(loader), rollout_engine, ref_model, reward_model, 0, wandb, use_sglang = (args.rollout_engine == "sglang"))
# ========== 5. 清理分布进程 ==========
if dist.is_initialized(): dist.destroy_process_group()

View File

@ -215,7 +215,7 @@ if __name__ == "__main__":
# ========== 6. 从ckp恢复状态 ==========
start_epoch, start_step = 0, 0
if ckp_data:
model.load_state_dict(ckp_data['model'])
model.module.load_state_dict(ckp_data['model'])
optimizer.load_state_dict(ckp_data['optimizer'])
scaler.load_state_dict(ckp_data['scaler'])
start_epoch = ckp_data['epoch']

View File

@ -54,6 +54,8 @@ def train_epoch(epoch, loader, iters, ref_model, lm_config, start_step=0, wandb=
last_step = start_step
for step, batch in enumerate(loader, start=start_step + 1):
########################### 训练前操作 ###########################
# 数据加载
last_step = step
x_chosen = batch['x_chosen'].to(args.device)
x_rejected = batch['x_rejected'].to(args.device)
@ -64,11 +66,14 @@ def train_epoch(epoch, loader, iters, ref_model, lm_config, start_step=0, wandb=
x = torch.cat([x_chosen, x_rejected], dim=0)
y = torch.cat([y_chosen, y_rejected], dim=0)
mask = torch.cat([mask_chosen, mask_rejected], dim=0)
lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
# 学习率调整
for param_group in optimizer.param_groups:
param_group['lr'] = lr
########################### 训练中操作 ###########################
# 模型前向传播
with autocast_ctx:
with torch.no_grad():
ref_outputs = ref_model(x)
@ -83,8 +88,10 @@ def train_epoch(epoch, loader, iters, ref_model, lm_config, start_step=0, wandb=
loss = dpo_loss_val + outputs.aux_loss
loss = loss / args.accumulation_steps
# 模型反向传播
scaler.scale(loss).backward()
# 梯度更新
if step % args.accumulation_steps == 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
@ -92,6 +99,8 @@ def train_epoch(epoch, loader, iters, ref_model, lm_config, start_step=0, wandb=
scaler.update()
optimizer.zero_grad(set_to_none=True)
########################### 训练后操作 ###########################
# 日志打印
if step % args.log_interval == 0 or step == iters:
spend_time = time.time() - start_time
current_loss = loss.item() * args.accumulation_steps
@ -104,6 +113,7 @@ def train_epoch(epoch, loader, iters, ref_model, lm_config, start_step=0, wandb=
if wandb: wandb.log({"loss": current_loss, "dpo_loss": current_dpo_loss, "aux_loss": current_aux_loss, "learning_rate": current_lr, "epoch_time": eta_min})
# 模型保存
if (step % args.save_interval == 0 or step == iters) and is_main_process():
model.eval()
moe_suffix = '_moe' if lm_config.use_moe else ''
@ -154,31 +164,20 @@ if __name__ == "__main__":
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速0=否1=是)")
args = parser.parse_args()
# ========== 1. 初始化环境和随机种子 ==========
# ========== 1. 创建训练环境 ==========
local_rank = init_distributed_mode()
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
# ========== 2. 配置目录、模型参数、检查ckp ==========
# ========== 2. 模型相关 ==========
os.makedirs(args.save_dir, exist_ok=True)
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe))
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
# ========== 3. 设置混合精度 ==========
device_type = "cuda" if "cuda" in args.device else "cpu"
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
# ========== 4. 配wandb ==========
wandb = None
if args.use_wandb and is_main_process():
import swanlab as wandb
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
resume = 'must' if wandb_id else None
wandb_run_name = f"MiniMind-DPO-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LR-{args.learning_rate}"
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
# ========== 5. 定义模型和参考模型 ==========
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe))
model, tokenizer = init_model(lm_config, args.from_weight, device=args.device)
Logger(f'策略模型总参数量:{sum(p.numel() for p in model.parameters()) / 1e6:.3f} M')
# 初始化参考模型ref_model冻结
@ -191,25 +190,34 @@ if __name__ == "__main__":
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
# ========== 6. 从ckp恢复状态 ==========
start_epoch, start_step = 0, 0
if ckp_data:
model.load_state_dict(ckp_data['model'])
optimizer.load_state_dict(ckp_data['optimizer'])
scaler.load_state_dict(ckp_data['scaler'])
start_epoch = ckp_data['epoch']
start_step = ckp_data.get('step', 0)
# ========== 7. 编译和分布式包装 ==========
if args.use_compile == 1:
model = torch.compile(model)
Logger('torch.compile enabled')
if dist.is_initialized():
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
model = DistributedDataParallel(model, device_ids=[local_rank])
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
# ========== 8. 开始训练 ==========
# ========== 3.checkpoint相关 ==========
wandb = None
if args.use_wandb and is_main_process():
import swanlab as wandb
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
resume = 'must' if wandb_id else None
wandb_run_name = f"MiniMind-DPO-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LR-{args.learning_rate}"
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
start_epoch, start_step = 0, 0
if ckp_data:
model.module.load_state_dict(ckp_data['model'])
optimizer.load_state_dict(ckp_data['optimizer'])
scaler.load_state_dict(ckp_data['scaler'])
start_epoch = ckp_data['epoch']
start_step = ckp_data.get('step', 0)
# ========== 4. 训练 ==========
for epoch in range(start_epoch, args.epochs):
train_sampler and train_sampler.set_epoch(epoch)
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
@ -222,5 +230,5 @@ if __name__ == "__main__":
else:
train_epoch(epoch, loader, len(loader), ref_model, lm_config, 0, wandb, args.beta)
# ========== 9. 清理分布进程 ==========
# ========== 5. 清理分布进程 ==========
if dist.is_initialized(): dist.destroy_process_group()

View File

@ -19,26 +19,32 @@ from trainer.trainer_utils import get_lr, Logger, is_main_process, lm_checkpoint
warnings.filterwarnings('ignore')
def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
start_time = time.time()
last_step = start_step
for step, (input_ids, labels) in enumerate(loader, start=start_step + 1):
########################### 训练前操作 ###########################
# 数据加载
input_ids = input_ids.to(args.device)
labels = labels.to(args.device)
last_step = step
lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
# 学习率调整
for param_group in optimizer.param_groups:
param_group['lr'] = lr
########################### 训练中操作 ###########################
# 模型前向传播
with autocast_ctx:
res = model(input_ids, labels=labels)
loss = res.loss + res.aux_loss
loss = loss / args.accumulation_steps
# 模型反向传播
scaler.scale(loss).backward()
if step % args.accumulation_steps == 0:
# 梯度更新
if (step + 1) % args.accumulation_steps == 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
@ -47,17 +53,20 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
optimizer.zero_grad(set_to_none=True)
if step % args.log_interval == 0 or step == iters:
########################### 训练后操作 ###########################
# 日志打印
if step % args.log_interval == 0 or step == iters - 1:
spend_time = time.time() - start_time
current_loss = loss.item() * args.accumulation_steps
current_aux_loss = res.aux_loss.item() if res.aux_loss is not None else 0.0
current_logits_loss = current_loss - current_aux_loss
current_lr = optimizer.param_groups[-1]['lr']
eta_min = spend_time / max(step - start_step, 1) * (iters - step) // 60
eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, aux_loss: {current_aux_loss:.4f}, lr: {current_lr:.8f}, epoch_time: {eta_min:.1f}min')
if wandb: wandb.log({"loss": current_loss, "logits_loss": current_logits_loss, "aux_loss": current_aux_loss, "learning_rate": current_lr, "epoch_time": eta_min})
if (step % args.save_interval == 0 or step == iters) and is_main_process():
# 模型保存
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
model.eval()
moe_suffix = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
@ -65,20 +74,12 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
raw_model = getattr(raw_model, '_orig_mod', raw_model)
state_dict = raw_model.state_dict()
torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer,
epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scaler=scaler)
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
model.train()
del state_dict
del input_ids, labels, res, loss
if last_step > start_step and last_step % args.accumulation_steps != 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MiniMind Full SFT")
@ -86,7 +87,7 @@ if __name__ == "__main__":
parser.add_argument('--save_weight', default='full_sft', type=str, help="保存权重的前缀名")
parser.add_argument("--epochs", type=int, default=2, help="训练轮数")
parser.add_argument("--batch_size", type=int, default=16, help="batch size")
parser.add_argument("--learning_rate", type=float, default=1e-5, help="初始学习率")
parser.add_argument("--learning_rate", type=float, default=1e-6, help="初始学习率")
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
parser.add_argument("--num_workers", type=int, default=8, help="数据加载线程数")
@ -96,7 +97,7 @@ if __name__ == "__main__":
parser.add_argument("--save_interval", type=int, default=1000, help="模型保存间隔")
parser.add_argument('--hidden_size', default=768, type=int, help="隐藏层维度")
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
parser.add_argument('--max_seq_len', default=768, type=int, help="训练的最大截断长度中文1token≈1.5~1.7字符)")
parser.add_argument('--max_seq_len', default=340, type=int, help="训练的最大截断长度中文1token≈1.5~1.7字符)")
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构0=否1=是)")
parser.add_argument("--data_path", type=str, default="../dataset/sft_t2t_mini.jsonl", help="训练数据路径")
parser.add_argument('--from_weight', default='pretrain', type=str, help="基于哪个权重训练为none则不基于任何权重训练")
@ -106,55 +107,52 @@ if __name__ == "__main__":
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速0=否1=是)")
args = parser.parse_args()
# ========== 1. 初始化环境和随机种子 ==========
# ========== 1. 创建训练环境 ==========
local_rank = init_distributed_mode()
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
if dist.is_initialized(): args.device = f"cuda:{local_rank}" # 分布式训练时,设置设备
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0)) # 设置随机种子
# ========== 2. 配置目录、模型参数、检查ckp ==========
os.makedirs(args.save_dir, exist_ok=True)
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe))
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
# ========== 3. 设置混合精度 ==========
# ========== 2. 模型相关 ==========
os.makedirs(args.save_dir, exist_ok=True) # 创建模型保存目录
device_type = "cuda" if "cuda" in args.device else "cpu"
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
# ========== 4. 配wandb ==========
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype) # 设置混合精度
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe))
model, tokenizer = init_model(lm_config, args.from_weight, device=args.device)
if args.use_compile == 1:
model = torch.compile(model)
Logger('torch.compile enabled')
train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate) # 所有模型相关的初始化
if dist.is_initialized():
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
model = DistributedDataParallel(model, device_ids=[local_rank]) # 分布式训练模型初始化
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None # 检查checkpoint
# ========== 3. checkpoint相关 ==========
wandb = None
if args.use_wandb and is_main_process():
import swanlab as wandb
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
resume = 'must' if wandb_id else None
wandb_run_name = f"MiniMind-Full-SFT-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
# ========== 5. 定义模型、数据、优化器 ==========
model, tokenizer = init_model(lm_config, args.from_weight, device=args.device)
train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
# ========== 6. 从ckp恢复状态 ==========
wandb_run_name = f"MiniMind-SFT-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume) # 通过checkpoint进行训练可视化
start_epoch, start_step = 0, 0
if ckp_data:
model.load_state_dict(ckp_data['model'])
model.module.load_state_dict(ckp_data['model'])
optimizer.load_state_dict(ckp_data['optimizer'])
scaler.load_state_dict(ckp_data['scaler'])
start_epoch = ckp_data['epoch']
start_step = ckp_data.get('step', 0)
start_step = ckp_data.get('step', 0) # 通过checkpoint进行状态恢复
# ========== 7. 编译和分布式包装 ==========
if args.use_compile == 1:
model = torch.compile(model)
Logger('torch.compile enabled')
if dist.is_initialized():
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
model = DistributedDataParallel(model, device_ids=[local_rank])
# ========== 8. 开始训练 ==========
# ========== 4. 训练 ==========
for epoch in range(start_epoch, args.epochs):
train_sampler and train_sampler.set_epoch(epoch)
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
@ -167,5 +165,5 @@ if __name__ == "__main__":
else:
train_epoch(epoch, loader, len(loader), 0, wandb)
# ========== 9. 清理分布进程 ==========
if dist.is_initialized(): dist.destroy_process_group()
# ========== 5. 撤销训练环境 ==========
if dist.is_initialized(): dist.destroy_process_group() # 撤销分布式训练环境

View File

@ -299,7 +299,7 @@ if __name__ == "__main__":
# ========== 6. 从ckp恢复状态 ==========
start_epoch, start_step = 0, 0
if ckp_data:
model.load_state_dict(ckp_data['model'])
model.module.load_state_dict(ckp_data['model'])
optimizer.load_state_dict(ckp_data['optimizer'])
scheduler.load_state_dict(ckp_data['scheduler'])
start_epoch = ckp_data['epoch']

View File

@ -25,20 +25,28 @@ def train_epoch(epoch, loader, iters, lora_params, start_step=0, wandb=None):
start_time = time.time()
last_step = start_step
for step, (input_ids, labels) in enumerate(loader, start=start_step + 1):
########################### 训练前操作 ###########################
# 数据加载
input_ids = input_ids.to(args.device)
labels = labels.to(args.device)
last_step = step
# 学习率调整
lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
########################### 训练中操作 ###########################
# 模型前向传播
with autocast_ctx:
res = model(input_ids, labels=labels)
loss = res.loss + res.aux_loss
loss = loss / args.accumulation_steps
# 模型反向传播
scaler.scale(loss).backward()
# 梯度更新
if step % args.accumulation_steps == 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(lora_params, args.grad_clip)
@ -46,6 +54,8 @@ def train_epoch(epoch, loader, iters, lora_params, start_step=0, wandb=None):
scaler.update()
optimizer.zero_grad(set_to_none=True)
########################### 训练后操作 ###########################
# 日志打印
if step % args.log_interval == 0 or step == iters:
spend_time = time.time() - start_time
current_loss = loss.item() * args.accumulation_steps
@ -56,10 +66,10 @@ def train_epoch(epoch, loader, iters, lora_params, start_step=0, wandb=None):
Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, aux_loss: {current_aux_loss:.4f}, lr: {current_lr:.8f}, epoch_time: {eta_min:.1f}min')
if wandb: wandb.log({"loss": current_loss, "logits_loss": current_logits_loss, "aux_loss": current_aux_loss, "learning_rate": current_lr, "epoch_time": eta_min})
# 模型保存仅保存LoRA权重
if (step % args.save_interval == 0 or step == iters) and is_main_process():
model.eval()
lora_save_path = f'{args.save_dir}/{args.lora_name}_{lm_config.hidden_size}.pth'
# LoRA只保存LoRA权重
save_lora(model, lora_save_path)
lm_checkpoint(lm_config, weight=args.lora_name, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
model.train()
@ -99,41 +109,27 @@ if __name__ == "__main__":
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速0=否1=是)")
args = parser.parse_args()
# ========== 1. 初始化环境和随机种子 ==========
# ========== 1. 创建训练环境 ==========
local_rank = init_distributed_mode()
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
# ========== 2. 配置目录、模型参数、检查ckp ==========
os.makedirs(args.save_dir, exist_ok=True)
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe))
ckp_data = lm_checkpoint(lm_config, weight=args.lora_name, save_dir='../checkpoints') if args.from_resume==1 else None
# ========== 3. 设置混合精度 ==========
# ========== 2. 模型相关 ==========
os.makedirs(args.save_dir, exist_ok=True)
device_type = "cuda" if "cuda" in args.device else "cpu"
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
# ========== 4. 配wandb ==========
wandb = None
if args.use_wandb and is_main_process():
import swanlab as wandb
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
resume = 'must' if wandb_id else None
wandb_run_name = f"MiniMind-LoRA-{args.lora_name}-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LR-{args.learning_rate}"
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
# ========== 5. 定义模型、应用LoRA、冻结非LoRA参数 ==========
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe))
model, tokenizer = init_model(lm_config, args.from_weight, device=args.device)
apply_lora(model)
# 统计参数
total_params = sum(p.numel() for p in model.parameters())
lora_params_count = sum(p.numel() for name, p in model.named_parameters() if 'lora' in name)
Logger(f"LLM 总参数量: {total_params / 1e6:.3f} M")
Logger(f"LoRA 参数量: {lora_params_count / 1e6:.3f} M")
Logger(f"LoRA 参数占比: {lora_params_count / total_params * 100:.2f}%")
# 冻结非LoRA参数收集LoRA参数
lora_params = []
for name, param in model.named_parameters():
@ -143,13 +139,29 @@ if __name__ == "__main__":
else:
param.requires_grad = False
# ========== 6. 定义数据和优化器 ==========
train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
optimizer = optim.AdamW(lora_params, lr=args.learning_rate)
if args.use_compile == 1:
model = torch.compile(model)
Logger('torch.compile enabled')
if dist.is_initialized():
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
model = DistributedDataParallel(model, device_ids=[local_rank])
ckp_data = lm_checkpoint(lm_config, weight=args.lora_name, save_dir='../checkpoints') if args.from_resume==1 else None
# ========== 3. checkpoint相关 ==========
wandb = None
if args.use_wandb and is_main_process():
import swanlab as wandb
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
resume = 'must' if wandb_id else None
wandb_run_name = f"MiniMind-LoRA-{args.lora_name}-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LR-{args.learning_rate}"
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
# ========== 7. 从ckp恢复状态 ==========
start_epoch, start_step = 0, 0
if ckp_data:
model.load_state_dict(ckp_data['model'], strict=False)
@ -158,15 +170,7 @@ if __name__ == "__main__":
start_epoch = ckp_data['epoch']
start_step = ckp_data.get('step', 0)
# ========== 8. 编译和分布式包装 ==========
if args.use_compile == 1:
model = torch.compile(model)
Logger('torch.compile enabled')
if dist.is_initialized():
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
model = DistributedDataParallel(model, device_ids=[local_rank])
# ========== 9. 开始训练 ==========
# ========== 4. 训练 ==========
for epoch in range(start_epoch, args.epochs):
train_sampler and train_sampler.set_epoch(epoch)
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
@ -179,5 +183,5 @@ if __name__ == "__main__":
else:
train_epoch(epoch, loader, len(loader), lora_params, 0, wandb)
# ========== 10. 清理分布进程 ==========
# ========== 5. 撤销训练环境 ==========
if dist.is_initialized(): dist.destroy_process_group()

View File

@ -25,28 +25,24 @@ from trainer.rollout_engine import create_rollout_engine
warnings.filterwarnings('ignore')
def rep_penalty(text, n=3, cap=0.5):
toks = re.findall(r"\w+|[^\w\s]", text.lower())
grams = [tuple(toks[i:i + n]) for i in range(len(toks) - n + 1)]
return min(cap, (len(grams) - len(set(grams))) * cap * 2 / len(grams)) if grams else 0.0
# 自定义的Critic模型继承自MiniMindLM
class CriticModel(MiniMindForCausalLM):
def __init__(self, params):
super().__init__(params)
# 替换lm_head为输出单一价值的线性层
self.value_head = nn.Linear(params.hidden_size, 1)
def forward(self, input_ids=None, attention_mask=None, **kwargs):
# 使用基础模型获取隐藏状态
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
hidden_states = self.model.norm(outputs[0])
# 使用value_head获取价值估计
# 使用value_head替代lm_head,获取价值估计
values = self.value_head(hidden_states).squeeze(-1)
return values
def rep_penalty(text, n=3, cap=0.5):
toks = re.findall(r"\w+|[^\w\s]", text.lower())
grams = [tuple(toks[i:i + n]) for i in range(len(toks) - n + 1)]
return min(cap, (len(grams) - len(set(grams))) * cap * 2 / len(grams)) if grams else 0.0
def calculate_rewards(prompts, responses, reward_model):
rewards = torch.zeros(len(responses), device=args.device)
@ -81,6 +77,8 @@ def ppo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, actor_sched
grad_accum_step = 0
for step, batch in enumerate(loader, start=start_step + 1):
########################### 训练前操作 ###########################
# 数据准备
prompts = batch["prompt"] # list[str], length B
enc = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True, max_length=args.max_seq_len,
padding_side="left").to(args.device) # input_ids: [B, P], attention_mask: [B, P]
@ -110,7 +108,8 @@ def ppo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, actor_sched
Logger(f"{'=' * 29} [DEBUG] sample[{i}] RESPONSE_END {'=' * 29}")
Logger(f"[DEBUG] reward={rewards[i].item():.4f}")
Logger('='*100)
# 数据处理
full_mask = (gen_out != tokenizer.pad_token_id).long() # [B, P+R]
labels = gen_out[:, 1:].clone() # [B, P+R-1]
seq_len, resp_start = gen_out.size(1) - 1, prompt_length - 1
@ -126,6 +125,8 @@ def ppo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, actor_sched
resp_policy_mask = ((resp_idx < resp_lengths.unsqueeze(1)) & resp_pad_mask).float()
resp_value_mask = resp_policy_mask.clone()
########################### 训练中操作 ###########################
# 数据计算初始化优势函数advantages和价值函数returns
with torch.no_grad(): # Rollout阶段只需推理获取old_logp和old_values切断梯度省显存
critic_for_rollout = critic_model.module if isinstance(critic_model, DistributedDataParallel) else critic_model
values_seq = critic_for_rollout(input_ids=gen_out, attention_mask=full_mask)
@ -155,7 +156,8 @@ def ppo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, actor_sched
adv_mean = (advantages * resp_policy_mask).sum() / resp_policy_mask.sum().clamp(min=1)
adv_var = ((advantages - adv_mean) ** 2 * resp_policy_mask).sum() / resp_policy_mask.sum().clamp(min=1)
advantages = (advantages - adv_mean) * torch.rsqrt(adv_var + 1e-8) * resp_policy_mask
# 训练参数初始化
mb_size = max(1, min(args.mini_batch_size, B))
stop_ppo = False
policy_loss_sum = 0.0
@ -167,6 +169,8 @@ def ppo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, actor_sched
log_count = 0
actor_unwrapped = actor_model.module if isinstance(actor_model, DistributedDataParallel) else actor_model
critic_unwrapped = critic_model.module if isinstance(critic_model, DistributedDataParallel) else critic_model
# 强化学习训练迭代
for ppo_epoch in range(args.ppo_update_iters):
if stop_ppo:
break
@ -240,6 +244,7 @@ def ppo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, actor_sched
actor_optimizer.zero_grad()
critic_optimizer.zero_grad()
# 梯度更新
if grad_accum_step % args.accumulation_steps != 0:
clip_grad_norm_(actor_model.parameters(), args.grad_clip)
clip_grad_norm_(critic_model.parameters(), args.grad_clip)
@ -249,9 +254,12 @@ def ppo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, actor_sched
critic_scheduler.step()
actor_optimizer.zero_grad()
critic_optimizer.zero_grad()
if is_main_process() and step % args.save_interval == 0: rollout_engine.update_policy(actor_model)
########################### 训练后操作 ###########################
# 模型更新
if is_main_process() and step % args.save_interval == 0: rollout_engine.update_policy(actor_model)
# 日志打印
if is_main_process():
critic_loss_val = value_loss_sum / max(log_count, 1)
reward_val = rewards.mean().item()
@ -278,6 +286,7 @@ def ppo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, actor_sched
f"ClipFrac: {clipfrac_val:.4f}, Critic Loss: {critic_loss_val:.4f}, "
f"Avg Response Len: {avg_len_val:.2f}, Actor LR: {actor_lr:.8f}, Critic LR: {critic_lr:.8f}")
# 模型保存
if (step % args.save_interval == 0 or step == iters) and is_main_process():
actor_model.eval()
moe_suffix = '_moe' if lm_config.use_moe else ''
@ -345,43 +354,37 @@ if __name__ == "__main__":
parser.add_argument("--sglang_shared_path", type=str, default="./sglang_ckpt_ppo", help="SGLang共享存储路径")
args = parser.parse_args()
# ========== 1. 初始化环境和随机种子 ==========
# ========== 1. 创建训练环境 ==========
local_rank = init_distributed_mode()
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
# ========== 2. 配置目录、模型参数、检查ckp ==========
# ========== 2. 模型相关 ==========
os.makedirs(args.save_dir, exist_ok=True)
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe))
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
# ========== 3. 设置混合精度 ==========
device_type = "cuda" if "cuda" in args.device else "cpu"
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
# ========== 4. 配wandb ==========
wandb = None
if args.use_wandb and is_main_process():
import swanlab as wandb
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
resume = 'must' if wandb_id else None
wandb_run_name = f"MiniMind-PPO-Epoch-{args.epochs}-BS-{args.batch_size}-LR-{args.learning_rate}"
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
# ========== 5. 初始化模型和数据 ==========
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe))
base_weight = args.from_weight
# Actor模型
# LLM_PPO四大模型初始化
actor_model, tokenizer = init_model(lm_config, base_weight, device=args.device)
ref_model, _ = init_model(lm_config, base_weight, device=args.device)
ref_model = ref_model.eval().requires_grad_(False)
moe_suffix = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/{base_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
state_dict = torch.load(ckp, map_location=args.device)
critic_model = CriticModel(lm_config)
critic_model.load_state_dict(state_dict, strict=False)
critic_model = critic_model.to(args.device)
reward_model = LMForRewardModel(args.reward_model_path, device=args.device, dtype=torch.float16)
# Rollout引擎
rollout_engine = create_rollout_engine(
engine_type=args.rollout_engine,
@ -393,6 +396,7 @@ if __name__ == "__main__":
sglang_model_path=args.sglang_model_path,
sglang_shared_path=args.sglang_shared_path,
)
train_ds = RLAIFDataset(args.data_path, tokenizer, max_length=(args.max_seq_len + args.max_gen_len), thinking_ratio=args.thinking_ratio)
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
actor_optimizer = optim.AdamW(actor_model.parameters(), lr=args.learning_rate)
@ -404,18 +408,6 @@ if __name__ == "__main__":
actor_scheduler = CosineAnnealingLR(actor_optimizer, T_max=total_optimizer_steps, eta_min=args.learning_rate / 10)
critic_scheduler = CosineAnnealingLR(critic_optimizer, T_max=total_optimizer_steps, eta_min=args.critic_learning_rate / 10)
start_epoch, start_step = 0, 0
if ckp_data:
actor_model.load_state_dict(ckp_data['model'])
critic_model.load_state_dict(ckp_data['critic_model'])
actor_optimizer.load_state_dict(ckp_data['optimizer'])
critic_optimizer.load_state_dict(ckp_data['critic_optimizer'])
actor_scheduler.load_state_dict(ckp_data['scheduler'])
critic_scheduler.load_state_dict(ckp_data['critic_scheduler'])
start_epoch = ckp_data['epoch']
start_step = ckp_data.get('step', 0)
# ========== 7. 编译和分布式包装 ==========
if args.use_compile == 1:
actor_model = torch.compile(actor_model)
Logger('torch.compile enabled')
@ -426,8 +418,30 @@ if __name__ == "__main__":
actor_model = DistributedDataParallel(actor_model, device_ids=[local_rank])
critic_model = DistributedDataParallel(critic_model, device_ids=[local_rank])
if is_main_process(): rollout_engine.update_policy(actor_model)
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
# ========== 3. checkpoint相关 ==========
wandb = None
if args.use_wandb and is_main_process():
import swanlab as wandb
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
resume = 'must' if wandb_id else None
wandb_run_name = f"MiniMind-PPO-Epoch-{args.epochs}-BS-{args.batch_size}-LR-{args.learning_rate}"
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
start_epoch, start_step = 0, 0
if ckp_data:
actor_model.module.load_state_dict(ckp_data['model'])
critic_model.module.load_state_dict(ckp_data['critic_model'])
actor_optimizer.load_state_dict(ckp_data['optimizer'])
critic_optimizer.load_state_dict(ckp_data['critic_optimizer'])
actor_scheduler.load_state_dict(ckp_data['scheduler'])
critic_scheduler.load_state_dict(ckp_data['critic_scheduler'])
start_epoch = ckp_data['epoch']
start_step = ckp_data.get('step', 0)
# ========== 8. 开始训练 ==========
# ========== 4. 开始训练 ==========
for epoch in range(start_epoch, args.epochs):
train_sampler and train_sampler.set_epoch(epoch)
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
@ -440,5 +454,5 @@ if __name__ == "__main__":
else:
ppo_train_epoch(epoch, loader, len(loader), rollout_engine, ref_model, actor_scheduler, critic_scheduler, reward_model, 0, wandb, use_sglang = (args.rollout_engine == "sglang"))
# ========== 9. 清理分布进程 ==========
# ========== 5. 清理分布进程 ==========
if dist.is_initialized(): dist.destroy_process_group()

View File

@ -22,23 +22,30 @@ warnings.filterwarnings('ignore')
def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
start_time = time.time()
last_step = start_step
for step, (input_ids, labels) in enumerate(loader, start=start_step + 1):
########################### 训练前操作 ###########################
# 数据加载
input_ids = input_ids.to(args.device)
labels = labels.to(args.device)
last_step = step
lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
# 学习率调整
for param_group in optimizer.param_groups:
param_group['lr'] = lr
########################### 训练中操作 ###########################
# 模型前向传播
with autocast_ctx:
res = model(input_ids, labels=labels)
loss = res.loss + res.aux_loss
loss = loss / args.accumulation_steps
# 模型反向传播
scaler.scale(loss).backward()
if step % args.accumulation_steps == 0:
# 梯度更新
if (step + 1) % args.accumulation_steps == 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
@ -47,17 +54,20 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
optimizer.zero_grad(set_to_none=True)
if step % args.log_interval == 0 or step == iters:
########################### 训练后操作 ###########################
# 日志打印
if step % args.log_interval == 0 or step == iters - 1:
spend_time = time.time() - start_time
current_loss = loss.item() * args.accumulation_steps
current_aux_loss = res.aux_loss.item() if res.aux_loss is not None else 0.0
current_logits_loss = current_loss - current_aux_loss
current_lr = optimizer.param_groups[-1]['lr']
eta_min = spend_time / max(step - start_step, 1) * (iters - step) // 60
eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, aux_loss: {current_aux_loss:.4f}, lr: {current_lr:.8f}, epoch_time: {eta_min:.1f}min')
if wandb: wandb.log({"loss": current_loss, "logits_loss": current_logits_loss, "aux_loss": current_aux_loss, "learning_rate": current_lr, "epoch_time": eta_min})
if (step % args.save_interval == 0 or step == iters) and is_main_process():
# 模型保存
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
model.eval()
moe_suffix = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
@ -71,19 +81,12 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
del input_ids, labels, res, loss
if last_step > start_step and last_step % args.accumulation_steps != 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MiniMind Pretraining")
parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
parser.add_argument('--save_weight', default='pretrain', type=str, help="保存权重的前缀名")
parser.add_argument("--epochs", type=int, default=2, help="训练轮数")
parser.add_argument("--epochs", type=int, default=1, help="训练轮数建议1轮zero或2-6轮充分训练")
parser.add_argument("--batch_size", type=int, default=32, help="batch size")
parser.add_argument("--learning_rate", type=float, default=5e-4, help="初始学习率")
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
@ -105,55 +108,52 @@ if __name__ == "__main__":
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速0=否1=是)")
args = parser.parse_args()
# ========== 1. 初始化环境和随机种子 ==========
# ========== 1. 创建训练环境 ==========
local_rank = init_distributed_mode()
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
if dist.is_initialized(): args.device = f"cuda:{local_rank}" # 分布式训练时,设置设备
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0)) # 设置随机种子
# ========== 2. 配置目录、模型参数、检查ckp ==========
os.makedirs(args.save_dir, exist_ok=True)
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe))
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
# ========== 3. 设置混合精度 ==========
# ========== 2. 模型相关 ==========
os.makedirs(args.save_dir, exist_ok=True) # 创建模型保存目录
device_type = "cuda" if "cuda" in args.device else "cpu"
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
# ========== 4. 配wandb ==========
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype) # 设置混合精度
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe))
model, tokenizer = init_model(lm_config, args.from_weight, device=args.device)
if args.use_compile == 1:
model = torch.compile(model)
Logger('torch.compile enabled')
train_ds = PretrainDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate) # 所有模型相关的初始化
if dist.is_initialized():
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
model = DistributedDataParallel(model, device_ids=[local_rank]) # 分布式训练模型初始化
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None # 检查checkpoint
# ========== 3. checkpoint相关 ==========
wandb = None
if args.use_wandb and is_main_process():
import swanlab as wandb
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
resume = 'must' if wandb_id else None
wandb_run_name = f"MiniMind-Pretrain-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
# ========== 5. 定义模型、数据、优化器 ==========
model, tokenizer = init_model(lm_config, args.from_weight, device=args.device)
train_ds = PretrainDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
# ========== 6. 从ckp恢复状态 ==========
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume) # 通过checkpoint进行训练可视化
start_epoch, start_step = 0, 0
if ckp_data:
model.load_state_dict(ckp_data['model'])
model.module.load_state_dict(ckp_data['model'])
optimizer.load_state_dict(ckp_data['optimizer'])
scaler.load_state_dict(ckp_data['scaler'])
start_epoch = ckp_data['epoch']
start_step = ckp_data.get('step', 0)
start_step = ckp_data.get('step', 0) # 通过checkpoint进行状态恢复
# ========== 7. 编译和分布式包装 ==========
if args.use_compile == 1:
model = torch.compile(model)
Logger('torch.compile enabled')
if dist.is_initialized():
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
model = DistributedDataParallel(model, device_ids=[local_rank])
# ========== 8. 开始训练 ==========
# ========== 4. 训练 ==========
for epoch in range(start_epoch, args.epochs):
train_sampler and train_sampler.set_epoch(epoch)
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
@ -166,5 +166,5 @@ if __name__ == "__main__":
else:
train_epoch(epoch, loader, len(loader), 0, wandb)
# ========== 9. 清理分布进程 ==========
if dist.is_initialized(): dist.destroy_process_group()
# ========== 5. 撤销训练环境 ==========
if dist.is_initialized(): dist.destroy_process_group() # 撤销分布式训练环境

View File

@ -9,22 +9,32 @@ TOKENIZER_DIR = '../model_learn_tokenizer/'
VOCAB_SIZE = 6400
SPECIAL_TOKENS_NUM = 36
# 获取文本(train_tokenizer辅助函数)
def get_texts(data_path):
with open(data_path, 'r', encoding='utf-8', errors='ignore') as f:
for i, line in enumerate(f):
if i >= 10000: break # 选10000行测试
if i >= 10000: break # 选10000行测试(注释掉该行,可以进行所有数据集的文本读取)
try:
data = json.loads(line)
# 仅利用SFT数据集中conversations字段的content字段忽略reasoning_content、tools、tool_calls字段
contents = [item.get('content') for item in data.get('conversations', []) if item.get('content')]
if contents:
# 注意这里yeild的使用
yield "\n".join(contents)
except json.JSONDecodeError:
continue
def train_tokenizer(data_path, tokenizer_dir, vocab_size, special_tokens_num=SPECIAL_TOKENS_NUM):
tokenizer = Tokenizer(models.BPE())
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
###########################训练前操作###########################
# 获取文本
texts = get_texts(data_path)
# 创建tokenizer目录
os.makedirs(tokenizer_dir, exist_ok=True)
tokenizer_json_path = os.path.join(tokenizer_dir, "tokenizer.json")
tokenizer_config_path = os.path.join(tokenizer_dir, "tokenizer_config.json")
# special_tokens定义
special_tokens_list = [
"<|endoftext|>", "<|im_start|>", "<|im_end|>",
"<|object_ref_start|>", "<|object_ref_end|>", "<|box_start|>", "<|box_end|>", "<|quad_start|>", "<|quad_end|>",
@ -40,21 +50,28 @@ def train_tokenizer(data_path, tokenizer_dir, vocab_size, special_tokens_num=SPE
num_buffer = special_tokens_num - len(special_tokens_list + additional_tokens_list)
buffer_tokens = [f"<|buffer{i}|>" for i in range(1, num_buffer + 1)] # 预留一定数量的token位置
all_special_tokens = special_tokens_list + additional_tokens_list + buffer_tokens
# 分词器、训练器初始化及预处理
tokenizer = Tokenizer(models.BPE())
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
trainer = trainers.BpeTrainer(
vocab_size=vocab_size,
show_progress=True,
initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
special_tokens=all_special_tokens
)
texts = get_texts(data_path)
###########################训练中操作###########################
tokenizer.train_from_iterator(texts, trainer=trainer)
###########################训练后操作###########################
# 后处理
tokenizer.decoder = decoders.ByteLevel()
tokenizer.add_special_tokens(special_tokens_list)
os.makedirs(tokenizer_dir, exist_ok=True)
tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
# 修改并保存分词器json文件
tokenizer.save(tokenizer_json_path)
tokenizer.model.save(tokenizer_dir)
tokenizer_json_path = os.path.join(tokenizer_dir, "tokenizer.json")
with open(tokenizer_json_path, 'r', encoding='utf-8') as f:
tokenizer_data = json.load(f)
for token_info in tokenizer_data.get('added_tokens', []):
@ -63,6 +80,7 @@ def train_tokenizer(data_path, tokenizer_dir, vocab_size, special_tokens_num=SPE
with open(tokenizer_json_path, 'w', encoding='utf-8') as f:
json.dump(tokenizer_data, f, ensure_ascii=False, indent=2)
# 创建并保存分词器config文件
added_tokens_decoder = {}
for i, token in enumerate(all_special_tokens):
idx = tokenizer.token_to_id(token)
@ -101,13 +119,19 @@ def train_tokenizer(data_path, tokenizer_dir, vocab_size, special_tokens_num=SPE
"tokenizer_class": "PreTrainedTokenizerFast"
}
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w", encoding="utf-8") as f:
with open(tokenizer_config_path, "w", encoding="utf-8") as f:
json.dump(config, f, ensure_ascii=False, indent=4)
# 打印训练完成信息
print("Tokenizer training completed.")
def eval_tokenizer(tokenizer_dir):
from transformers import AutoTokenizer
###########################评估前操作###########################
# 加载tokenizer
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)
# 创建测试消息
messages = [
{"role": "system", "content": "你是一个优秀的聊天机器人,总是给我正确的回应!"},
{"role": "user", "content": '你来自哪里?'},
@ -119,14 +143,21 @@ def eval_tokenizer(tokenizer_dir):
messages,
tokenize=False
)
###########################评估中操作###########################
# 聊天模版测试
print('-'*100)
print(new_prompt)
# 基础信息测试
print('-'*100)
print('tokenizer词表长度', len(tokenizer))
model_inputs = tokenizer(new_prompt)
print('encoder长度', len(model_inputs['input_ids']))
response = tokenizer.decode(model_inputs['input_ids'], skip_special_tokens=False)
print('decoder一致性', response == new_prompt, "\n")
# 压缩率测试
print('-'*100)
print('压缩率测试Chars/Tokens')
test_texts = [
@ -150,6 +181,8 @@ def eval_tokenizer(tokenizer_dir):
print(f"样本 {i+1} | 字符数: {char_count:4} | Tokens: {token_count:3} | 压缩率: {compression_ratio:.2f}")
print(f"平均压缩率: {total_compression / len(test_texts):.2f}")
# 流式解码测试
print('-'*100)
print('流式解码(字节缓冲)测试:')
input_ids = model_inputs['input_ids']
@ -162,6 +195,10 @@ def eval_tokenizer(tokenizer_dir):
raw_tokens = [tokenizer.convert_ids_to_tokens(int(t)) for t in (token_cache if isinstance(token_cache, list) else [token_cache])]
print(f'Token ID: {str(display_ids):15} -> Raw: {str(raw_tokens):20} -> Decode Str: {current_decode}')
token_cache = []
###########################评估后操作###########################
# 打印评估完成信息
print("Tokenizer evaluation completed.")
if __name__ == '__main__':
train_tokenizer(DATA_PATH, TOKENIZER_DIR, VOCAB_SIZE)