diff --git a/trainer/muon.py b/trainer/muon.py new file mode 100644 index 0000000..aafa9d7 --- /dev/null +++ b/trainer/muon.py @@ -0,0 +1,286 @@ +import torch +import torch.distributed as dist + + +def zeropower_via_newtonschulz5(G, steps: int): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + # Perform the NS iterations + for _ in range(steps): + A = X @ X.mT + B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng + X = a * X + B @ X + + if G.size(-2) > G.size(-1): + X = X.mT + return X + + +def muon_update(grad, momentum, beta=0.95, ns_steps=5, nesterov=True): + momentum.lerp_(grad, 1 - beta) + update = grad.lerp_(momentum, beta) if nesterov else momentum + if update.ndim == 4: # for the case of conv filters + update = update.view(len(update), -1) + update = zeropower_via_newtonschulz5(update, steps=ns_steps) + update *= max(1, update.size(-2) / update.size(-1))**0.5 + return update + + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. For efficient orthogonalization we use a Newton-Schulz iteration, which has the + advantage that it can be stably run in bfloat16 on the GPU. + + Muon should only be used for hidden weight layers. The input embedding, final output layer, + and any internal gains or biases should be optimized using a standard method such as AdamW. + Hidden convolutional weights can be trained using Muon by viewing them as 2D and then + collapsing their last 3 dimensions. + + Arguments: + lr: The learning rate, in units of spectral norm per update. + weight_decay: The AdamW-style weight decay. + momentum: The momentum. A value of 0.95 here is usually fine. + """ + def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + assert isinstance(params, list) and len(params) >= 1 and isinstance(params[0], torch.nn.Parameter) + params = sorted(params, key=lambda x: x.size(), reverse=True) + super().__init__(params, defaults) + + @torch.no_grad() + def step(self, closure=None): + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * (dist.get_world_size() - len(params) % dist.get_world_size()) + for base_i in range(len(params))[::dist.get_world_size()]: + if base_i + dist.get_rank() < len(params): + p = params[base_i + dist.get_rank()] + if p.grad is None: + # continue + p.grad = torch.zeros_like(p) # Force synchronization + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(p) + update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"]) + p.mul_(1 - group["lr"] * group["weight_decay"]) + p.add_(update.reshape(p.shape), alpha=-group["lr"]) + dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()]) + + return loss + + +class SingleDeviceMuon(torch.optim.Optimizer): + """ + Muon variant for usage in non-distributed settings. + """ + def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + super().__init__(params, defaults) + + @torch.no_grad() + def step(self, closure=None): + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + # continue + p.grad = torch.zeros_like(p) # Force synchronization + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(p) + update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"]) + p.mul_(1 - group["lr"] * group["weight_decay"]) + p.add_(update.reshape(p.shape), alpha=-group["lr"]) + + return loss + + +def adam_update(grad, buf1, buf2, step, betas, eps): + buf1.lerp_(grad, 1 - betas[0]) + buf2.lerp_(grad.square(), 1 - betas[1]) + buf1c = buf1 / (1 - betas[0]**step) + buf2c = buf2 / (1 - betas[1]**step) + return buf1c / (buf2c.sqrt() + eps) + + +class MuonWithAuxAdam(torch.optim.Optimizer): + """ + Distributed Muon variant that can be used for all parameters in the network, since it runs an + internal AdamW for the parameters that are not compatible with Muon. The user must manually + specify which parameters shall be optimized with Muon and which with Adam by passing in a + list of param_groups with the `use_muon` flag set. + + The point of this class is to allow the user to have a single optimizer in their code, rather + than having both a Muon and an Adam which each need to be stepped. + + You can see an example usage below: + + https://github.com/KellerJordan/modded-nanogpt/blob/master/records/052525_MuonWithAuxAdamExample/b01550f9-03d8-4a9c-86fe-4ab434f1c5e0.txt#L470 + ``` + hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] + embed_params = [p for n, p in model.named_parameters() if "embed" in n] + scalar_params = [p for p in model.parameters() if p.ndim < 2] + head_params = [model.lm_head.weight] + + from muon import MuonWithAuxAdam + adam_groups = [dict(params=head_params, lr=0.22), dict(params=embed_params, lr=0.6), dict(params=scalar_params, lr=0.04)] + adam_groups = [dict(**g, betas=(0.8, 0.95), eps=1e-10, use_muon=False) for g in adam_groups] + muon_group = dict(params=hidden_matrix_params, lr=0.05, momentum=0.95, use_muon=True) + param_groups = [*adam_groups, muon_group] + optimizer = MuonWithAuxAdam(param_groups) + ``` + """ + def __init__(self, param_groups): + for group in param_groups: + assert "use_muon" in group + if group["use_muon"]: + group["params"] = sorted(group["params"], key=lambda x: x.size(), reverse=True) + # defaults + group["lr"] = group.get("lr", 0.02) + group["momentum"] = group.get("momentum", 0.95) + group["weight_decay"] = group.get("weight_decay", 0) + assert set(group.keys()) == set(["params", "lr", "momentum", "weight_decay", "use_muon"]) + else: + # defaults + group["lr"] = group.get("lr", 3e-4) + group["betas"] = group.get("betas", (0.9, 0.95)) + group["eps"] = group.get("eps", 1e-10) + group["weight_decay"] = group.get("weight_decay", 0) + assert set(group.keys()) == set(["params", "lr", "betas", "eps", "weight_decay", "use_muon"]) + super().__init__(param_groups, dict()) + + @torch.no_grad() + def step(self, closure=None): + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + if group["use_muon"]: + params = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * (dist.get_world_size() - len(params) % dist.get_world_size()) + for base_i in range(len(params))[::dist.get_world_size()]: + if base_i + dist.get_rank() < len(params): + p = params[base_i + dist.get_rank()] + if p.grad is None: + # continue + p.grad = torch.zeros_like(p) # Force synchronization + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(p) + update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"]) + p.mul_(1 - group["lr"] * group["weight_decay"]) + p.add_(update.reshape(p.shape), alpha=-group["lr"]) + dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()]) + else: + for p in group["params"]: + if p.grad is None: + # continue + p.grad = torch.zeros_like(p) # Force synchronization + state = self.state[p] + if len(state) == 0: + state["exp_avg"] = torch.zeros_like(p) + state["exp_avg_sq"] = torch.zeros_like(p) + state["step"] = 0 + state["step"] += 1 + update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"], + state["step"], group["betas"], group["eps"]) + p.mul_(1 - group["lr"] * group["weight_decay"]) + p.add_(update, alpha=-group["lr"]) + + return loss + + +class SingleDeviceMuonWithAuxAdam(torch.optim.Optimizer): + """ + Non-distributed variant of MuonWithAuxAdam. + """ + def __init__(self, param_groups): + for group in param_groups: + assert "use_muon" in group + if group["use_muon"]: + # defaults + group["lr"] = group.get("lr", 0.02) + group["momentum"] = group.get("momentum", 0.95) + group["weight_decay"] = group.get("weight_decay", 0) + assert set(group.keys()) == set(["params", "lr", "momentum", "weight_decay", "use_muon"]) + else: + # defaults + group["lr"] = group.get("lr", 3e-4) + group["betas"] = group.get("betas", (0.9, 0.95)) + group["eps"] = group.get("eps", 1e-10) + group["weight_decay"] = group.get("weight_decay", 0) + assert set(group.keys()) == set(["params", "lr", "betas", "eps", "weight_decay", "use_muon"]) + super().__init__(param_groups, dict()) + + @torch.no_grad() + def step(self, closure=None): + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + if group["use_muon"]: + for p in group["params"]: + if p.grad is None: + # continue + p.grad = torch.zeros_like(p) # Force synchronization + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(p) + update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"]) + p.mul_(1 - group["lr"] * group["weight_decay"]) + p.add_(update.reshape(p.shape), alpha=-group["lr"]) + else: + for p in group["params"]: + if p.grad is None: + # continue + p.grad = torch.zeros_like(p) # Force synchronization + state = self.state[p] + if len(state) == 0: + state["exp_avg"] = torch.zeros_like(p) + state["exp_avg_sq"] = torch.zeros_like(p) + state["step"] = 0 + state["step"] += 1 + update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"], + state["step"], group["betas"], group["eps"]) + p.mul_(1 - group["lr"] * group["weight_decay"]) + p.add_(update, alpha=-group["lr"]) + + return loss diff --git a/trainer/train_full_sft.py b/trainer/train_full_sft.py index 51c755b..715f9bf 100644 --- a/trainer/train_full_sft.py +++ b/trainer/train_full_sft.py @@ -16,6 +16,7 @@ from torch.utils.data import DataLoader, DistributedSampler from model.model_minimind import MiniMindConfig from dataset.lm_dataset import SFTDataset from trainer.trainer_utils import get_lr, Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, init_model, SkipBatchSampler +from trainer.muon import MuonWithAuxAdam, SingleDeviceMuonWithAuxAdam warnings.filterwarnings('ignore') @@ -104,6 +105,7 @@ if __name__ == "__main__": parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb") parser.add_argument("--wandb_project", type=str, default="MiniMind-Full-SFT", help="wandb项目名") parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速(0=否,1=是)") + parser.add_argument("--optimizer", type=str, default="adamw", choices=["adamw", "muon"], help="优化器类型(adamw 或 muon)") args = parser.parse_args() # ========== 1. 初始化环境和随机种子 ========== @@ -135,7 +137,24 @@ if __name__ == "__main__": train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len) train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16')) - optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate) + + # 配置优化器 + if args.optimizer == "muon": + # Muon: 隐藏权重(>=2D)用 Muon,其他用 AdamW + hidden_weights = [p for p in model.parameters() if p.ndim >= 2] + nonhidden_params = [p for p in model.parameters() if p.ndim < 2] + param_groups = [ + dict(params=hidden_weights, use_muon=True, lr=args.learning_rate, weight_decay=0.01), + dict(params=nonhidden_params, use_muon=False, lr=5e-4, betas=(0.9, 0.95)), + ] + if dist.is_initialized(): + optimizer = MuonWithAuxAdam(param_groups) + else: + optimizer = SingleDeviceMuonWithAuxAdam(param_groups) + Logger(f'Using Muon optimizer with {len(hidden_weights)} hidden weights and {len(nonhidden_params)} non-hidden params') + else: + optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate) + Logger('Using AdamW optimizer') # ========== 6. 从ckp恢复状态 ========== start_epoch, start_step = 0, 0 @@ -167,4 +186,4 @@ if __name__ == "__main__": train_epoch(epoch, loader, len(loader), 0, wandb) # ========== 9. 清理分布进程 ========== - if dist.is_initialized(): dist.destroy_process_group() \ No newline at end of file + if dist.is_initialized(): dist.destroy_process_group() diff --git a/trainer/train_pretrain.py b/trainer/train_pretrain.py index 5d33de4..8893c69 100644 --- a/trainer/train_pretrain.py +++ b/trainer/train_pretrain.py @@ -16,6 +16,7 @@ from torch.utils.data import DataLoader, DistributedSampler from model.model_minimind import MiniMindConfig from dataset.lm_dataset import PretrainDataset from trainer.trainer_utils import get_lr, Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, init_model, SkipBatchSampler +from trainer.muon import MuonWithAuxAdam, SingleDeviceMuonWithAuxAdam warnings.filterwarnings('ignore') @@ -103,6 +104,7 @@ if __name__ == "__main__": parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb") parser.add_argument("--wandb_project", type=str, default="MiniMind-Pretrain", help="wandb项目名") parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速(0=否,1=是)") + parser.add_argument("--optimizer", type=str, default="adamw", choices=["adamw", "muon"], help="优化器类型(adamw 或 muon)") args = parser.parse_args() # ========== 1. 初始化环境和随机种子 ========== @@ -134,7 +136,24 @@ if __name__ == "__main__": train_ds = PretrainDataset(args.data_path, tokenizer, max_length=args.max_seq_len) train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16')) - optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate) + + # 配置优化器 + if args.optimizer == "muon": + # Muon: 隐藏权重(>=2D)用 Muon,其他用 AdamW + hidden_weights = [p for p in model.parameters() if p.ndim >= 2] + nonhidden_params = [p for p in model.parameters() if p.ndim < 2] + param_groups = [ + dict(params=hidden_weights, use_muon=True, lr=args.learning_rate, weight_decay=0.01), + dict(params=nonhidden_params, use_muon=False, lr=5e-4, betas=(0.9, 0.95)), + ] + if dist.is_initialized(): + optimizer = MuonWithAuxAdam(param_groups) + else: + optimizer = SingleDeviceMuonWithAuxAdam(param_groups) + Logger(f'Using Muon optimizer with {len(hidden_weights)} hidden weights and {len(nonhidden_params)} non-hidden params') + else: + optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate) + Logger('Using AdamW optimizer') # ========== 6. 从ckp恢复状态 ========== start_epoch, start_step = 0, 0 @@ -166,4 +185,4 @@ if __name__ == "__main__": train_epoch(epoch, loader, len(loader), 0, wandb) # ========== 9. 清理分布进程 ========== - if dist.is_initialized(): dist.destroy_process_group() \ No newline at end of file + if dist.is_initialized(): dist.destroy_process_group()