From 01c04f519b39f4a3cab3f9f794f6f5a12499724c Mon Sep 17 00:00:00 2001 From: guo-sj Date: Mon, 2 Feb 2026 14:48:22 +0800 Subject: [PATCH] adjust nonhidden_params learning_rate to 5e-4 --- trainer/train_full_sft.py | 4 ++-- trainer/train_pretrain.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/trainer/train_full_sft.py b/trainer/train_full_sft.py index adbfdce..133385f 100644 --- a/trainer/train_full_sft.py +++ b/trainer/train_full_sft.py @@ -139,7 +139,7 @@ if __name__ == "__main__": nonhidden_params = [p for p in model.parameters() if p.ndim < 2] param_groups = [ dict(params=hidden_weights, use_muon=True, lr=args.learning_rate, weight_decay=0.01), - dict(params=nonhidden_params, use_muon=False, lr=args.learning_rate, betas=(0.9, 0.95)), + dict(params=nonhidden_params, use_muon=False, lr=5e-4, betas=(0.9, 0.95)), ] if dist.is_initialized(): optimizer = MuonWithAuxAdam(param_groups) @@ -178,4 +178,4 @@ if __name__ == "__main__": train_epoch(epoch, loader, len(loader), 0, wandb) # ========== 9. 清理分布进程 ========== - if dist.is_initialized(): dist.destroy_process_group() \ No newline at end of file + if dist.is_initialized(): dist.destroy_process_group() diff --git a/trainer/train_pretrain.py b/trainer/train_pretrain.py index 7b3cb1d..a4330cc 100644 --- a/trainer/train_pretrain.py +++ b/trainer/train_pretrain.py @@ -138,7 +138,7 @@ if __name__ == "__main__": nonhidden_params = [p for p in model.parameters() if p.ndim < 2] param_groups = [ dict(params=hidden_weights, use_muon=True, lr=args.learning_rate, weight_decay=0.01), - dict(params=nonhidden_params, use_muon=False, lr=args.learning_rate, betas=(0.9, 0.95)), + dict(params=nonhidden_params, use_muon=False, lr=5e-4, betas=(0.9, 0.95)), ] if dist.is_initialized(): optimizer = MuonWithAuxAdam(param_groups) @@ -177,4 +177,4 @@ if __name__ == "__main__": train_epoch(epoch, loader, len(loader), 0, wandb) # ========== 9. 清理分布进程 ========== - if dist.is_initialized(): dist.destroy_process_group() \ No newline at end of file + if dist.is_initialized(): dist.destroy_process_group()