fix sft resume with compile mode

Load checkpoint state before torch.compile to avoid key mismatch on resume, and make optimizer/scaler restore tolerant to missing fields in older checkpoints.

Made-with: Cursor
This commit is contained in:
王得利 2026-03-19 12:44:53 +00:00
parent 349e74ec7b
commit 8972dab6f5

View File

@ -123,22 +123,31 @@ if __name__ == "__main__":
# ========== 5. 定义模型、数据、优化器 ==========
model, tokenizer = init_model(lm_config, args.from_weight, device=args.device)
if args.use_compile == 1:
model = torch.compile(model)
Logger('torch.compile enabled')
train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
# ========== 6. 从ckp恢复状态 ==========
# ========== 6. 从ckp恢复状态(需在 torch.compile 之前) ==========
start_epoch, start_step = 0, 0
if ckp_data:
model.load_state_dict(ckp_data['model'])
optimizer.load_state_dict(ckp_data['optimizer'])
scaler.load_state_dict(ckp_data['scaler'])
model.load_state_dict(ckp_data['model'], strict=False)
if ckp_data.get('optimizer') is not None:
try:
optimizer.load_state_dict(ckp_data['optimizer'])
except Exception:
pass
if ckp_data.get('scaler') is not None:
try:
scaler.load_state_dict(ckp_data['scaler'])
except Exception:
pass
start_epoch = ckp_data['epoch']
start_step = ckp_data.get('step', 0)
if args.use_compile == 1:
model = torch.compile(model)
Logger('torch.compile enabled')
# ========== 7. DDP包模型 ==========
if dist.is_initialized():