diff --git a/trainer/train_full_sft.py b/trainer/train_full_sft.py index cc59cc7..5b4cc86 100644 --- a/trainer/train_full_sft.py +++ b/trainer/train_full_sft.py @@ -123,22 +123,31 @@ if __name__ == "__main__": # ========== 5. 定义模型、数据、优化器 ========== model, tokenizer = init_model(lm_config, args.from_weight, device=args.device) - if args.use_compile == 1: - model = torch.compile(model) - Logger('torch.compile enabled') train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len) train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16')) optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate) - # ========== 6. 从ckp恢复状态 ========== + # ========== 6. 从ckp恢复状态(需在 torch.compile 之前) ========== start_epoch, start_step = 0, 0 if ckp_data: - model.load_state_dict(ckp_data['model']) - optimizer.load_state_dict(ckp_data['optimizer']) - scaler.load_state_dict(ckp_data['scaler']) + model.load_state_dict(ckp_data['model'], strict=False) + if ckp_data.get('optimizer') is not None: + try: + optimizer.load_state_dict(ckp_data['optimizer']) + except Exception: + pass + if ckp_data.get('scaler') is not None: + try: + scaler.load_state_dict(ckp_data['scaler']) + except Exception: + pass start_epoch = ckp_data['epoch'] start_step = ckp_data.get('step', 0) + + if args.use_compile == 1: + model = torch.compile(model) + Logger('torch.compile enabled') # ========== 7. DDP包模型 ========== if dist.is_initialized():