From 8972dab6f531ed2b14ff98e7a4f0d7ca65bc03de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E5=BE=97=E5=88=A9?= Date: Thu, 19 Mar 2026 12:44:53 +0000 Subject: [PATCH] fix sft resume with compile mode Load checkpoint state before torch.compile to avoid key mismatch on resume, and make optimizer/scaler restore tolerant to missing fields in older checkpoints. Made-with: Cursor --- trainer/train_full_sft.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/trainer/train_full_sft.py b/trainer/train_full_sft.py index cc59cc7..5b4cc86 100644 --- a/trainer/train_full_sft.py +++ b/trainer/train_full_sft.py @@ -123,22 +123,31 @@ if __name__ == "__main__": # ========== 5. 定义模型、数据、优化器 ========== model, tokenizer = init_model(lm_config, args.from_weight, device=args.device) - if args.use_compile == 1: - model = torch.compile(model) - Logger('torch.compile enabled') train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len) train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16')) optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate) - # ========== 6. 从ckp恢复状态 ========== + # ========== 6. 从ckp恢复状态(需在 torch.compile 之前) ========== start_epoch, start_step = 0, 0 if ckp_data: - model.load_state_dict(ckp_data['model']) - optimizer.load_state_dict(ckp_data['optimizer']) - scaler.load_state_dict(ckp_data['scaler']) + model.load_state_dict(ckp_data['model'], strict=False) + if ckp_data.get('optimizer') is not None: + try: + optimizer.load_state_dict(ckp_data['optimizer']) + except Exception: + pass + if ckp_data.get('scaler') is not None: + try: + scaler.load_state_dict(ckp_data['scaler']) + except Exception: + pass start_epoch = ckp_data['epoch'] start_step = ckp_data.get('step', 0) + + if args.use_compile == 1: + model = torch.compile(model) + Logger('torch.compile enabled') # ========== 7. DDP包模型 ========== if dist.is_initialized():