[fix] dtype & lr

This commit is contained in:
jingyaogong 2025-12-09 13:01:38 +08:00
parent aa7dc0f61e
commit 5129f0e2a2
2 changed files with 2 additions and 3 deletions

View File

@ -313,7 +313,7 @@ class MOEFeedForward(nn.Module):
flat_topk_idx = topk_idx.view(-1)
if self.training:
x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0)
y = torch.empty_like(x, dtype=torch.float16)
y = torch.empty_like(x, dtype=x.dtype)
for i, expert in enumerate(self.experts):
y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(y.dtype) # 确保类型一致
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)

View File

@ -23,8 +23,7 @@ def Logger(content):
def get_lr(current_step, total_steps, lr):
min_lr = lr / 10
return min_lr + 0.5 * (lr - min_lr) * (1 + math.cos(math.pi * current_step / total_steps))
return lr*(0.1 + 0.45*(1 + math.cos(math.pi * current_step / total_steps)))
def init_distributed_mode():