diff --git a/trainer/train_full_sft.py b/trainer/train_full_sft.py index adbfdce..133385f 100644 --- a/trainer/train_full_sft.py +++ b/trainer/train_full_sft.py @@ -139,7 +139,7 @@ if __name__ == "__main__": nonhidden_params = [p for p in model.parameters() if p.ndim < 2] param_groups = [ dict(params=hidden_weights, use_muon=True, lr=args.learning_rate, weight_decay=0.01), - dict(params=nonhidden_params, use_muon=False, lr=args.learning_rate, betas=(0.9, 0.95)), + dict(params=nonhidden_params, use_muon=False, lr=5e-4, betas=(0.9, 0.95)), ] if dist.is_initialized(): optimizer = MuonWithAuxAdam(param_groups) @@ -178,4 +178,4 @@ if __name__ == "__main__": train_epoch(epoch, loader, len(loader), 0, wandb) # ========== 9. 清理分布进程 ========== - if dist.is_initialized(): dist.destroy_process_group() \ No newline at end of file + if dist.is_initialized(): dist.destroy_process_group() diff --git a/trainer/train_pretrain.py b/trainer/train_pretrain.py index 7b3cb1d..a4330cc 100644 --- a/trainer/train_pretrain.py +++ b/trainer/train_pretrain.py @@ -138,7 +138,7 @@ if __name__ == "__main__": nonhidden_params = [p for p in model.parameters() if p.ndim < 2] param_groups = [ dict(params=hidden_weights, use_muon=True, lr=args.learning_rate, weight_decay=0.01), - dict(params=nonhidden_params, use_muon=False, lr=args.learning_rate, betas=(0.9, 0.95)), + dict(params=nonhidden_params, use_muon=False, lr=5e-4, betas=(0.9, 0.95)), ] if dist.is_initialized(): optimizer = MuonWithAuxAdam(param_groups) @@ -177,4 +177,4 @@ if __name__ == "__main__": train_epoch(epoch, loader, len(loader), 0, wandb) # ========== 9. 清理分布进程 ========== - if dist.is_initialized(): dist.destroy_process_group() \ No newline at end of file + if dist.is_initialized(): dist.destroy_process_group()