diff --git a/trainer/train_dpo.py b/trainer/train_dpo.py index 97716a5..fc2eb82 100644 --- a/trainer/train_dpo.py +++ b/trainer/train_dpo.py @@ -119,9 +119,9 @@ if __name__ == "__main__": parser = argparse.ArgumentParser(description="MiniMind DPO (Direct Preference Optimization)") parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录") parser.add_argument('--save_weight', default='dpo', type=str, help="保存权重的前缀名") - parser.add_argument("--epochs", type=int, default=2, help="训练轮数") + parser.add_argument("--epochs", type=int, default=1, help="训练轮数") parser.add_argument("--batch_size", type=int, default=4, help="batch size") - parser.add_argument("--learning_rate", type=float, default=5e-8, help="初始学习率(建议<=5e-8避免遗忘)") + parser.add_argument("--learning_rate", type=float, default=4e-8, help="初始学习率(建议<=5e-8避免遗忘)") parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备") parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型") parser.add_argument("--num_workers", type=int, default=1, help="数据加载线程数")