[fix]:Fixed the gradient explosion problem that occurred during pre-training of the MOE module.

This commit is contained in:
dieu 2025-12-08 15:25:34 +08:00
parent cc29d9a351
commit c37d924b47
3 changed files with 96 additions and 23 deletions

View File

@ -34,7 +34,7 @@ class MiniMindConfig(PretrainedConfig):
n_routed_experts: int = 4,
n_shared_experts: int = 1,
scoring_func: str = 'softmax',
aux_loss_alpha: float = 0.1,
aux_loss_alpha: float = 0.01,
seq_aux: bool = True,
norm_topk_prob: bool = True,
**kwargs
@ -249,7 +249,8 @@ class MoEGate(nn.Module):
self.reset_parameters()
def reset_parameters(self) -> None:
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
# init.kaiming_uniform_(self.weight, a=math.sqrt(5))
init.normal_(self.weight, mean=0.0, std=0.01)
def forward(self, hidden_states):
bsz, seq_len, h = hidden_states.shape
@ -313,7 +314,7 @@ class MOEFeedForward(nn.Module):
flat_topk_idx = topk_idx.view(-1)
if self.training:
x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0)
y = torch.empty_like(x, dtype=torch.float16)
y = torch.empty_like(x, dtype=x.dtype)
for i, expert in enumerate(self.experts):
y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(y.dtype) # 确保类型一致
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)

View File

@ -27,19 +27,32 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
X = X.to(args.device)
Y = Y.to(args.device)
loss_mask = loss_mask.to(args.device)
lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
# lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate, args.warmup_steps, args.warmup_ratio)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
with autocast_ctx:
res = model(X)
loss = loss_fct(
# loss = loss_fct(
# res.logits.view(-1, res.logits.size(-1)),
# Y.view(-1)
# ).view(Y.size())
#
# loss = (loss * loss_mask).sum() / loss_mask.sum()
# loss += res.aux_loss
# loss = loss / args.accumulation_steps
# --- 【新增监控 1】拆分 Loss 方便观察 ---
raw_loss = loss_fct(
res.logits.view(-1, res.logits.size(-1)),
Y.view(-1)
).view(Y.size())
loss = (loss * loss_mask).sum() / loss_mask.sum()
loss += res.aux_loss
main_loss = (raw_loss * loss_mask).sum() / loss_mask.sum()
aux_loss = res.aux_loss # 获取辅助损失
# 这里的比例很关键,如果 aux_loss 比 main_loss 还大,模型就练歪了
loss = main_loss + aux_loss
loss = loss / args.accumulation_steps
scaler.scale(loss).backward()
@ -59,10 +72,42 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
current_loss = loss.item() * args.accumulation_steps
current_lr = optimizer.param_groups[-1]['lr']
eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
Logger(f'Epoch:[{epoch+1}/{args.epochs}]({step}/{iters}) loss:{current_loss:.6f} lr:{current_lr:.12f} epoch_Time:{eta_min}min:')
if wandb: wandb.log({"loss": current_loss, "lr": current_lr, "epoch_Time": eta_min})
# --- 【新增监控 2】打印 Gate 梯度和 Aux Loss ---
# 1. 解包 DDP 模型,获取原始 model
raw_model = model.module if hasattr(model, 'module') else model
# 2. 统计 MoE Gate 的梯度范数 (检查路由是否还有梯度)
gate_grad_sum = 0.0
gate_count = 0
for name, param in raw_model.named_parameters():
# 筛选出所有门控层的权重
if "gate.weight" in name and param.grad is not None:
gate_grad_sum += param.grad.norm().item()
gate_count += 1
avg_gate_grad = gate_grad_sum / gate_count if gate_count > 0 else 0.0
# 3. 打印详细日志
# Main: 主任务Loss (越低越好)
# Aux: 负载均衡Loss (太低说明可能只顾着平均分配了,太高说明完全没分配好)
# G_Grad: Gate梯度 (如果是 0.0000,说明路由彻底死了,没有任何学习)
debug_info = f"Main:{main_loss.item():.4f} Aux:{aux_loss.item():.4f} G_Grad:{avg_gate_grad:.6f}"
Logger(
f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}) Loss:{current_loss:.4f} {debug_info} lr:{current_lr:.8f} Time:{eta_min}m')
if wandb:
wandb.log({
"total_loss": current_loss,
"main_loss": main_loss.item(),
"aux_loss": aux_loss.item(),
"gate_grad_norm": avg_gate_grad,
"lr": current_lr,
"epoch_Time": eta_min
})
# Logger(f'Epoch:[{epoch+1}/{args.epochs}]({step}/{iters}) loss:{current_loss:.6f} lr:{current_lr:.12f} epoch_Time:{eta_min}min:')
#
# if wandb: wandb.log({"loss": current_loss, "lr": current_lr, "epoch_Time": eta_min})
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
model.eval()
@ -84,25 +129,28 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MiniMind Pretraining")
parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
parser.add_argument('--save_weight', default='pretrain', type=str, help="保存权重的前缀名")
parser.add_argument("--epochs", type=int, default=1, help="训练轮数建议1轮zero或2-6轮充分训练")
parser.add_argument("--batch_size", type=int, default=32, help="batch size")
parser.add_argument("--learning_rate", type=float, default=5e-4, help="初始学习率")
parser.add_argument("--epochs", type=int, default=3, help="训练轮数建议1轮zero或2-6轮充分训练")
parser.add_argument("--batch_size", type=int, default=64, help="batch size")
parser.add_argument("--learning_rate", type=float, default=1e-4, help="初始学习率")
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
parser.add_argument("--num_workers", type=int, default=1, help="数据加载线程数")
parser.add_argument("--accumulation_steps", type=int, default=8, help="梯度累积步数")
parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
parser.add_argument("--num_workers", type=int, default=64, help="数据加载线程数")
parser.add_argument("--accumulation_steps", type=int, default=16, help="梯度累积步数")
parser.add_argument("--grad_clip", type=float, default=1, help="梯度裁剪阈值")
parser.add_argument("--log_interval", type=int, default=100, help="日志打印间隔")
parser.add_argument("--save_interval", type=int, default=100, help="模型保存间隔")
parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
parser.add_argument("--save_interval", type=int, default=500, help="模型保存间隔")
parser.add_argument('--hidden_size', default=640, type=int, help="隐藏层维度")
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
parser.add_argument('--max_seq_len', default=512, type=int, help="训练的最大截断长度")
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构0=否1=是)")
parser.add_argument('--use_moe', default=1, type=int, choices=[0, 1], help="是否使用MoE架构0=否1=是)")
parser.add_argument("--data_path", type=str, default="../dataset/pretrain_hq.jsonl", help="预训练数据路径")
parser.add_argument('--from_weight', default='none', type=str, help="基于哪个权重训练为none则从头开始")
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训0=否1=是)")
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
parser.add_argument("--wandb_project", type=str, default="MiniMind-Pretrain", help="wandb项目名")
parser.add_argument("--wandb_project", type=str, default="MiniMind-Pretrain-640-moe-wramup", help="wandb项目名")
parser.add_argument("--warmup_steps", type=int, default=3000, help="Warmup步数0表示不使用warmup建议为总步数的1-5%")
parser.add_argument("--warmup_ratio", type=float, default=0.01,
help="Warmup起始学习率比例起始lr = lr * warmup_ratio")
args = parser.parse_args()
# ========== 1. 初始化环境和随机种子 ==========

View File

@ -21,8 +21,32 @@ def Logger(content):
print(content)
def get_lr(current_step, total_steps, lr):
return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps))
def get_lr(current_step, total_steps, lr, warmup_steps=0, warmup_ratio=0.01):
"""
学习率调度Warmup + Cosine Decay
Args:
current_step: 当前步数
total_steps: 总步数
lr: 目标学习率
warmup_steps: warmup步数0表示不使用warmup
warmup_ratio: warmup起始学习率比例起始lr = lr * warmup_ratio
Returns:
当前步数对应的学习率
"""
if warmup_steps > 0 and current_step < warmup_steps:
# Warmup阶段线性从 warmup_ratio * lr 增长到 lr
return warmup_ratio * lr + (1 - warmup_ratio) * lr * (current_step / warmup_steps)
else:
# Cosine Decay阶段从 lr 按余弦函数衰减到 lr / 10
if warmup_steps > 0:
cosine_steps = current_step - warmup_steps
cosine_total = total_steps - warmup_steps
else:
cosine_steps = current_step
cosine_total = total_steps
return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * cosine_steps / cosine_total))
def init_distributed_mode():