diff --git a/trainer/train_full_sft.py b/trainer/train_full_sft.py index 9b7f011..cea2ffb 100644 --- a/trainer/train_full_sft.py +++ b/trainer/train_full_sft.py @@ -87,7 +87,7 @@ if __name__ == "__main__": parser.add_argument('--save_weight', default='full_sft', type=str, help="保存权重的前缀名") parser.add_argument("--epochs", type=int, default=2, help="训练轮数") parser.add_argument("--batch_size", type=int, default=16, help="batch size") - parser.add_argument("--learning_rate", type=float, default=5e-7, help="初始学习率") + parser.add_argument("--learning_rate", type=float, default=1e-6, help="初始学习率") parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备") parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型") parser.add_argument("--num_workers", type=int, default=8, help="数据加载线程数")