mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-01-13 19:57:20 +08:00
[fix]:Fixed the gradient explosion problem that occurred during pre-training of the MOE module.
This commit is contained in:
parent
cc29d9a351
commit
c37d924b47
@ -34,7 +34,7 @@ class MiniMindConfig(PretrainedConfig):
|
||||
n_routed_experts: int = 4,
|
||||
n_shared_experts: int = 1,
|
||||
scoring_func: str = 'softmax',
|
||||
aux_loss_alpha: float = 0.1,
|
||||
aux_loss_alpha: float = 0.01,
|
||||
seq_aux: bool = True,
|
||||
norm_topk_prob: bool = True,
|
||||
**kwargs
|
||||
@ -249,7 +249,8 @@ class MoEGate(nn.Module):
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||||
# init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||||
init.normal_(self.weight, mean=0.0, std=0.01)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
bsz, seq_len, h = hidden_states.shape
|
||||
@ -313,7 +314,7 @@ class MOEFeedForward(nn.Module):
|
||||
flat_topk_idx = topk_idx.view(-1)
|
||||
if self.training:
|
||||
x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0)
|
||||
y = torch.empty_like(x, dtype=torch.float16)
|
||||
y = torch.empty_like(x, dtype=x.dtype)
|
||||
for i, expert in enumerate(self.experts):
|
||||
y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(y.dtype) # 确保类型一致
|
||||
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
||||
|
||||
@ -27,19 +27,32 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
|
||||
X = X.to(args.device)
|
||||
Y = Y.to(args.device)
|
||||
loss_mask = loss_mask.to(args.device)
|
||||
lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
|
||||
# lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
|
||||
lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate, args.warmup_steps, args.warmup_ratio)
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
with autocast_ctx:
|
||||
res = model(X)
|
||||
loss = loss_fct(
|
||||
# loss = loss_fct(
|
||||
# res.logits.view(-1, res.logits.size(-1)),
|
||||
# Y.view(-1)
|
||||
# ).view(Y.size())
|
||||
#
|
||||
# loss = (loss * loss_mask).sum() / loss_mask.sum()
|
||||
# loss += res.aux_loss
|
||||
# loss = loss / args.accumulation_steps
|
||||
# --- 【新增监控 1】拆分 Loss 方便观察 ---
|
||||
raw_loss = loss_fct(
|
||||
res.logits.view(-1, res.logits.size(-1)),
|
||||
Y.view(-1)
|
||||
).view(Y.size())
|
||||
|
||||
loss = (loss * loss_mask).sum() / loss_mask.sum()
|
||||
loss += res.aux_loss
|
||||
main_loss = (raw_loss * loss_mask).sum() / loss_mask.sum()
|
||||
aux_loss = res.aux_loss # 获取辅助损失
|
||||
|
||||
# 这里的比例很关键,如果 aux_loss 比 main_loss 还大,模型就练歪了
|
||||
loss = main_loss + aux_loss
|
||||
loss = loss / args.accumulation_steps
|
||||
|
||||
scaler.scale(loss).backward()
|
||||
@ -59,10 +72,42 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
|
||||
current_loss = loss.item() * args.accumulation_steps
|
||||
current_lr = optimizer.param_groups[-1]['lr']
|
||||
eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
|
||||
|
||||
Logger(f'Epoch:[{epoch+1}/{args.epochs}]({step}/{iters}) loss:{current_loss:.6f} lr:{current_lr:.12f} epoch_Time:{eta_min}min:')
|
||||
|
||||
if wandb: wandb.log({"loss": current_loss, "lr": current_lr, "epoch_Time": eta_min})
|
||||
|
||||
# --- 【新增监控 2】打印 Gate 梯度和 Aux Loss ---
|
||||
# 1. 解包 DDP 模型,获取原始 model
|
||||
raw_model = model.module if hasattr(model, 'module') else model
|
||||
|
||||
# 2. 统计 MoE Gate 的梯度范数 (检查路由是否还有梯度)
|
||||
gate_grad_sum = 0.0
|
||||
gate_count = 0
|
||||
for name, param in raw_model.named_parameters():
|
||||
# 筛选出所有门控层的权重
|
||||
if "gate.weight" in name and param.grad is not None:
|
||||
gate_grad_sum += param.grad.norm().item()
|
||||
gate_count += 1
|
||||
avg_gate_grad = gate_grad_sum / gate_count if gate_count > 0 else 0.0
|
||||
|
||||
# 3. 打印详细日志
|
||||
# Main: 主任务Loss (越低越好)
|
||||
# Aux: 负载均衡Loss (太低说明可能只顾着平均分配了,太高说明完全没分配好)
|
||||
# G_Grad: Gate梯度 (如果是 0.0000,说明路由彻底死了,没有任何学习)
|
||||
debug_info = f"Main:{main_loss.item():.4f} Aux:{aux_loss.item():.4f} G_Grad:{avg_gate_grad:.6f}"
|
||||
|
||||
Logger(
|
||||
f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}) Loss:{current_loss:.4f} {debug_info} lr:{current_lr:.8f} Time:{eta_min}m')
|
||||
|
||||
if wandb:
|
||||
wandb.log({
|
||||
"total_loss": current_loss,
|
||||
"main_loss": main_loss.item(),
|
||||
"aux_loss": aux_loss.item(),
|
||||
"gate_grad_norm": avg_gate_grad,
|
||||
"lr": current_lr,
|
||||
"epoch_Time": eta_min
|
||||
})
|
||||
# Logger(f'Epoch:[{epoch+1}/{args.epochs}]({step}/{iters}) loss:{current_loss:.6f} lr:{current_lr:.12f} epoch_Time:{eta_min}min:')
|
||||
#
|
||||
# if wandb: wandb.log({"loss": current_loss, "lr": current_lr, "epoch_Time": eta_min})
|
||||
|
||||
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
|
||||
model.eval()
|
||||
@ -84,25 +129,28 @@ if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="MiniMind Pretraining")
|
||||
parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
|
||||
parser.add_argument('--save_weight', default='pretrain', type=str, help="保存权重的前缀名")
|
||||
parser.add_argument("--epochs", type=int, default=1, help="训练轮数(建议1轮zero或2-6轮充分训练)")
|
||||
parser.add_argument("--batch_size", type=int, default=32, help="batch size")
|
||||
parser.add_argument("--learning_rate", type=float, default=5e-4, help="初始学习率")
|
||||
parser.add_argument("--epochs", type=int, default=3, help="训练轮数(建议1轮zero或2-6轮充分训练)")
|
||||
parser.add_argument("--batch_size", type=int, default=64, help="batch size")
|
||||
parser.add_argument("--learning_rate", type=float, default=1e-4, help="初始学习率")
|
||||
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
|
||||
parser.add_argument("--num_workers", type=int, default=1, help="数据加载线程数")
|
||||
parser.add_argument("--accumulation_steps", type=int, default=8, help="梯度累积步数")
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
|
||||
parser.add_argument("--num_workers", type=int, default=64, help="数据加载线程数")
|
||||
parser.add_argument("--accumulation_steps", type=int, default=16, help="梯度累积步数")
|
||||
parser.add_argument("--grad_clip", type=float, default=1, help="梯度裁剪阈值")
|
||||
parser.add_argument("--log_interval", type=int, default=100, help="日志打印间隔")
|
||||
parser.add_argument("--save_interval", type=int, default=100, help="模型保存间隔")
|
||||
parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
|
||||
parser.add_argument("--save_interval", type=int, default=500, help="模型保存间隔")
|
||||
parser.add_argument('--hidden_size', default=640, type=int, help="隐藏层维度")
|
||||
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
|
||||
parser.add_argument('--max_seq_len', default=512, type=int, help="训练的最大截断长度")
|
||||
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)")
|
||||
parser.add_argument('--use_moe', default=1, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)")
|
||||
parser.add_argument("--data_path", type=str, default="../dataset/pretrain_hq.jsonl", help="预训练数据路径")
|
||||
parser.add_argument('--from_weight', default='none', type=str, help="基于哪个权重训练,为none则从头开始")
|
||||
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训(0=否,1=是)")
|
||||
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
|
||||
parser.add_argument("--wandb_project", type=str, default="MiniMind-Pretrain", help="wandb项目名")
|
||||
parser.add_argument("--wandb_project", type=str, default="MiniMind-Pretrain-640-moe-wramup", help="wandb项目名")
|
||||
parser.add_argument("--warmup_steps", type=int, default=3000, help="Warmup步数,0表示不使用warmup(建议为总步数的1-5%)")
|
||||
parser.add_argument("--warmup_ratio", type=float, default=0.01,
|
||||
help="Warmup起始学习率比例,起始lr = lr * warmup_ratio")
|
||||
args = parser.parse_args()
|
||||
|
||||
# ========== 1. 初始化环境和随机种子 ==========
|
||||
|
||||
@ -21,8 +21,32 @@ def Logger(content):
|
||||
print(content)
|
||||
|
||||
|
||||
def get_lr(current_step, total_steps, lr):
|
||||
return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps))
|
||||
def get_lr(current_step, total_steps, lr, warmup_steps=0, warmup_ratio=0.01):
|
||||
"""
|
||||
学习率调度:Warmup + Cosine Decay
|
||||
|
||||
Args:
|
||||
current_step: 当前步数
|
||||
total_steps: 总步数
|
||||
lr: 目标学习率
|
||||
warmup_steps: warmup步数,0表示不使用warmup
|
||||
warmup_ratio: warmup起始学习率比例,起始lr = lr * warmup_ratio
|
||||
|
||||
Returns:
|
||||
当前步数对应的学习率
|
||||
"""
|
||||
if warmup_steps > 0 and current_step < warmup_steps:
|
||||
# Warmup阶段:线性从 warmup_ratio * lr 增长到 lr
|
||||
return warmup_ratio * lr + (1 - warmup_ratio) * lr * (current_step / warmup_steps)
|
||||
else:
|
||||
# Cosine Decay阶段:从 lr 按余弦函数衰减到 lr / 10
|
||||
if warmup_steps > 0:
|
||||
cosine_steps = current_step - warmup_steps
|
||||
cosine_total = total_steps - warmup_steps
|
||||
else:
|
||||
cosine_steps = current_step
|
||||
cosine_total = total_steps
|
||||
return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * cosine_steps / cosine_total))
|
||||
|
||||
|
||||
def init_distributed_mode():
|
||||
|
||||
Loading…
Reference in New Issue
Block a user