mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-04-25 08:48:16 +08:00
Merge 01c04f519b into 693fb1ccf1
This commit is contained in:
commit
00ec16c61f
286
trainer/muon.py
Normal file
286
trainer/muon.py
Normal file
@ -0,0 +1,286 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def zeropower_via_newtonschulz5(G, steps: int):
|
||||
"""
|
||||
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
|
||||
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
|
||||
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
|
||||
zero even beyond the point where the iteration no longer converges all the way to one everywhere
|
||||
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
|
||||
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
|
||||
performance at all relative to UV^T, where USV^T = G is the SVD.
|
||||
"""
|
||||
assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
|
||||
a, b, c = (3.4445, -4.7750, 2.0315)
|
||||
X = G.bfloat16()
|
||||
if G.size(-2) > G.size(-1):
|
||||
X = X.mT
|
||||
|
||||
# Ensure spectral norm is at most 1
|
||||
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
|
||||
# Perform the NS iterations
|
||||
for _ in range(steps):
|
||||
A = X @ X.mT
|
||||
B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
|
||||
X = a * X + B @ X
|
||||
|
||||
if G.size(-2) > G.size(-1):
|
||||
X = X.mT
|
||||
return X
|
||||
|
||||
|
||||
def muon_update(grad, momentum, beta=0.95, ns_steps=5, nesterov=True):
|
||||
momentum.lerp_(grad, 1 - beta)
|
||||
update = grad.lerp_(momentum, beta) if nesterov else momentum
|
||||
if update.ndim == 4: # for the case of conv filters
|
||||
update = update.view(len(update), -1)
|
||||
update = zeropower_via_newtonschulz5(update, steps=ns_steps)
|
||||
update *= max(1, update.size(-2) / update.size(-1))**0.5
|
||||
return update
|
||||
|
||||
|
||||
class Muon(torch.optim.Optimizer):
|
||||
"""
|
||||
Muon - MomentUm Orthogonalized by Newton-schulz
|
||||
|
||||
https://kellerjordan.github.io/posts/muon/
|
||||
|
||||
Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
|
||||
processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
|
||||
matrix. For efficient orthogonalization we use a Newton-Schulz iteration, which has the
|
||||
advantage that it can be stably run in bfloat16 on the GPU.
|
||||
|
||||
Muon should only be used for hidden weight layers. The input embedding, final output layer,
|
||||
and any internal gains or biases should be optimized using a standard method such as AdamW.
|
||||
Hidden convolutional weights can be trained using Muon by viewing them as 2D and then
|
||||
collapsing their last 3 dimensions.
|
||||
|
||||
Arguments:
|
||||
lr: The learning rate, in units of spectral norm per update.
|
||||
weight_decay: The AdamW-style weight decay.
|
||||
momentum: The momentum. A value of 0.95 here is usually fine.
|
||||
"""
|
||||
def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95):
|
||||
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum)
|
||||
assert isinstance(params, list) and len(params) >= 1 and isinstance(params[0], torch.nn.Parameter)
|
||||
params = sorted(params, key=lambda x: x.size(), reverse=True)
|
||||
super().__init__(params, defaults)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
params = group["params"]
|
||||
params_pad = params + [torch.empty_like(params[-1])] * (dist.get_world_size() - len(params) % dist.get_world_size())
|
||||
for base_i in range(len(params))[::dist.get_world_size()]:
|
||||
if base_i + dist.get_rank() < len(params):
|
||||
p = params[base_i + dist.get_rank()]
|
||||
if p.grad is None:
|
||||
# continue
|
||||
p.grad = torch.zeros_like(p) # Force synchronization
|
||||
state = self.state[p]
|
||||
if len(state) == 0:
|
||||
state["momentum_buffer"] = torch.zeros_like(p)
|
||||
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
|
||||
p.mul_(1 - group["lr"] * group["weight_decay"])
|
||||
p.add_(update.reshape(p.shape), alpha=-group["lr"])
|
||||
dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()])
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class SingleDeviceMuon(torch.optim.Optimizer):
|
||||
"""
|
||||
Muon variant for usage in non-distributed settings.
|
||||
"""
|
||||
def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95):
|
||||
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum)
|
||||
super().__init__(params, defaults)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
# continue
|
||||
p.grad = torch.zeros_like(p) # Force synchronization
|
||||
state = self.state[p]
|
||||
if len(state) == 0:
|
||||
state["momentum_buffer"] = torch.zeros_like(p)
|
||||
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
|
||||
p.mul_(1 - group["lr"] * group["weight_decay"])
|
||||
p.add_(update.reshape(p.shape), alpha=-group["lr"])
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def adam_update(grad, buf1, buf2, step, betas, eps):
|
||||
buf1.lerp_(grad, 1 - betas[0])
|
||||
buf2.lerp_(grad.square(), 1 - betas[1])
|
||||
buf1c = buf1 / (1 - betas[0]**step)
|
||||
buf2c = buf2 / (1 - betas[1]**step)
|
||||
return buf1c / (buf2c.sqrt() + eps)
|
||||
|
||||
|
||||
class MuonWithAuxAdam(torch.optim.Optimizer):
|
||||
"""
|
||||
Distributed Muon variant that can be used for all parameters in the network, since it runs an
|
||||
internal AdamW for the parameters that are not compatible with Muon. The user must manually
|
||||
specify which parameters shall be optimized with Muon and which with Adam by passing in a
|
||||
list of param_groups with the `use_muon` flag set.
|
||||
|
||||
The point of this class is to allow the user to have a single optimizer in their code, rather
|
||||
than having both a Muon and an Adam which each need to be stepped.
|
||||
|
||||
You can see an example usage below:
|
||||
|
||||
https://github.com/KellerJordan/modded-nanogpt/blob/master/records/052525_MuonWithAuxAdamExample/b01550f9-03d8-4a9c-86fe-4ab434f1c5e0.txt#L470
|
||||
```
|
||||
hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n]
|
||||
embed_params = [p for n, p in model.named_parameters() if "embed" in n]
|
||||
scalar_params = [p for p in model.parameters() if p.ndim < 2]
|
||||
head_params = [model.lm_head.weight]
|
||||
|
||||
from muon import MuonWithAuxAdam
|
||||
adam_groups = [dict(params=head_params, lr=0.22), dict(params=embed_params, lr=0.6), dict(params=scalar_params, lr=0.04)]
|
||||
adam_groups = [dict(**g, betas=(0.8, 0.95), eps=1e-10, use_muon=False) for g in adam_groups]
|
||||
muon_group = dict(params=hidden_matrix_params, lr=0.05, momentum=0.95, use_muon=True)
|
||||
param_groups = [*adam_groups, muon_group]
|
||||
optimizer = MuonWithAuxAdam(param_groups)
|
||||
```
|
||||
"""
|
||||
def __init__(self, param_groups):
|
||||
for group in param_groups:
|
||||
assert "use_muon" in group
|
||||
if group["use_muon"]:
|
||||
group["params"] = sorted(group["params"], key=lambda x: x.size(), reverse=True)
|
||||
# defaults
|
||||
group["lr"] = group.get("lr", 0.02)
|
||||
group["momentum"] = group.get("momentum", 0.95)
|
||||
group["weight_decay"] = group.get("weight_decay", 0)
|
||||
assert set(group.keys()) == set(["params", "lr", "momentum", "weight_decay", "use_muon"])
|
||||
else:
|
||||
# defaults
|
||||
group["lr"] = group.get("lr", 3e-4)
|
||||
group["betas"] = group.get("betas", (0.9, 0.95))
|
||||
group["eps"] = group.get("eps", 1e-10)
|
||||
group["weight_decay"] = group.get("weight_decay", 0)
|
||||
assert set(group.keys()) == set(["params", "lr", "betas", "eps", "weight_decay", "use_muon"])
|
||||
super().__init__(param_groups, dict())
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
if group["use_muon"]:
|
||||
params = group["params"]
|
||||
params_pad = params + [torch.empty_like(params[-1])] * (dist.get_world_size() - len(params) % dist.get_world_size())
|
||||
for base_i in range(len(params))[::dist.get_world_size()]:
|
||||
if base_i + dist.get_rank() < len(params):
|
||||
p = params[base_i + dist.get_rank()]
|
||||
if p.grad is None:
|
||||
# continue
|
||||
p.grad = torch.zeros_like(p) # Force synchronization
|
||||
state = self.state[p]
|
||||
if len(state) == 0:
|
||||
state["momentum_buffer"] = torch.zeros_like(p)
|
||||
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
|
||||
p.mul_(1 - group["lr"] * group["weight_decay"])
|
||||
p.add_(update.reshape(p.shape), alpha=-group["lr"])
|
||||
dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()])
|
||||
else:
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
# continue
|
||||
p.grad = torch.zeros_like(p) # Force synchronization
|
||||
state = self.state[p]
|
||||
if len(state) == 0:
|
||||
state["exp_avg"] = torch.zeros_like(p)
|
||||
state["exp_avg_sq"] = torch.zeros_like(p)
|
||||
state["step"] = 0
|
||||
state["step"] += 1
|
||||
update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"],
|
||||
state["step"], group["betas"], group["eps"])
|
||||
p.mul_(1 - group["lr"] * group["weight_decay"])
|
||||
p.add_(update, alpha=-group["lr"])
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class SingleDeviceMuonWithAuxAdam(torch.optim.Optimizer):
|
||||
"""
|
||||
Non-distributed variant of MuonWithAuxAdam.
|
||||
"""
|
||||
def __init__(self, param_groups):
|
||||
for group in param_groups:
|
||||
assert "use_muon" in group
|
||||
if group["use_muon"]:
|
||||
# defaults
|
||||
group["lr"] = group.get("lr", 0.02)
|
||||
group["momentum"] = group.get("momentum", 0.95)
|
||||
group["weight_decay"] = group.get("weight_decay", 0)
|
||||
assert set(group.keys()) == set(["params", "lr", "momentum", "weight_decay", "use_muon"])
|
||||
else:
|
||||
# defaults
|
||||
group["lr"] = group.get("lr", 3e-4)
|
||||
group["betas"] = group.get("betas", (0.9, 0.95))
|
||||
group["eps"] = group.get("eps", 1e-10)
|
||||
group["weight_decay"] = group.get("weight_decay", 0)
|
||||
assert set(group.keys()) == set(["params", "lr", "betas", "eps", "weight_decay", "use_muon"])
|
||||
super().__init__(param_groups, dict())
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
if group["use_muon"]:
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
# continue
|
||||
p.grad = torch.zeros_like(p) # Force synchronization
|
||||
state = self.state[p]
|
||||
if len(state) == 0:
|
||||
state["momentum_buffer"] = torch.zeros_like(p)
|
||||
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
|
||||
p.mul_(1 - group["lr"] * group["weight_decay"])
|
||||
p.add_(update.reshape(p.shape), alpha=-group["lr"])
|
||||
else:
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
# continue
|
||||
p.grad = torch.zeros_like(p) # Force synchronization
|
||||
state = self.state[p]
|
||||
if len(state) == 0:
|
||||
state["exp_avg"] = torch.zeros_like(p)
|
||||
state["exp_avg_sq"] = torch.zeros_like(p)
|
||||
state["step"] = 0
|
||||
state["step"] += 1
|
||||
update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"],
|
||||
state["step"], group["betas"], group["eps"])
|
||||
p.mul_(1 - group["lr"] * group["weight_decay"])
|
||||
p.add_(update, alpha=-group["lr"])
|
||||
|
||||
return loss
|
||||
@ -16,6 +16,7 @@ from torch.utils.data import DataLoader, DistributedSampler
|
||||
from model.model_minimind import MiniMindConfig
|
||||
from dataset.lm_dataset import SFTDataset
|
||||
from trainer.trainer_utils import get_lr, Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, init_model, SkipBatchSampler
|
||||
from trainer.muon import MuonWithAuxAdam, SingleDeviceMuonWithAuxAdam
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
@ -104,6 +105,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
|
||||
parser.add_argument("--wandb_project", type=str, default="MiniMind-Full-SFT", help="wandb项目名")
|
||||
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速(0=否,1=是)")
|
||||
parser.add_argument("--optimizer", type=str, default="adamw", choices=["adamw", "muon"], help="优化器类型(adamw 或 muon)")
|
||||
args = parser.parse_args()
|
||||
|
||||
# ========== 1. 初始化环境和随机种子 ==========
|
||||
@ -135,7 +137,24 @@ if __name__ == "__main__":
|
||||
train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
|
||||
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
|
||||
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
|
||||
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
|
||||
|
||||
# 配置优化器
|
||||
if args.optimizer == "muon":
|
||||
# Muon: 隐藏权重(>=2D)用 Muon,其他用 AdamW
|
||||
hidden_weights = [p for p in model.parameters() if p.ndim >= 2]
|
||||
nonhidden_params = [p for p in model.parameters() if p.ndim < 2]
|
||||
param_groups = [
|
||||
dict(params=hidden_weights, use_muon=True, lr=args.learning_rate, weight_decay=0.01),
|
||||
dict(params=nonhidden_params, use_muon=False, lr=5e-4, betas=(0.9, 0.95)),
|
||||
]
|
||||
if dist.is_initialized():
|
||||
optimizer = MuonWithAuxAdam(param_groups)
|
||||
else:
|
||||
optimizer = SingleDeviceMuonWithAuxAdam(param_groups)
|
||||
Logger(f'Using Muon optimizer with {len(hidden_weights)} hidden weights and {len(nonhidden_params)} non-hidden params')
|
||||
else:
|
||||
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
|
||||
Logger('Using AdamW optimizer')
|
||||
|
||||
# ========== 6. 从ckp恢复状态 ==========
|
||||
start_epoch, start_step = 0, 0
|
||||
@ -167,4 +186,4 @@ if __name__ == "__main__":
|
||||
train_epoch(epoch, loader, len(loader), 0, wandb)
|
||||
|
||||
# ========== 9. 清理分布进程 ==========
|
||||
if dist.is_initialized(): dist.destroy_process_group()
|
||||
if dist.is_initialized(): dist.destroy_process_group()
|
||||
|
||||
@ -16,6 +16,7 @@ from torch.utils.data import DataLoader, DistributedSampler
|
||||
from model.model_minimind import MiniMindConfig
|
||||
from dataset.lm_dataset import PretrainDataset
|
||||
from trainer.trainer_utils import get_lr, Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, init_model, SkipBatchSampler
|
||||
from trainer.muon import MuonWithAuxAdam, SingleDeviceMuonWithAuxAdam
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
@ -103,6 +104,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
|
||||
parser.add_argument("--wandb_project", type=str, default="MiniMind-Pretrain", help="wandb项目名")
|
||||
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速(0=否,1=是)")
|
||||
parser.add_argument("--optimizer", type=str, default="adamw", choices=["adamw", "muon"], help="优化器类型(adamw 或 muon)")
|
||||
args = parser.parse_args()
|
||||
|
||||
# ========== 1. 初始化环境和随机种子 ==========
|
||||
@ -134,7 +136,24 @@ if __name__ == "__main__":
|
||||
train_ds = PretrainDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
|
||||
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
|
||||
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
|
||||
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
|
||||
|
||||
# 配置优化器
|
||||
if args.optimizer == "muon":
|
||||
# Muon: 隐藏权重(>=2D)用 Muon,其他用 AdamW
|
||||
hidden_weights = [p for p in model.parameters() if p.ndim >= 2]
|
||||
nonhidden_params = [p for p in model.parameters() if p.ndim < 2]
|
||||
param_groups = [
|
||||
dict(params=hidden_weights, use_muon=True, lr=args.learning_rate, weight_decay=0.01),
|
||||
dict(params=nonhidden_params, use_muon=False, lr=5e-4, betas=(0.9, 0.95)),
|
||||
]
|
||||
if dist.is_initialized():
|
||||
optimizer = MuonWithAuxAdam(param_groups)
|
||||
else:
|
||||
optimizer = SingleDeviceMuonWithAuxAdam(param_groups)
|
||||
Logger(f'Using Muon optimizer with {len(hidden_weights)} hidden weights and {len(nonhidden_params)} non-hidden params')
|
||||
else:
|
||||
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
|
||||
Logger('Using AdamW optimizer')
|
||||
|
||||
# ========== 6. 从ckp恢复状态 ==========
|
||||
start_epoch, start_step = 0, 0
|
||||
@ -166,4 +185,4 @@ if __name__ == "__main__":
|
||||
train_epoch(epoch, loader, len(loader), 0, wandb)
|
||||
|
||||
# ========== 9. 清理分布进程 ==========
|
||||
if dist.is_initialized(): dist.destroy_process_group()
|
||||
if dist.is_initialized(): dist.destroy_process_group()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user