From 5129f0e2a28faecef58859bb9c0837457ad6afa9 Mon Sep 17 00:00:00 2001 From: jingyaogong Date: Tue, 9 Dec 2025 13:01:38 +0800 Subject: [PATCH] [fix] dtype & lr --- model/model_minimind.py | 2 +- trainer/trainer_utils.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/model/model_minimind.py b/model/model_minimind.py index ad62a68..f7e49d1 100755 --- a/model/model_minimind.py +++ b/model/model_minimind.py @@ -313,7 +313,7 @@ class MOEFeedForward(nn.Module): flat_topk_idx = topk_idx.view(-1) if self.training: x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0) - y = torch.empty_like(x, dtype=torch.float16) + y = torch.empty_like(x, dtype=x.dtype) for i, expert in enumerate(self.experts): y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(y.dtype) # 确保类型一致 y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) diff --git a/trainer/trainer_utils.py b/trainer/trainer_utils.py index 0ad1b00..c1a4ca2 100644 --- a/trainer/trainer_utils.py +++ b/trainer/trainer_utils.py @@ -23,8 +23,7 @@ def Logger(content): def get_lr(current_step, total_steps, lr): - min_lr = lr / 10 - return min_lr + 0.5 * (lr - min_lr) * (1 + math.cos(math.pi * current_step / total_steps)) + return lr*(0.1 + 0.45*(1 + math.cos(math.pi * current_step / total_steps))) def init_distributed_mode():