mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-04-25 08:48:16 +08:00
fix sft resume with compile mode
Load checkpoint state before torch.compile to avoid key mismatch on resume, and make optimizer/scaler restore tolerant to missing fields in older checkpoints. Made-with: Cursor
This commit is contained in:
parent
349e74ec7b
commit
8972dab6f5
@ -123,22 +123,31 @@ if __name__ == "__main__":
|
||||
|
||||
# ========== 5. 定义模型、数据、优化器 ==========
|
||||
model, tokenizer = init_model(lm_config, args.from_weight, device=args.device)
|
||||
if args.use_compile == 1:
|
||||
model = torch.compile(model)
|
||||
Logger('torch.compile enabled')
|
||||
train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
|
||||
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
|
||||
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
|
||||
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
|
||||
|
||||
# ========== 6. 从ckp恢复状态 ==========
|
||||
# ========== 6. 从ckp恢复状态(需在 torch.compile 之前) ==========
|
||||
start_epoch, start_step = 0, 0
|
||||
if ckp_data:
|
||||
model.load_state_dict(ckp_data['model'])
|
||||
optimizer.load_state_dict(ckp_data['optimizer'])
|
||||
scaler.load_state_dict(ckp_data['scaler'])
|
||||
model.load_state_dict(ckp_data['model'], strict=False)
|
||||
if ckp_data.get('optimizer') is not None:
|
||||
try:
|
||||
optimizer.load_state_dict(ckp_data['optimizer'])
|
||||
except Exception:
|
||||
pass
|
||||
if ckp_data.get('scaler') is not None:
|
||||
try:
|
||||
scaler.load_state_dict(ckp_data['scaler'])
|
||||
except Exception:
|
||||
pass
|
||||
start_epoch = ckp_data['epoch']
|
||||
start_step = ckp_data.get('step', 0)
|
||||
|
||||
if args.use_compile == 1:
|
||||
model = torch.compile(model)
|
||||
Logger('torch.compile enabled')
|
||||
|
||||
# ========== 7. DDP包模型 ==========
|
||||
if dist.is_initialized():
|
||||
|
||||
Loading…
Reference in New Issue
Block a user