[feat] add compile

This commit is contained in:
jingyaogong
2026-01-14 14:42:30 +08:00
parent 1279a61681
commit 81d24a4f16
9 changed files with 36 additions and 0 deletions
+4
View File
@@ -162,6 +162,7 @@ if __name__ == "__main__":
parser.add_argument('--temperature', default=1.5, type=float, help="蒸馏温度(推荐范围1.0-2.0")
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
parser.add_argument("--wandb_project", type=str, default="MiniMind-Distillation", help="wandb项目名")
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速(0=否,1=是)")
args = parser.parse_args()
# ========== 1. 初始化环境和随机种子 ==========
@@ -191,6 +192,9 @@ if __name__ == "__main__":
# ========== 5. 定义学生和教师模型 ==========
model, tokenizer = init_model(lm_config_student, args.from_student_weight, device=args.device)
if args.use_compile == 1:
model = torch.compile(model)
Logger('torch.compile enabled')
Logger(f'学生模型总参数量:{sum(p.numel() for p in model.parameters()) / 1e6:.3f} M')
teacher_model, _ = init_model(lm_config_teacher, args.from_teacher_weight, device=args.device)
teacher_model.eval()
+4
View File
@@ -146,6 +146,7 @@ if __name__ == "__main__":
parser.add_argument('--beta', default=0.1, type=float, help="DPO中的beta参数")
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
parser.add_argument("--wandb_project", type=str, default="MiniMind-DPO", help="wandb项目名")
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速(0=否,1=是)")
args = parser.parse_args()
# ========== 1. 初始化环境和随机种子 ==========
@@ -174,6 +175,9 @@ if __name__ == "__main__":
# ========== 5. 定义模型和参考模型 ==========
model, tokenizer = init_model(lm_config, args.from_weight, device=args.device)
if args.use_compile == 1:
model = torch.compile(model)
Logger('torch.compile enabled')
Logger(f'策略模型总参数量:{sum(p.numel() for p in model.parameters()) / 1e6:.3f} M')
# 初始化参考模型(ref_model冻结)
ref_model, _ = init_model(lm_config, args.from_weight, device=args.device)
+4
View File
@@ -106,6 +106,7 @@ if __name__ == "__main__":
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训(0=否,1=是)")
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
parser.add_argument("--wandb_project", type=str, default="MiniMind-Full-SFT", help="wandb项目名")
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速(0=否,1=是)")
args = parser.parse_args()
# ========== 1. 初始化环境和随机种子 ==========
@@ -134,6 +135,9 @@ if __name__ == "__main__":
# ========== 5. 定义模型、数据、优化器 ==========
model, tokenizer = init_model(lm_config, args.from_weight, device=args.device)
if args.use_compile == 1:
model = torch.compile(model)
Logger('torch.compile enabled')
train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
+4
View File
@@ -218,6 +218,7 @@ if __name__ == "__main__":
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训(0=否,1=是)")
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
parser.add_argument("--wandb_project", type=str, default="MiniMind-GRPO", help="wandb项目名")
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速(0=否,1=是)")
args = parser.parse_args()
# ========== 1. 初始化环境和随机种子 ==========
@@ -249,6 +250,9 @@ if __name__ == "__main__":
base_weight = "reason" if args.reasoning == 1 else "full_sft"
# Policy模型
model, tokenizer = init_model(lm_config, base_weight, device=args.device)
if args.use_compile == 1:
model = torch.compile(model)
Logger('torch.compile enabled')
# Reference模型
ref_model, _ = init_model(lm_config, base_weight, device=args.device)
ref_model = ref_model.eval().requires_grad_(False)
+4
View File
@@ -100,6 +100,7 @@ if __name__ == "__main__":
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训(0=否,1=是)")
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
parser.add_argument("--wandb_project", type=str, default="MiniMind-LoRA", help="wandb项目名")
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速(0=否,1=是)")
args = parser.parse_args()
# ========== 1. 初始化环境和随机种子 ==========
@@ -128,6 +129,9 @@ if __name__ == "__main__":
# ========== 5. 定义模型、应用LoRA、冻结非LoRA参数 ==========
model, tokenizer = init_model(lm_config, args.from_weight, device=args.device)
if args.use_compile == 1:
model = torch.compile(model)
Logger('torch.compile enabled')
apply_lora(model)
# 统计参数
+4
View File
@@ -274,6 +274,7 @@ if __name__ == "__main__":
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训(0=否,1=是)")
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
parser.add_argument("--wandb_project", type=str, default="MiniMind-PPO", help="wandb项目名")
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速(0=否,1=是)")
args = parser.parse_args()
# ========== 1. 初始化环境和随机种子 ==========
@@ -304,6 +305,9 @@ if __name__ == "__main__":
base_weight = "reason" if args.reasoning == 1 else "full_sft"
# Actor模型
actor_model, tokenizer = init_model(lm_config, base_weight, device=args.device)
if args.use_compile == 1:
actor_model = torch.compile(actor_model)
Logger('torch.compile enabled')
# Old Actor模型
old_actor_model, _ = init_model(lm_config, base_weight, device=args.device)
old_actor_model = old_actor_model.eval().requires_grad_(False)
+4
View File
@@ -105,6 +105,7 @@ if __name__ == "__main__":
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训(0=否,1=是)")
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
parser.add_argument("--wandb_project", type=str, default="MiniMind-Pretrain", help="wandb项目名")
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速(0=否,1=是)")
args = parser.parse_args()
# ========== 1. 初始化环境和随机种子 ==========
@@ -133,6 +134,9 @@ if __name__ == "__main__":
# ========== 5. 定义模型、数据、优化器 ==========
model, tokenizer = init_model(lm_config, args.from_weight, device=args.device)
if args.use_compile == 1:
model = torch.compile(model)
Logger('torch.compile enabled')
train_ds = PretrainDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
+4
View File
@@ -118,6 +118,7 @@ if __name__ == "__main__":
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训(0=否,1=是)")
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
parser.add_argument("--wandb_project", type=str, default="MiniMind-Reasoning", help="wandb项目名")
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速(0=否,1=是)")
args = parser.parse_args()
# ========== 1. 初始化环境和随机种子 ==========
@@ -146,6 +147,9 @@ if __name__ == "__main__":
# ========== 5. 定义模型、数据、优化器 ==========
model, tokenizer = init_model(lm_config, args.from_weight, device=args.device)
if args.use_compile == 1:
model = torch.compile(model)
Logger('torch.compile enabled')
train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
+4
View File
@@ -265,6 +265,7 @@ if __name__ == "__main__":
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训(0=否,1=是)")
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
parser.add_argument("--wandb_project", type=str, default="MiniMind-SPO", help="wandb项目名")
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速(0=否,1=是)")
args = parser.parse_args()
# ========== 1. 初始化环境和随机种子 ==========
@@ -296,6 +297,9 @@ if __name__ == "__main__":
base_weight = "reason" if args.reasoning == 1 else "full_sft"
# Policy模型
model, tokenizer = init_model(lm_config, base_weight, device=args.device)
if args.use_compile == 1:
model = torch.compile(model)
Logger('torch.compile enabled')
# Reference模型
ref_model, _ = init_model(lm_config, base_weight, device=args.device)
ref_model = ref_model.eval().requires_grad_(False)