mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-04-23 15:58:15 +08:00
[fix] dtype & lr
This commit is contained in:
parent
aa7dc0f61e
commit
5129f0e2a2
@ -313,7 +313,7 @@ class MOEFeedForward(nn.Module):
|
||||
flat_topk_idx = topk_idx.view(-1)
|
||||
if self.training:
|
||||
x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0)
|
||||
y = torch.empty_like(x, dtype=torch.float16)
|
||||
y = torch.empty_like(x, dtype=x.dtype)
|
||||
for i, expert in enumerate(self.experts):
|
||||
y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(y.dtype) # 确保类型一致
|
||||
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
||||
|
||||
@ -23,8 +23,7 @@ def Logger(content):
|
||||
|
||||
|
||||
def get_lr(current_step, total_steps, lr):
|
||||
min_lr = lr / 10
|
||||
return min_lr + 0.5 * (lr - min_lr) * (1 + math.cos(math.pi * current_step / total_steps))
|
||||
return lr*(0.1 + 0.45*(1 + math.cos(math.pi * current_step / total_steps)))
|
||||
|
||||
|
||||
def init_distributed_mode():
|
||||
|
||||
Loading…
Reference in New Issue
Block a user