mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-04-25 08:48:16 +08:00
[feat] update trainer
This commit is contained in:
parent
8f7e07b8ef
commit
acd5925193
@ -119,9 +119,9 @@ if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="MiniMind DPO (Direct Preference Optimization)")
|
||||
parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
|
||||
parser.add_argument('--save_weight', default='dpo', type=str, help="保存权重的前缀名")
|
||||
parser.add_argument("--epochs", type=int, default=2, help="训练轮数")
|
||||
parser.add_argument("--epochs", type=int, default=1, help="训练轮数")
|
||||
parser.add_argument("--batch_size", type=int, default=4, help="batch size")
|
||||
parser.add_argument("--learning_rate", type=float, default=5e-8, help="初始学习率(建议<=5e-8避免遗忘)")
|
||||
parser.add_argument("--learning_rate", type=float, default=4e-8, help="初始学习率(建议<=5e-8避免遗忘)")
|
||||
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
|
||||
parser.add_argument("--num_workers", type=int, default=1, help="数据加载线程数")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user