mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-04-25 08:48:16 +08:00
一些小的改动
1、给模型定义和模型训练的相关程序添加注释,并移动代码位置,增加代码的可读性和可理解性 2、修改了训练因故中断后,无法正常从分布式环境中续训的问题
This commit is contained in:
parent
88e675dc2c
commit
ef40a1f271
@ -11,21 +11,17 @@ class MiniMindConfig(PretrainedConfig):
|
||||
model_type = "minimind"
|
||||
def __init__(self, hidden_size=768, num_hidden_layers=8, use_moe=False, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.use_moe = use_moe
|
||||
self.dropout = kwargs.get("dropout", 0.0)
|
||||
####################################################
|
||||
# token相关
|
||||
####################################################
|
||||
self.vocab_size = kwargs.get("vocab_size", 6400)
|
||||
self.bos_token_id = kwargs.get("bos_token_id", 1)
|
||||
self.eos_token_id = kwargs.get("eos_token_id", 2)
|
||||
self.flash_attn = kwargs.get("flash_attn", True)
|
||||
self.num_attention_heads = kwargs.get("num_attention_heads", 8)
|
||||
self.num_key_value_heads = kwargs.get("num_key_value_heads", 4)
|
||||
self.head_dim = kwargs.get("head_dim", self.hidden_size // self.num_attention_heads)
|
||||
self.hidden_act = kwargs.get("hidden_act", 'silu')
|
||||
self.intermediate_size = kwargs.get("intermediate_size", math.ceil(hidden_size * math.pi / 64) * 64)
|
||||
|
||||
####################################################
|
||||
# embedding相关
|
||||
####################################################
|
||||
self.max_position_embeddings = kwargs.get("max_position_embeddings", 32768)
|
||||
self.rms_norm_eps = kwargs.get("rms_norm_eps", 1e-6)
|
||||
self.rope_theta = kwargs.get("rope_theta", 1e6)
|
||||
self.inference_rope_scaling = kwargs.get("inference_rope_scaling", False)
|
||||
self.rope_scaling = {
|
||||
@ -36,6 +32,35 @@ class MiniMindConfig(PretrainedConfig):
|
||||
"attention_factor": 1.0,
|
||||
"type": "yarn"
|
||||
} if self.inference_rope_scaling else None
|
||||
|
||||
####################################################
|
||||
# 表示空间(Representation Space)相关
|
||||
####################################################
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
####################################################
|
||||
# transformer相关
|
||||
####################################################
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = kwargs.get("num_attention_heads", 8)
|
||||
# GQA中的KV复用机制
|
||||
self.num_key_value_heads = kwargs.get("num_key_value_heads", 4)
|
||||
self.head_dim = kwargs.get("head_dim", self.hidden_size // self.num_attention_heads)
|
||||
self.flash_attn = kwargs.get("flash_attn", True)
|
||||
|
||||
####################################################
|
||||
# 前馈网络相关
|
||||
####################################################
|
||||
self.intermediate_size = kwargs.get("intermediate_size", math.ceil(hidden_size * math.pi / 64) * 64)
|
||||
self.hidden_act = kwargs.get("hidden_act", 'silu')
|
||||
|
||||
####################################################
|
||||
# 模型整体架构相关
|
||||
####################################################
|
||||
self.use_moe = use_moe
|
||||
self.dropout = kwargs.get("dropout", 0.0)
|
||||
self.rms_norm_eps = kwargs.get("rms_norm_eps", 1e-6)
|
||||
|
||||
### MoE specific configs (ignored if use_moe = False)
|
||||
self.num_experts = kwargs.get("num_experts", 4)
|
||||
self.num_experts_per_tok = kwargs.get("num_experts_per_tok", 1)
|
||||
@ -46,17 +71,6 @@ class MiniMindConfig(PretrainedConfig):
|
||||
# 🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏
|
||||
# MiniMind Model
|
||||
# 🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def norm(self, x):
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
return (self.weight * self.norm(x.float())).type_as(x)
|
||||
|
||||
def precompute_freqs_cis(dim: int, end: int = int(32 * 1024), rope_base: float = 1e6, rope_scaling: dict = None):
|
||||
freqs, attn_factor = 1.0 / (rope_base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)), 1.0
|
||||
@ -87,6 +101,18 @@ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
if n_rep == 1: return x
|
||||
return (x[:, :, :, None, :].expand(bs, slen, num_key_value_heads, n_rep, head_dim).reshape(bs, slen, num_key_value_heads * n_rep, head_dim))
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def norm(self, x):
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
return (self.weight * self.norm(x.float())).type_as(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, config: MiniMindConfig):
|
||||
super().__init__()
|
||||
|
||||
@ -126,10 +126,12 @@ def convert_json_to_jinja(json_file_path, output_path):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
lm_config = MiniMindConfig(hidden_size=768, num_hidden_layers=8, max_seq_len=8192, use_moe=True)
|
||||
|
||||
# 注意这里use_moe参数的配置,默认使用非MoE模型
|
||||
lm_config = MiniMindConfig(hidden_size=768, num_hidden_layers=8, max_seq_len=8192, use_moe=False)
|
||||
# convert torch to transformers
|
||||
torch_path = f"../out/full_sft_{lm_config.hidden_size}{'_moe' if lm_config.use_moe else ''}.pth"
|
||||
transformers_path = '../minimind-3-moe'
|
||||
transformers_path = '../minimind-3'
|
||||
convert_torch2transformers(torch_path, transformers_path)
|
||||
|
||||
# # merge lora
|
||||
|
||||
@ -241,15 +241,19 @@ def calculate_rewards(prompts, completions, gt_batch, tools_batch, num_gen, rewa
|
||||
def rl_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_model=None, start_step=0, wandb=None, use_sglang=False):
|
||||
last_step = start_step
|
||||
for step, batch in enumerate(loader, start=start_step + 1):
|
||||
########################### 训练前操作 ###########################
|
||||
# 数据准备
|
||||
messages_batch = batch['messages']
|
||||
tools_batch = batch['tools']
|
||||
gt_batch = batch['gt']
|
||||
last_step = step
|
||||
|
||||
prompts = [tokenizer.apply_chat_template(m, tokenize=False, add_generation_prompt=True, tools=t) for m, t in zip(messages_batch, tools_batch)]
|
||||
|
||||
with torch.no_grad():
|
||||
completions, contexts, prompt_ids_batch, response_ids_batch, response_masks_batch, response_old_logps_batch, turn_outputs_batch, unfinished_batch = rollout_batch(rollout_engine, tokenizer, messages_batch, tools_batch, args.num_generations, max_turns=3, max_new_tokens=args.max_gen_len, thinking_ratio=args.thinking_ratio, device=args.device)
|
||||
|
||||
prompts = [tokenizer.apply_chat_template(m, tokenize=False, add_generation_prompt=True, tools=t) for m, t in zip(messages_batch, tools_batch)]
|
||||
# 数据处理(保证序列长度一致)
|
||||
packed_samples = []
|
||||
for p, r, m, old_lp in zip(prompt_ids_batch, response_ids_batch, response_masks_batch, response_old_logps_batch):
|
||||
ids = p + r
|
||||
@ -268,6 +272,8 @@ def rl_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_model
|
||||
full_response_masks = torch.tensor([mask + [0] * (max_len - len(mask)) for _, mask, _, _ in packed_samples], device=args.device, dtype=torch.float32)
|
||||
old_per_token_logps = torch.tensor([old_logps + [0.0] * ((max_len - 1) - len(old_logps)) for _, _, _, old_logps in packed_samples], device=args.device, dtype=torch.float32)
|
||||
|
||||
########################### 训练中操作 ###########################
|
||||
# 数据计算
|
||||
model_unwrapped = model.module if isinstance(model, DistributedDataParallel) else model
|
||||
with autocast_ctx:
|
||||
res = model_unwrapped(input_ids)
|
||||
@ -316,6 +322,8 @@ def rl_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_model
|
||||
kl_div = ref_per_token_logps - per_token_logps
|
||||
per_token_kl = torch.exp(kl_div) - kl_div - 1
|
||||
ratio = torch.exp(per_token_logps - old_per_token_logps)
|
||||
|
||||
# 定义损失函数
|
||||
if args.loss_type == "cispo":
|
||||
clamped_ratio = torch.clamp(ratio, max=args.epsilon_high).detach()
|
||||
per_token_loss = -(clamped_ratio * advantages.unsqueeze(1) * per_token_logps - args.beta * per_token_kl)
|
||||
@ -327,13 +335,18 @@ def rl_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_model
|
||||
policy_loss = (((per_token_loss * completion_mask).sum(dim=1)[valid_rows] / token_counts[valid_rows].clamp(min=1)).mean()
|
||||
if valid_rows.any() else per_token_loss.sum() * 0.0)
|
||||
loss = (policy_loss + aux_loss) / args.accumulation_steps
|
||||
|
||||
# 反向传播
|
||||
loss.backward()
|
||||
|
||||
# 梯度更新
|
||||
if step % args.accumulation_steps == 0:
|
||||
if args.grad_clip > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||
optimizer.step(); scheduler.step(); optimizer.zero_grad()
|
||||
if is_main_process() and step % args.save_interval == 0: rollout_engine.update_policy(model)
|
||||
|
||||
########################### 训练后操作 ###########################
|
||||
# 日志打印
|
||||
if step % args.log_interval == 0 or step == iters:
|
||||
pl = loss.item() * args.accumulation_steps
|
||||
ar = rewards.mean().item()
|
||||
@ -346,6 +359,7 @@ def rl_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_model
|
||||
if wandb and is_main_process():
|
||||
wandb.log({"reward":ar,"kl_ref":kl,"group_reward_std":gs,"advantages_std":ast,"policy_loss":pl,"avg_response_len":al,"advantages_mean":am,"learning_rate":lr})
|
||||
|
||||
# 模型保存
|
||||
if (step % args.save_interval == 0 or step == iters) and is_main_process():
|
||||
model.eval()
|
||||
moe_suffix = '_moe' if lm_config.use_moe else ''
|
||||
@ -409,25 +423,20 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--sglang_shared_path", type=str, default="./sglang_ckpt_agent", help="SGLang共享存储路径")
|
||||
args = parser.parse_args()
|
||||
|
||||
# ========== 1. 创建训练环境 ==========
|
||||
local_rank = init_distributed_mode()
|
||||
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
|
||||
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
|
||||
|
||||
# ========== 2. 模型相关 ==========
|
||||
os.makedirs(args.save_dir, exist_ok=True)
|
||||
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers,
|
||||
max_seq_len=args.max_seq_len + args.max_gen_len, use_moe=bool(args.use_moe))
|
||||
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume == 1 else None
|
||||
|
||||
|
||||
device_type = "cuda" if "cuda" in args.device else "cpu"
|
||||
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
|
||||
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
|
||||
|
||||
wandb = None
|
||||
if args.use_wandb and is_main_process():
|
||||
import swanlab as wandb
|
||||
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
|
||||
resume = 'must' if wandb_id else None
|
||||
wandb.init(project=args.wandb_project, name=f"Agent-RL-E{args.epochs}-B{args.batch_size}-LR{args.learning_rate}", id=wandb_id, resume=resume)
|
||||
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers,
|
||||
max_seq_len=args.max_seq_len + args.max_gen_len, use_moe=bool(args.use_moe))
|
||||
|
||||
model, tokenizer = init_model(lm_config, args.from_weight, device=args.device)
|
||||
|
||||
@ -456,14 +465,6 @@ if __name__ == "__main__":
|
||||
total_optimizer_steps = math.ceil(iters / args.accumulation_steps) * args.epochs
|
||||
scheduler = CosineAnnealingLR(optimizer, T_max=total_optimizer_steps, eta_min=args.learning_rate / 10)
|
||||
|
||||
start_epoch, start_step = 0, 0
|
||||
if ckp_data:
|
||||
model.load_state_dict(ckp_data['model'])
|
||||
optimizer.load_state_dict(ckp_data['optimizer'])
|
||||
scheduler.load_state_dict(ckp_data['scheduler'])
|
||||
start_epoch = ckp_data['epoch']
|
||||
start_step = ckp_data.get('step', 0)
|
||||
|
||||
if args.use_compile == 1:
|
||||
model = torch.compile(model)
|
||||
Logger('torch.compile enabled')
|
||||
@ -472,6 +473,25 @@ if __name__ == "__main__":
|
||||
model = DistributedDataParallel(model, device_ids=[local_rank])
|
||||
if is_main_process(): rollout_engine.update_policy(model)
|
||||
|
||||
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume == 1 else None
|
||||
|
||||
# ========== 3. checkpoint相关 ==========
|
||||
wandb = None
|
||||
if args.use_wandb and is_main_process():
|
||||
import swanlab as wandb
|
||||
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
|
||||
resume = 'must' if wandb_id else None
|
||||
wandb.init(project=args.wandb_project, name=f"Agent-RL-E{args.epochs}-B{args.batch_size}-LR{args.learning_rate}", id=wandb_id, resume=resume)
|
||||
|
||||
start_epoch, start_step = 0, 0
|
||||
if ckp_data:
|
||||
model.module.load_state_dict(ckp_data['model'])
|
||||
optimizer.load_state_dict(ckp_data['optimizer'])
|
||||
scheduler.load_state_dict(ckp_data['scheduler'])
|
||||
start_epoch = ckp_data['epoch']
|
||||
start_step = ckp_data.get('step', 0)
|
||||
|
||||
# ========== 4. 开始训练 ==========
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
train_sampler and train_sampler.set_epoch(epoch)
|
||||
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
|
||||
@ -484,4 +504,5 @@ if __name__ == "__main__":
|
||||
else:
|
||||
rl_train_epoch(epoch, loader, len(loader), rollout_engine, ref_model, reward_model, 0, wandb, use_sglang = (args.rollout_engine == "sglang"))
|
||||
|
||||
# ========== 5. 清理分布进程 ==========
|
||||
if dist.is_initialized(): dist.destroy_process_group()
|
||||
|
||||
@ -215,7 +215,7 @@ if __name__ == "__main__":
|
||||
# ========== 6. 从ckp恢复状态 ==========
|
||||
start_epoch, start_step = 0, 0
|
||||
if ckp_data:
|
||||
model.load_state_dict(ckp_data['model'])
|
||||
model.module.load_state_dict(ckp_data['model'])
|
||||
optimizer.load_state_dict(ckp_data['optimizer'])
|
||||
scaler.load_state_dict(ckp_data['scaler'])
|
||||
start_epoch = ckp_data['epoch']
|
||||
|
||||
@ -54,6 +54,8 @@ def train_epoch(epoch, loader, iters, ref_model, lm_config, start_step=0, wandb=
|
||||
last_step = start_step
|
||||
|
||||
for step, batch in enumerate(loader, start=start_step + 1):
|
||||
########################### 训练前操作 ###########################
|
||||
# 数据加载
|
||||
last_step = step
|
||||
x_chosen = batch['x_chosen'].to(args.device)
|
||||
x_rejected = batch['x_rejected'].to(args.device)
|
||||
@ -64,11 +66,14 @@ def train_epoch(epoch, loader, iters, ref_model, lm_config, start_step=0, wandb=
|
||||
x = torch.cat([x_chosen, x_rejected], dim=0)
|
||||
y = torch.cat([y_chosen, y_rejected], dim=0)
|
||||
mask = torch.cat([mask_chosen, mask_rejected], dim=0)
|
||||
|
||||
lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
|
||||
|
||||
# 学习率调整
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
########################### 训练中操作 ###########################
|
||||
# 模型前向传播
|
||||
with autocast_ctx:
|
||||
with torch.no_grad():
|
||||
ref_outputs = ref_model(x)
|
||||
@ -83,8 +88,10 @@ def train_epoch(epoch, loader, iters, ref_model, lm_config, start_step=0, wandb=
|
||||
loss = dpo_loss_val + outputs.aux_loss
|
||||
loss = loss / args.accumulation_steps
|
||||
|
||||
# 模型反向传播
|
||||
scaler.scale(loss).backward()
|
||||
|
||||
# 梯度更新
|
||||
if step % args.accumulation_steps == 0:
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||
@ -92,6 +99,8 @@ def train_epoch(epoch, loader, iters, ref_model, lm_config, start_step=0, wandb=
|
||||
scaler.update()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
########################### 训练后操作 ###########################
|
||||
# 日志打印
|
||||
if step % args.log_interval == 0 or step == iters:
|
||||
spend_time = time.time() - start_time
|
||||
current_loss = loss.item() * args.accumulation_steps
|
||||
@ -104,6 +113,7 @@ def train_epoch(epoch, loader, iters, ref_model, lm_config, start_step=0, wandb=
|
||||
|
||||
if wandb: wandb.log({"loss": current_loss, "dpo_loss": current_dpo_loss, "aux_loss": current_aux_loss, "learning_rate": current_lr, "epoch_time": eta_min})
|
||||
|
||||
# 模型保存
|
||||
if (step % args.save_interval == 0 or step == iters) and is_main_process():
|
||||
model.eval()
|
||||
moe_suffix = '_moe' if lm_config.use_moe else ''
|
||||
@ -154,31 +164,20 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速(0=否,1=是)")
|
||||
args = parser.parse_args()
|
||||
|
||||
# ========== 1. 初始化环境和随机种子 ==========
|
||||
# ========== 1. 创建训练环境 ==========
|
||||
local_rank = init_distributed_mode()
|
||||
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
|
||||
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
|
||||
|
||||
# ========== 2. 配置目录、模型参数、检查ckp ==========
|
||||
# ========== 2. 模型相关 ==========
|
||||
os.makedirs(args.save_dir, exist_ok=True)
|
||||
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe))
|
||||
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
|
||||
|
||||
# ========== 3. 设置混合精度 ==========
|
||||
|
||||
device_type = "cuda" if "cuda" in args.device else "cpu"
|
||||
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
|
||||
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
|
||||
|
||||
# ========== 4. 配wandb ==========
|
||||
wandb = None
|
||||
if args.use_wandb and is_main_process():
|
||||
import swanlab as wandb
|
||||
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
|
||||
resume = 'must' if wandb_id else None
|
||||
wandb_run_name = f"MiniMind-DPO-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LR-{args.learning_rate}"
|
||||
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
|
||||
|
||||
# ========== 5. 定义模型和参考模型 ==========
|
||||
|
||||
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe))
|
||||
|
||||
model, tokenizer = init_model(lm_config, args.from_weight, device=args.device)
|
||||
Logger(f'策略模型总参数量:{sum(p.numel() for p in model.parameters()) / 1e6:.3f} M')
|
||||
# 初始化参考模型(ref_model冻结)
|
||||
@ -191,25 +190,34 @@ if __name__ == "__main__":
|
||||
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
|
||||
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
|
||||
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
|
||||
|
||||
# ========== 6. 从ckp恢复状态 ==========
|
||||
start_epoch, start_step = 0, 0
|
||||
if ckp_data:
|
||||
model.load_state_dict(ckp_data['model'])
|
||||
optimizer.load_state_dict(ckp_data['optimizer'])
|
||||
scaler.load_state_dict(ckp_data['scaler'])
|
||||
start_epoch = ckp_data['epoch']
|
||||
start_step = ckp_data.get('step', 0)
|
||||
|
||||
# ========== 7. 编译和分布式包装 ==========
|
||||
|
||||
if args.use_compile == 1:
|
||||
model = torch.compile(model)
|
||||
Logger('torch.compile enabled')
|
||||
if dist.is_initialized():
|
||||
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
||||
model = DistributedDataParallel(model, device_ids=[local_rank])
|
||||
|
||||
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
|
||||
|
||||
# ========== 8. 开始训练 ==========
|
||||
# ========== 3.checkpoint相关 ==========
|
||||
wandb = None
|
||||
if args.use_wandb and is_main_process():
|
||||
import swanlab as wandb
|
||||
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
|
||||
resume = 'must' if wandb_id else None
|
||||
wandb_run_name = f"MiniMind-DPO-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LR-{args.learning_rate}"
|
||||
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
|
||||
|
||||
start_epoch, start_step = 0, 0
|
||||
if ckp_data:
|
||||
model.module.load_state_dict(ckp_data['model'])
|
||||
optimizer.load_state_dict(ckp_data['optimizer'])
|
||||
scaler.load_state_dict(ckp_data['scaler'])
|
||||
start_epoch = ckp_data['epoch']
|
||||
start_step = ckp_data.get('step', 0)
|
||||
|
||||
# ========== 4. 训练 ==========
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
train_sampler and train_sampler.set_epoch(epoch)
|
||||
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
|
||||
@ -222,5 +230,5 @@ if __name__ == "__main__":
|
||||
else:
|
||||
train_epoch(epoch, loader, len(loader), ref_model, lm_config, 0, wandb, args.beta)
|
||||
|
||||
# ========== 9. 清理分布进程 ==========
|
||||
# ========== 5. 清理分布进程 ==========
|
||||
if dist.is_initialized(): dist.destroy_process_group()
|
||||
@ -19,26 +19,32 @@ from trainer.trainer_utils import get_lr, Logger, is_main_process, lm_checkpoint
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
|
||||
def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
|
||||
start_time = time.time()
|
||||
last_step = start_step
|
||||
for step, (input_ids, labels) in enumerate(loader, start=start_step + 1):
|
||||
|
||||
########################### 训练前操作 ###########################
|
||||
# 数据加载
|
||||
input_ids = input_ids.to(args.device)
|
||||
labels = labels.to(args.device)
|
||||
last_step = step
|
||||
lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
|
||||
|
||||
# 学习率调整
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
########################### 训练中操作 ###########################
|
||||
# 模型前向传播
|
||||
with autocast_ctx:
|
||||
res = model(input_ids, labels=labels)
|
||||
loss = res.loss + res.aux_loss
|
||||
loss = loss / args.accumulation_steps
|
||||
|
||||
# 模型反向传播
|
||||
scaler.scale(loss).backward()
|
||||
|
||||
if step % args.accumulation_steps == 0:
|
||||
# 梯度更新
|
||||
if (step + 1) % args.accumulation_steps == 0:
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||
|
||||
@ -47,17 +53,20 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
|
||||
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
if step % args.log_interval == 0 or step == iters:
|
||||
########################### 训练后操作 ###########################
|
||||
# 日志打印
|
||||
if step % args.log_interval == 0 or step == iters - 1:
|
||||
spend_time = time.time() - start_time
|
||||
current_loss = loss.item() * args.accumulation_steps
|
||||
current_aux_loss = res.aux_loss.item() if res.aux_loss is not None else 0.0
|
||||
current_logits_loss = current_loss - current_aux_loss
|
||||
current_lr = optimizer.param_groups[-1]['lr']
|
||||
eta_min = spend_time / max(step - start_step, 1) * (iters - step) // 60
|
||||
eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
|
||||
Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, aux_loss: {current_aux_loss:.4f}, lr: {current_lr:.8f}, epoch_time: {eta_min:.1f}min')
|
||||
if wandb: wandb.log({"loss": current_loss, "logits_loss": current_logits_loss, "aux_loss": current_aux_loss, "learning_rate": current_lr, "epoch_time": eta_min})
|
||||
|
||||
if (step % args.save_interval == 0 or step == iters) and is_main_process():
|
||||
# 模型保存
|
||||
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
|
||||
model.eval()
|
||||
moe_suffix = '_moe' if lm_config.use_moe else ''
|
||||
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
|
||||
@ -65,20 +74,12 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
|
||||
raw_model = getattr(raw_model, '_orig_mod', raw_model)
|
||||
state_dict = raw_model.state_dict()
|
||||
torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
|
||||
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer,
|
||||
epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scaler=scaler)
|
||||
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
|
||||
model.train()
|
||||
del state_dict
|
||||
|
||||
del input_ids, labels, res, loss
|
||||
|
||||
if last_step > start_step and last_step % args.accumulation_steps != 0:
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="MiniMind Full SFT")
|
||||
@ -86,7 +87,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument('--save_weight', default='full_sft', type=str, help="保存权重的前缀名")
|
||||
parser.add_argument("--epochs", type=int, default=2, help="训练轮数")
|
||||
parser.add_argument("--batch_size", type=int, default=16, help="batch size")
|
||||
parser.add_argument("--learning_rate", type=float, default=1e-5, help="初始学习率")
|
||||
parser.add_argument("--learning_rate", type=float, default=1e-6, help="初始学习率")
|
||||
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
|
||||
parser.add_argument("--num_workers", type=int, default=8, help="数据加载线程数")
|
||||
@ -96,7 +97,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--save_interval", type=int, default=1000, help="模型保存间隔")
|
||||
parser.add_argument('--hidden_size', default=768, type=int, help="隐藏层维度")
|
||||
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
|
||||
parser.add_argument('--max_seq_len', default=768, type=int, help="训练的最大截断长度(中文1token≈1.5~1.7字符)")
|
||||
parser.add_argument('--max_seq_len', default=340, type=int, help="训练的最大截断长度(中文1token≈1.5~1.7字符)")
|
||||
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)")
|
||||
parser.add_argument("--data_path", type=str, default="../dataset/sft_t2t_mini.jsonl", help="训练数据路径")
|
||||
parser.add_argument('--from_weight', default='pretrain', type=str, help="基于哪个权重训练,为none则不基于任何权重训练")
|
||||
@ -106,55 +107,52 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速(0=否,1=是)")
|
||||
args = parser.parse_args()
|
||||
|
||||
# ========== 1. 初始化环境和随机种子 ==========
|
||||
# ========== 1. 创建训练环境 ==========
|
||||
local_rank = init_distributed_mode()
|
||||
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
|
||||
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
|
||||
if dist.is_initialized(): args.device = f"cuda:{local_rank}" # 分布式训练时,设置设备
|
||||
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0)) # 设置随机种子
|
||||
|
||||
# ========== 2. 配置目录、模型参数、检查ckp ==========
|
||||
os.makedirs(args.save_dir, exist_ok=True)
|
||||
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe))
|
||||
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
|
||||
|
||||
# ========== 3. 设置混合精度 ==========
|
||||
# ========== 2. 模型相关 ==========
|
||||
os.makedirs(args.save_dir, exist_ok=True) # 创建模型保存目录
|
||||
|
||||
device_type = "cuda" if "cuda" in args.device else "cpu"
|
||||
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
|
||||
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
|
||||
|
||||
# ========== 4. 配wandb ==========
|
||||
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype) # 设置混合精度
|
||||
|
||||
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe))
|
||||
model, tokenizer = init_model(lm_config, args.from_weight, device=args.device)
|
||||
if args.use_compile == 1:
|
||||
model = torch.compile(model)
|
||||
Logger('torch.compile enabled')
|
||||
train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
|
||||
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
|
||||
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
|
||||
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate) # 所有模型相关的初始化
|
||||
|
||||
if dist.is_initialized():
|
||||
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
||||
model = DistributedDataParallel(model, device_ids=[local_rank]) # 分布式训练模型初始化
|
||||
|
||||
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None # 检查checkpoint
|
||||
|
||||
# ========== 3. checkpoint相关 ==========
|
||||
wandb = None
|
||||
if args.use_wandb and is_main_process():
|
||||
import swanlab as wandb
|
||||
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
|
||||
resume = 'must' if wandb_id else None
|
||||
wandb_run_name = f"MiniMind-Full-SFT-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
|
||||
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
|
||||
|
||||
# ========== 5. 定义模型、数据、优化器 ==========
|
||||
model, tokenizer = init_model(lm_config, args.from_weight, device=args.device)
|
||||
train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
|
||||
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
|
||||
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
|
||||
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
|
||||
|
||||
# ========== 6. 从ckp恢复状态 ==========
|
||||
wandb_run_name = f"MiniMind-SFT-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
|
||||
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume) # 通过checkpoint进行训练可视化
|
||||
|
||||
start_epoch, start_step = 0, 0
|
||||
if ckp_data:
|
||||
model.load_state_dict(ckp_data['model'])
|
||||
model.module.load_state_dict(ckp_data['model'])
|
||||
optimizer.load_state_dict(ckp_data['optimizer'])
|
||||
scaler.load_state_dict(ckp_data['scaler'])
|
||||
start_epoch = ckp_data['epoch']
|
||||
start_step = ckp_data.get('step', 0)
|
||||
start_step = ckp_data.get('step', 0) # 通过checkpoint进行状态恢复
|
||||
|
||||
# ========== 7. 编译和分布式包装 ==========
|
||||
if args.use_compile == 1:
|
||||
model = torch.compile(model)
|
||||
Logger('torch.compile enabled')
|
||||
if dist.is_initialized():
|
||||
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
||||
model = DistributedDataParallel(model, device_ids=[local_rank])
|
||||
|
||||
# ========== 8. 开始训练 ==========
|
||||
# ========== 4. 训练 ==========
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
train_sampler and train_sampler.set_epoch(epoch)
|
||||
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
|
||||
@ -167,5 +165,5 @@ if __name__ == "__main__":
|
||||
else:
|
||||
train_epoch(epoch, loader, len(loader), 0, wandb)
|
||||
|
||||
# ========== 9. 清理分布进程 ==========
|
||||
if dist.is_initialized(): dist.destroy_process_group()
|
||||
# ========== 5. 撤销训练环境 ==========
|
||||
if dist.is_initialized(): dist.destroy_process_group() # 撤销分布式训练环境
|
||||
|
||||
@ -299,7 +299,7 @@ if __name__ == "__main__":
|
||||
# ========== 6. 从ckp恢复状态 ==========
|
||||
start_epoch, start_step = 0, 0
|
||||
if ckp_data:
|
||||
model.load_state_dict(ckp_data['model'])
|
||||
model.module.load_state_dict(ckp_data['model'])
|
||||
optimizer.load_state_dict(ckp_data['optimizer'])
|
||||
scheduler.load_state_dict(ckp_data['scheduler'])
|
||||
start_epoch = ckp_data['epoch']
|
||||
|
||||
@ -25,20 +25,28 @@ def train_epoch(epoch, loader, iters, lora_params, start_step=0, wandb=None):
|
||||
start_time = time.time()
|
||||
last_step = start_step
|
||||
for step, (input_ids, labels) in enumerate(loader, start=start_step + 1):
|
||||
########################### 训练前操作 ###########################
|
||||
# 数据加载
|
||||
input_ids = input_ids.to(args.device)
|
||||
labels = labels.to(args.device)
|
||||
last_step = step
|
||||
|
||||
# 学习率调整
|
||||
lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
########################### 训练中操作 ###########################
|
||||
# 模型前向传播
|
||||
with autocast_ctx:
|
||||
res = model(input_ids, labels=labels)
|
||||
loss = res.loss + res.aux_loss
|
||||
loss = loss / args.accumulation_steps
|
||||
|
||||
# 模型反向传播
|
||||
scaler.scale(loss).backward()
|
||||
|
||||
|
||||
# 梯度更新
|
||||
if step % args.accumulation_steps == 0:
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(lora_params, args.grad_clip)
|
||||
@ -46,6 +54,8 @@ def train_epoch(epoch, loader, iters, lora_params, start_step=0, wandb=None):
|
||||
scaler.update()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
########################### 训练后操作 ###########################
|
||||
# 日志打印
|
||||
if step % args.log_interval == 0 or step == iters:
|
||||
spend_time = time.time() - start_time
|
||||
current_loss = loss.item() * args.accumulation_steps
|
||||
@ -56,10 +66,10 @@ def train_epoch(epoch, loader, iters, lora_params, start_step=0, wandb=None):
|
||||
Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, aux_loss: {current_aux_loss:.4f}, lr: {current_lr:.8f}, epoch_time: {eta_min:.1f}min')
|
||||
if wandb: wandb.log({"loss": current_loss, "logits_loss": current_logits_loss, "aux_loss": current_aux_loss, "learning_rate": current_lr, "epoch_time": eta_min})
|
||||
|
||||
# 模型保存(仅保存LoRA权重)
|
||||
if (step % args.save_interval == 0 or step == iters) and is_main_process():
|
||||
model.eval()
|
||||
lora_save_path = f'{args.save_dir}/{args.lora_name}_{lm_config.hidden_size}.pth'
|
||||
# LoRA只保存LoRA权重
|
||||
save_lora(model, lora_save_path)
|
||||
lm_checkpoint(lm_config, weight=args.lora_name, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
|
||||
model.train()
|
||||
@ -99,41 +109,27 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速(0=否,1=是)")
|
||||
args = parser.parse_args()
|
||||
|
||||
# ========== 1. 初始化环境和随机种子 ==========
|
||||
# ========== 1. 创建训练环境 ==========
|
||||
local_rank = init_distributed_mode()
|
||||
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
|
||||
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
|
||||
|
||||
# ========== 2. 配置目录、模型参数、检查ckp ==========
|
||||
os.makedirs(args.save_dir, exist_ok=True)
|
||||
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe))
|
||||
ckp_data = lm_checkpoint(lm_config, weight=args.lora_name, save_dir='../checkpoints') if args.from_resume==1 else None
|
||||
|
||||
# ========== 3. 设置混合精度 ==========
|
||||
# ========== 2. 模型相关 ==========
|
||||
os.makedirs(args.save_dir, exist_ok=True)
|
||||
|
||||
device_type = "cuda" if "cuda" in args.device else "cpu"
|
||||
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
|
||||
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
|
||||
|
||||
# ========== 4. 配wandb ==========
|
||||
wandb = None
|
||||
if args.use_wandb and is_main_process():
|
||||
import swanlab as wandb
|
||||
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
|
||||
resume = 'must' if wandb_id else None
|
||||
wandb_run_name = f"MiniMind-LoRA-{args.lora_name}-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LR-{args.learning_rate}"
|
||||
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
|
||||
|
||||
# ========== 5. 定义模型、应用LoRA、冻结非LoRA参数 ==========
|
||||
|
||||
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe))
|
||||
model, tokenizer = init_model(lm_config, args.from_weight, device=args.device)
|
||||
apply_lora(model)
|
||||
|
||||
# 统计参数
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
lora_params_count = sum(p.numel() for name, p in model.named_parameters() if 'lora' in name)
|
||||
Logger(f"LLM 总参数量: {total_params / 1e6:.3f} M")
|
||||
Logger(f"LoRA 参数量: {lora_params_count / 1e6:.3f} M")
|
||||
Logger(f"LoRA 参数占比: {lora_params_count / total_params * 100:.2f}%")
|
||||
|
||||
# 冻结非LoRA参数,收集LoRA参数
|
||||
lora_params = []
|
||||
for name, param in model.named_parameters():
|
||||
@ -143,13 +139,29 @@ if __name__ == "__main__":
|
||||
else:
|
||||
param.requires_grad = False
|
||||
|
||||
# ========== 6. 定义数据和优化器 ==========
|
||||
train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
|
||||
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
|
||||
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
|
||||
optimizer = optim.AdamW(lora_params, lr=args.learning_rate)
|
||||
|
||||
if args.use_compile == 1:
|
||||
model = torch.compile(model)
|
||||
Logger('torch.compile enabled')
|
||||
if dist.is_initialized():
|
||||
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
||||
model = DistributedDataParallel(model, device_ids=[local_rank])
|
||||
|
||||
ckp_data = lm_checkpoint(lm_config, weight=args.lora_name, save_dir='../checkpoints') if args.from_resume==1 else None
|
||||
|
||||
# ========== 3. checkpoint相关 ==========
|
||||
wandb = None
|
||||
if args.use_wandb and is_main_process():
|
||||
import swanlab as wandb
|
||||
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
|
||||
resume = 'must' if wandb_id else None
|
||||
wandb_run_name = f"MiniMind-LoRA-{args.lora_name}-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LR-{args.learning_rate}"
|
||||
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
|
||||
|
||||
# ========== 7. 从ckp恢复状态 ==========
|
||||
start_epoch, start_step = 0, 0
|
||||
if ckp_data:
|
||||
model.load_state_dict(ckp_data['model'], strict=False)
|
||||
@ -158,15 +170,7 @@ if __name__ == "__main__":
|
||||
start_epoch = ckp_data['epoch']
|
||||
start_step = ckp_data.get('step', 0)
|
||||
|
||||
# ========== 8. 编译和分布式包装 ==========
|
||||
if args.use_compile == 1:
|
||||
model = torch.compile(model)
|
||||
Logger('torch.compile enabled')
|
||||
if dist.is_initialized():
|
||||
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
||||
model = DistributedDataParallel(model, device_ids=[local_rank])
|
||||
|
||||
# ========== 9. 开始训练 ==========
|
||||
# ========== 4. 训练 ==========
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
train_sampler and train_sampler.set_epoch(epoch)
|
||||
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
|
||||
@ -179,5 +183,5 @@ if __name__ == "__main__":
|
||||
else:
|
||||
train_epoch(epoch, loader, len(loader), lora_params, 0, wandb)
|
||||
|
||||
# ========== 10. 清理分布进程 ==========
|
||||
# ========== 5. 撤销训练环境 ==========
|
||||
if dist.is_initialized(): dist.destroy_process_group()
|
||||
@ -25,28 +25,24 @@ from trainer.rollout_engine import create_rollout_engine
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
|
||||
def rep_penalty(text, n=3, cap=0.5):
|
||||
toks = re.findall(r"\w+|[^\w\s]", text.lower())
|
||||
grams = [tuple(toks[i:i + n]) for i in range(len(toks) - n + 1)]
|
||||
return min(cap, (len(grams) - len(set(grams))) * cap * 2 / len(grams)) if grams else 0.0
|
||||
|
||||
|
||||
# 自定义的Critic模型,继承自MiniMindLM
|
||||
class CriticModel(MiniMindForCausalLM):
|
||||
def __init__(self, params):
|
||||
super().__init__(params)
|
||||
# 替换lm_head为输出单一价值的线性层
|
||||
self.value_head = nn.Linear(params.hidden_size, 1)
|
||||
|
||||
def forward(self, input_ids=None, attention_mask=None, **kwargs):
|
||||
# 使用基础模型获取隐藏状态
|
||||
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
|
||||
hidden_states = self.model.norm(outputs[0])
|
||||
# 使用value_head获取价值估计
|
||||
# 使用value_head替代lm_head,获取价值估计
|
||||
values = self.value_head(hidden_states).squeeze(-1)
|
||||
return values
|
||||
|
||||
def rep_penalty(text, n=3, cap=0.5):
|
||||
toks = re.findall(r"\w+|[^\w\s]", text.lower())
|
||||
grams = [tuple(toks[i:i + n]) for i in range(len(toks) - n + 1)]
|
||||
return min(cap, (len(grams) - len(set(grams))) * cap * 2 / len(grams)) if grams else 0.0
|
||||
|
||||
def calculate_rewards(prompts, responses, reward_model):
|
||||
rewards = torch.zeros(len(responses), device=args.device)
|
||||
@ -81,6 +77,8 @@ def ppo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, actor_sched
|
||||
grad_accum_step = 0
|
||||
|
||||
for step, batch in enumerate(loader, start=start_step + 1):
|
||||
########################### 训练前操作 ###########################
|
||||
# 数据准备
|
||||
prompts = batch["prompt"] # list[str], length B
|
||||
enc = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True, max_length=args.max_seq_len,
|
||||
padding_side="left").to(args.device) # input_ids: [B, P], attention_mask: [B, P]
|
||||
@ -110,7 +108,8 @@ def ppo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, actor_sched
|
||||
Logger(f"{'=' * 29} [DEBUG] sample[{i}] RESPONSE_END {'=' * 29}")
|
||||
Logger(f"[DEBUG] reward={rewards[i].item():.4f}")
|
||||
Logger('='*100)
|
||||
|
||||
|
||||
# 数据处理
|
||||
full_mask = (gen_out != tokenizer.pad_token_id).long() # [B, P+R]
|
||||
labels = gen_out[:, 1:].clone() # [B, P+R-1]
|
||||
seq_len, resp_start = gen_out.size(1) - 1, prompt_length - 1
|
||||
@ -126,6 +125,8 @@ def ppo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, actor_sched
|
||||
resp_policy_mask = ((resp_idx < resp_lengths.unsqueeze(1)) & resp_pad_mask).float()
|
||||
resp_value_mask = resp_policy_mask.clone()
|
||||
|
||||
########################### 训练中操作 ###########################
|
||||
# 数据计算,初始化优势函数advantages和价值函数returns
|
||||
with torch.no_grad(): # Rollout阶段只需推理获取old_logp和old_values,切断梯度省显存
|
||||
critic_for_rollout = critic_model.module if isinstance(critic_model, DistributedDataParallel) else critic_model
|
||||
values_seq = critic_for_rollout(input_ids=gen_out, attention_mask=full_mask)
|
||||
@ -155,7 +156,8 @@ def ppo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, actor_sched
|
||||
adv_mean = (advantages * resp_policy_mask).sum() / resp_policy_mask.sum().clamp(min=1)
|
||||
adv_var = ((advantages - adv_mean) ** 2 * resp_policy_mask).sum() / resp_policy_mask.sum().clamp(min=1)
|
||||
advantages = (advantages - adv_mean) * torch.rsqrt(adv_var + 1e-8) * resp_policy_mask
|
||||
|
||||
|
||||
# 训练参数初始化
|
||||
mb_size = max(1, min(args.mini_batch_size, B))
|
||||
stop_ppo = False
|
||||
policy_loss_sum = 0.0
|
||||
@ -167,6 +169,8 @@ def ppo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, actor_sched
|
||||
log_count = 0
|
||||
actor_unwrapped = actor_model.module if isinstance(actor_model, DistributedDataParallel) else actor_model
|
||||
critic_unwrapped = critic_model.module if isinstance(critic_model, DistributedDataParallel) else critic_model
|
||||
|
||||
# 强化学习训练迭代
|
||||
for ppo_epoch in range(args.ppo_update_iters):
|
||||
if stop_ppo:
|
||||
break
|
||||
@ -240,6 +244,7 @@ def ppo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, actor_sched
|
||||
actor_optimizer.zero_grad()
|
||||
critic_optimizer.zero_grad()
|
||||
|
||||
# 梯度更新
|
||||
if grad_accum_step % args.accumulation_steps != 0:
|
||||
clip_grad_norm_(actor_model.parameters(), args.grad_clip)
|
||||
clip_grad_norm_(critic_model.parameters(), args.grad_clip)
|
||||
@ -249,9 +254,12 @@ def ppo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, actor_sched
|
||||
critic_scheduler.step()
|
||||
actor_optimizer.zero_grad()
|
||||
critic_optimizer.zero_grad()
|
||||
|
||||
if is_main_process() and step % args.save_interval == 0: rollout_engine.update_policy(actor_model)
|
||||
|
||||
########################### 训练后操作 ###########################
|
||||
# 模型更新
|
||||
if is_main_process() and step % args.save_interval == 0: rollout_engine.update_policy(actor_model)
|
||||
|
||||
# 日志打印
|
||||
if is_main_process():
|
||||
critic_loss_val = value_loss_sum / max(log_count, 1)
|
||||
reward_val = rewards.mean().item()
|
||||
@ -278,6 +286,7 @@ def ppo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, actor_sched
|
||||
f"ClipFrac: {clipfrac_val:.4f}, Critic Loss: {critic_loss_val:.4f}, "
|
||||
f"Avg Response Len: {avg_len_val:.2f}, Actor LR: {actor_lr:.8f}, Critic LR: {critic_lr:.8f}")
|
||||
|
||||
# 模型保存
|
||||
if (step % args.save_interval == 0 or step == iters) and is_main_process():
|
||||
actor_model.eval()
|
||||
moe_suffix = '_moe' if lm_config.use_moe else ''
|
||||
@ -345,43 +354,37 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--sglang_shared_path", type=str, default="./sglang_ckpt_ppo", help="SGLang共享存储路径")
|
||||
args = parser.parse_args()
|
||||
|
||||
# ========== 1. 初始化环境和随机种子 ==========
|
||||
# ========== 1. 创建训练环境 ==========
|
||||
local_rank = init_distributed_mode()
|
||||
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
|
||||
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
|
||||
|
||||
# ========== 2. 配置目录、模型参数、检查ckp ==========
|
||||
# ========== 2. 模型相关 ==========
|
||||
os.makedirs(args.save_dir, exist_ok=True)
|
||||
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe))
|
||||
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
|
||||
|
||||
# ========== 3. 设置混合精度 ==========
|
||||
|
||||
device_type = "cuda" if "cuda" in args.device else "cpu"
|
||||
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
|
||||
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
|
||||
|
||||
# ========== 4. 配wandb ==========
|
||||
wandb = None
|
||||
if args.use_wandb and is_main_process():
|
||||
import swanlab as wandb
|
||||
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
|
||||
resume = 'must' if wandb_id else None
|
||||
wandb_run_name = f"MiniMind-PPO-Epoch-{args.epochs}-BS-{args.batch_size}-LR-{args.learning_rate}"
|
||||
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
|
||||
|
||||
# ========== 5. 初始化模型和数据 ==========
|
||||
|
||||
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe))
|
||||
|
||||
base_weight = args.from_weight
|
||||
# Actor模型
|
||||
|
||||
# LLM_PPO四大模型初始化
|
||||
actor_model, tokenizer = init_model(lm_config, base_weight, device=args.device)
|
||||
|
||||
ref_model, _ = init_model(lm_config, base_weight, device=args.device)
|
||||
ref_model = ref_model.eval().requires_grad_(False)
|
||||
moe_suffix = '_moe' if lm_config.use_moe else ''
|
||||
ckp = f'{args.save_dir}/{base_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
|
||||
state_dict = torch.load(ckp, map_location=args.device)
|
||||
|
||||
critic_model = CriticModel(lm_config)
|
||||
critic_model.load_state_dict(state_dict, strict=False)
|
||||
critic_model = critic_model.to(args.device)
|
||||
|
||||
reward_model = LMForRewardModel(args.reward_model_path, device=args.device, dtype=torch.float16)
|
||||
|
||||
# Rollout引擎
|
||||
rollout_engine = create_rollout_engine(
|
||||
engine_type=args.rollout_engine,
|
||||
@ -393,6 +396,7 @@ if __name__ == "__main__":
|
||||
sglang_model_path=args.sglang_model_path,
|
||||
sglang_shared_path=args.sglang_shared_path,
|
||||
)
|
||||
|
||||
train_ds = RLAIFDataset(args.data_path, tokenizer, max_length=(args.max_seq_len + args.max_gen_len), thinking_ratio=args.thinking_ratio)
|
||||
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
|
||||
actor_optimizer = optim.AdamW(actor_model.parameters(), lr=args.learning_rate)
|
||||
@ -404,18 +408,6 @@ if __name__ == "__main__":
|
||||
actor_scheduler = CosineAnnealingLR(actor_optimizer, T_max=total_optimizer_steps, eta_min=args.learning_rate / 10)
|
||||
critic_scheduler = CosineAnnealingLR(critic_optimizer, T_max=total_optimizer_steps, eta_min=args.critic_learning_rate / 10)
|
||||
|
||||
start_epoch, start_step = 0, 0
|
||||
if ckp_data:
|
||||
actor_model.load_state_dict(ckp_data['model'])
|
||||
critic_model.load_state_dict(ckp_data['critic_model'])
|
||||
actor_optimizer.load_state_dict(ckp_data['optimizer'])
|
||||
critic_optimizer.load_state_dict(ckp_data['critic_optimizer'])
|
||||
actor_scheduler.load_state_dict(ckp_data['scheduler'])
|
||||
critic_scheduler.load_state_dict(ckp_data['critic_scheduler'])
|
||||
start_epoch = ckp_data['epoch']
|
||||
start_step = ckp_data.get('step', 0)
|
||||
|
||||
# ========== 7. 编译和分布式包装 ==========
|
||||
if args.use_compile == 1:
|
||||
actor_model = torch.compile(actor_model)
|
||||
Logger('torch.compile enabled')
|
||||
@ -426,8 +418,30 @@ if __name__ == "__main__":
|
||||
actor_model = DistributedDataParallel(actor_model, device_ids=[local_rank])
|
||||
critic_model = DistributedDataParallel(critic_model, device_ids=[local_rank])
|
||||
if is_main_process(): rollout_engine.update_policy(actor_model)
|
||||
|
||||
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
|
||||
|
||||
# ========== 3. checkpoint相关 ==========
|
||||
wandb = None
|
||||
if args.use_wandb and is_main_process():
|
||||
import swanlab as wandb
|
||||
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
|
||||
resume = 'must' if wandb_id else None
|
||||
wandb_run_name = f"MiniMind-PPO-Epoch-{args.epochs}-BS-{args.batch_size}-LR-{args.learning_rate}"
|
||||
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
|
||||
|
||||
start_epoch, start_step = 0, 0
|
||||
if ckp_data:
|
||||
actor_model.module.load_state_dict(ckp_data['model'])
|
||||
critic_model.module.load_state_dict(ckp_data['critic_model'])
|
||||
actor_optimizer.load_state_dict(ckp_data['optimizer'])
|
||||
critic_optimizer.load_state_dict(ckp_data['critic_optimizer'])
|
||||
actor_scheduler.load_state_dict(ckp_data['scheduler'])
|
||||
critic_scheduler.load_state_dict(ckp_data['critic_scheduler'])
|
||||
start_epoch = ckp_data['epoch']
|
||||
start_step = ckp_data.get('step', 0)
|
||||
|
||||
# ========== 8. 开始训练 ==========
|
||||
# ========== 4. 开始训练 ==========
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
train_sampler and train_sampler.set_epoch(epoch)
|
||||
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
|
||||
@ -440,5 +454,5 @@ if __name__ == "__main__":
|
||||
else:
|
||||
ppo_train_epoch(epoch, loader, len(loader), rollout_engine, ref_model, actor_scheduler, critic_scheduler, reward_model, 0, wandb, use_sglang = (args.rollout_engine == "sglang"))
|
||||
|
||||
# ========== 9. 清理分布进程 ==========
|
||||
# ========== 5. 清理分布进程 ==========
|
||||
if dist.is_initialized(): dist.destroy_process_group()
|
||||
@ -22,23 +22,30 @@ warnings.filterwarnings('ignore')
|
||||
|
||||
def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
|
||||
start_time = time.time()
|
||||
last_step = start_step
|
||||
for step, (input_ids, labels) in enumerate(loader, start=start_step + 1):
|
||||
|
||||
########################### 训练前操作 ###########################
|
||||
# 数据加载
|
||||
input_ids = input_ids.to(args.device)
|
||||
labels = labels.to(args.device)
|
||||
last_step = step
|
||||
lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
|
||||
|
||||
# 学习率调整
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
########################### 训练中操作 ###########################
|
||||
# 模型前向传播
|
||||
with autocast_ctx:
|
||||
res = model(input_ids, labels=labels)
|
||||
loss = res.loss + res.aux_loss
|
||||
loss = loss / args.accumulation_steps
|
||||
|
||||
# 模型反向传播
|
||||
scaler.scale(loss).backward()
|
||||
|
||||
if step % args.accumulation_steps == 0:
|
||||
# 梯度更新
|
||||
if (step + 1) % args.accumulation_steps == 0:
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||
|
||||
@ -47,17 +54,20 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
|
||||
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
if step % args.log_interval == 0 or step == iters:
|
||||
########################### 训练后操作 ###########################
|
||||
# 日志打印
|
||||
if step % args.log_interval == 0 or step == iters - 1:
|
||||
spend_time = time.time() - start_time
|
||||
current_loss = loss.item() * args.accumulation_steps
|
||||
current_aux_loss = res.aux_loss.item() if res.aux_loss is not None else 0.0
|
||||
current_logits_loss = current_loss - current_aux_loss
|
||||
current_lr = optimizer.param_groups[-1]['lr']
|
||||
eta_min = spend_time / max(step - start_step, 1) * (iters - step) // 60
|
||||
eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
|
||||
Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, aux_loss: {current_aux_loss:.4f}, lr: {current_lr:.8f}, epoch_time: {eta_min:.1f}min')
|
||||
if wandb: wandb.log({"loss": current_loss, "logits_loss": current_logits_loss, "aux_loss": current_aux_loss, "learning_rate": current_lr, "epoch_time": eta_min})
|
||||
|
||||
if (step % args.save_interval == 0 or step == iters) and is_main_process():
|
||||
# 模型保存
|
||||
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
|
||||
model.eval()
|
||||
moe_suffix = '_moe' if lm_config.use_moe else ''
|
||||
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
|
||||
@ -71,19 +81,12 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
|
||||
|
||||
del input_ids, labels, res, loss
|
||||
|
||||
if last_step > start_step and last_step % args.accumulation_steps != 0:
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="MiniMind Pretraining")
|
||||
parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
|
||||
parser.add_argument('--save_weight', default='pretrain', type=str, help="保存权重的前缀名")
|
||||
parser.add_argument("--epochs", type=int, default=2, help="训练轮数")
|
||||
parser.add_argument("--epochs", type=int, default=1, help="训练轮数(建议1轮zero或2-6轮充分训练)")
|
||||
parser.add_argument("--batch_size", type=int, default=32, help="batch size")
|
||||
parser.add_argument("--learning_rate", type=float, default=5e-4, help="初始学习率")
|
||||
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
|
||||
@ -105,55 +108,52 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速(0=否,1=是)")
|
||||
args = parser.parse_args()
|
||||
|
||||
# ========== 1. 初始化环境和随机种子 ==========
|
||||
# ========== 1. 创建训练环境 ==========
|
||||
local_rank = init_distributed_mode()
|
||||
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
|
||||
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
|
||||
if dist.is_initialized(): args.device = f"cuda:{local_rank}" # 分布式训练时,设置设备
|
||||
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0)) # 设置随机种子
|
||||
|
||||
# ========== 2. 配置目录、模型参数、检查ckp ==========
|
||||
os.makedirs(args.save_dir, exist_ok=True)
|
||||
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe))
|
||||
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
|
||||
|
||||
# ========== 3. 设置混合精度 ==========
|
||||
# ========== 2. 模型相关 ==========
|
||||
os.makedirs(args.save_dir, exist_ok=True) # 创建模型保存目录
|
||||
|
||||
device_type = "cuda" if "cuda" in args.device else "cpu"
|
||||
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
|
||||
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
|
||||
|
||||
# ========== 4. 配wandb ==========
|
||||
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype) # 设置混合精度
|
||||
|
||||
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe))
|
||||
model, tokenizer = init_model(lm_config, args.from_weight, device=args.device)
|
||||
if args.use_compile == 1:
|
||||
model = torch.compile(model)
|
||||
Logger('torch.compile enabled')
|
||||
train_ds = PretrainDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
|
||||
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
|
||||
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
|
||||
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate) # 所有模型相关的初始化
|
||||
|
||||
if dist.is_initialized():
|
||||
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
||||
model = DistributedDataParallel(model, device_ids=[local_rank]) # 分布式训练模型初始化
|
||||
|
||||
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None # 检查checkpoint
|
||||
|
||||
# ========== 3. checkpoint相关 ==========
|
||||
wandb = None
|
||||
if args.use_wandb and is_main_process():
|
||||
import swanlab as wandb
|
||||
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
|
||||
resume = 'must' if wandb_id else None
|
||||
wandb_run_name = f"MiniMind-Pretrain-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
|
||||
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
|
||||
|
||||
# ========== 5. 定义模型、数据、优化器 ==========
|
||||
model, tokenizer = init_model(lm_config, args.from_weight, device=args.device)
|
||||
train_ds = PretrainDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
|
||||
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
|
||||
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
|
||||
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
|
||||
|
||||
# ========== 6. 从ckp恢复状态 ==========
|
||||
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume) # 通过checkpoint进行训练可视化
|
||||
|
||||
start_epoch, start_step = 0, 0
|
||||
if ckp_data:
|
||||
model.load_state_dict(ckp_data['model'])
|
||||
model.module.load_state_dict(ckp_data['model'])
|
||||
optimizer.load_state_dict(ckp_data['optimizer'])
|
||||
scaler.load_state_dict(ckp_data['scaler'])
|
||||
start_epoch = ckp_data['epoch']
|
||||
start_step = ckp_data.get('step', 0)
|
||||
start_step = ckp_data.get('step', 0) # 通过checkpoint进行状态恢复
|
||||
|
||||
# ========== 7. 编译和分布式包装 ==========
|
||||
if args.use_compile == 1:
|
||||
model = torch.compile(model)
|
||||
Logger('torch.compile enabled')
|
||||
if dist.is_initialized():
|
||||
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
||||
model = DistributedDataParallel(model, device_ids=[local_rank])
|
||||
|
||||
# ========== 8. 开始训练 ==========
|
||||
# ========== 4. 训练 ==========
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
train_sampler and train_sampler.set_epoch(epoch)
|
||||
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
|
||||
@ -166,5 +166,5 @@ if __name__ == "__main__":
|
||||
else:
|
||||
train_epoch(epoch, loader, len(loader), 0, wandb)
|
||||
|
||||
# ========== 9. 清理分布进程 ==========
|
||||
if dist.is_initialized(): dist.destroy_process_group()
|
||||
# ========== 5. 撤销训练环境 ==========
|
||||
if dist.is_initialized(): dist.destroy_process_group() # 撤销分布式训练环境
|
||||
|
||||
@ -9,22 +9,32 @@ TOKENIZER_DIR = '../model_learn_tokenizer/'
|
||||
VOCAB_SIZE = 6400
|
||||
SPECIAL_TOKENS_NUM = 36
|
||||
|
||||
# 获取文本(train_tokenizer辅助函数)
|
||||
def get_texts(data_path):
|
||||
with open(data_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
for i, line in enumerate(f):
|
||||
if i >= 10000: break # 选10000行测试
|
||||
if i >= 10000: break # 选10000行测试(注释掉该行,可以进行所有数据集的文本读取)
|
||||
try:
|
||||
data = json.loads(line)
|
||||
# 仅利用SFT数据集中conversations字段的content字段,忽略reasoning_content、tools、tool_calls字段
|
||||
contents = [item.get('content') for item in data.get('conversations', []) if item.get('content')]
|
||||
if contents:
|
||||
# 注意这里yeild的使用
|
||||
yield "\n".join(contents)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
def train_tokenizer(data_path, tokenizer_dir, vocab_size, special_tokens_num=SPECIAL_TOKENS_NUM):
|
||||
tokenizer = Tokenizer(models.BPE())
|
||||
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
|
||||
###########################训练前操作###########################
|
||||
# 获取文本
|
||||
texts = get_texts(data_path)
|
||||
|
||||
# 创建tokenizer目录
|
||||
os.makedirs(tokenizer_dir, exist_ok=True)
|
||||
tokenizer_json_path = os.path.join(tokenizer_dir, "tokenizer.json")
|
||||
tokenizer_config_path = os.path.join(tokenizer_dir, "tokenizer_config.json")
|
||||
|
||||
# special_tokens定义
|
||||
special_tokens_list = [
|
||||
"<|endoftext|>", "<|im_start|>", "<|im_end|>",
|
||||
"<|object_ref_start|>", "<|object_ref_end|>", "<|box_start|>", "<|box_end|>", "<|quad_start|>", "<|quad_end|>",
|
||||
@ -40,21 +50,28 @@ def train_tokenizer(data_path, tokenizer_dir, vocab_size, special_tokens_num=SPE
|
||||
num_buffer = special_tokens_num - len(special_tokens_list + additional_tokens_list)
|
||||
buffer_tokens = [f"<|buffer{i}|>" for i in range(1, num_buffer + 1)] # 预留一定数量的token位置
|
||||
all_special_tokens = special_tokens_list + additional_tokens_list + buffer_tokens
|
||||
|
||||
# 分词器、训练器初始化及预处理
|
||||
tokenizer = Tokenizer(models.BPE())
|
||||
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
|
||||
trainer = trainers.BpeTrainer(
|
||||
vocab_size=vocab_size,
|
||||
show_progress=True,
|
||||
initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
|
||||
special_tokens=all_special_tokens
|
||||
)
|
||||
texts = get_texts(data_path)
|
||||
|
||||
###########################训练中操作###########################
|
||||
tokenizer.train_from_iterator(texts, trainer=trainer)
|
||||
|
||||
###########################训练后操作###########################
|
||||
# 后处理
|
||||
tokenizer.decoder = decoders.ByteLevel()
|
||||
tokenizer.add_special_tokens(special_tokens_list)
|
||||
|
||||
os.makedirs(tokenizer_dir, exist_ok=True)
|
||||
tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
|
||||
# 修改并保存分词器json文件
|
||||
tokenizer.save(tokenizer_json_path)
|
||||
tokenizer.model.save(tokenizer_dir)
|
||||
tokenizer_json_path = os.path.join(tokenizer_dir, "tokenizer.json")
|
||||
with open(tokenizer_json_path, 'r', encoding='utf-8') as f:
|
||||
tokenizer_data = json.load(f)
|
||||
for token_info in tokenizer_data.get('added_tokens', []):
|
||||
@ -63,6 +80,7 @@ def train_tokenizer(data_path, tokenizer_dir, vocab_size, special_tokens_num=SPE
|
||||
with open(tokenizer_json_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(tokenizer_data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
# 创建并保存分词器config文件
|
||||
added_tokens_decoder = {}
|
||||
for i, token in enumerate(all_special_tokens):
|
||||
idx = tokenizer.token_to_id(token)
|
||||
@ -101,13 +119,19 @@ def train_tokenizer(data_path, tokenizer_dir, vocab_size, special_tokens_num=SPE
|
||||
"tokenizer_class": "PreTrainedTokenizerFast"
|
||||
}
|
||||
|
||||
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w", encoding="utf-8") as f:
|
||||
with open(tokenizer_config_path, "w", encoding="utf-8") as f:
|
||||
json.dump(config, f, ensure_ascii=False, indent=4)
|
||||
|
||||
# 打印训练完成信息
|
||||
print("Tokenizer training completed.")
|
||||
|
||||
def eval_tokenizer(tokenizer_dir):
|
||||
from transformers import AutoTokenizer
|
||||
###########################评估前操作###########################
|
||||
# 加载tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)
|
||||
|
||||
# 创建测试消息
|
||||
messages = [
|
||||
{"role": "system", "content": "你是一个优秀的聊天机器人,总是给我正确的回应!"},
|
||||
{"role": "user", "content": '你来自哪里?'},
|
||||
@ -119,14 +143,21 @@ def eval_tokenizer(tokenizer_dir):
|
||||
messages,
|
||||
tokenize=False
|
||||
)
|
||||
|
||||
###########################评估中操作###########################
|
||||
# 聊天模版测试
|
||||
print('-'*100)
|
||||
print(new_prompt)
|
||||
|
||||
# 基础信息测试
|
||||
print('-'*100)
|
||||
print('tokenizer词表长度:', len(tokenizer))
|
||||
model_inputs = tokenizer(new_prompt)
|
||||
print('encoder长度:', len(model_inputs['input_ids']))
|
||||
response = tokenizer.decode(model_inputs['input_ids'], skip_special_tokens=False)
|
||||
print('decoder一致性:', response == new_prompt, "\n")
|
||||
|
||||
# 压缩率测试
|
||||
print('-'*100)
|
||||
print('压缩率测试(Chars/Tokens):')
|
||||
test_texts = [
|
||||
@ -150,6 +181,8 @@ def eval_tokenizer(tokenizer_dir):
|
||||
print(f"样本 {i+1} | 字符数: {char_count:4} | Tokens: {token_count:3} | 压缩率: {compression_ratio:.2f}")
|
||||
|
||||
print(f"平均压缩率: {total_compression / len(test_texts):.2f}")
|
||||
|
||||
# 流式解码测试
|
||||
print('-'*100)
|
||||
print('流式解码(字节缓冲)测试:')
|
||||
input_ids = model_inputs['input_ids']
|
||||
@ -162,6 +195,10 @@ def eval_tokenizer(tokenizer_dir):
|
||||
raw_tokens = [tokenizer.convert_ids_to_tokens(int(t)) for t in (token_cache if isinstance(token_cache, list) else [token_cache])]
|
||||
print(f'Token ID: {str(display_ids):15} -> Raw: {str(raw_tokens):20} -> Decode Str: {current_decode}')
|
||||
token_cache = []
|
||||
|
||||
###########################评估后操作###########################
|
||||
# 打印评估完成信息
|
||||
print("Tokenizer evaluation completed.")
|
||||
|
||||
if __name__ == '__main__':
|
||||
train_tokenizer(DATA_PATH, TOKENIZER_DIR, VOCAB_SIZE)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user