From fea69cf338b8d12d0f062c151c4b75ef9328c0ae Mon Sep 17 00:00:00 2001 From: jingyaogong Date: Sun, 18 Jan 2026 16:56:29 +0800 Subject: [PATCH] [fix] data skip --- trainer/train_distillation.py | 2 +- trainer/train_dpo.py | 2 +- trainer/train_full_sft.py | 2 +- trainer/train_grpo.py | 2 +- trainer/train_lora.py | 2 +- trainer/train_ppo.py | 2 +- trainer/train_pretrain.py | 2 +- trainer/train_reason.py | 2 +- trainer/train_spo.py | 2 +- 9 files changed, 9 insertions(+), 9 deletions(-) diff --git a/trainer/train_distillation.py b/trainer/train_distillation.py index 2ec6fbe..5cc6269 100644 --- a/trainer/train_distillation.py +++ b/trainer/train_distillation.py @@ -222,7 +222,7 @@ if __name__ == "__main__": for epoch in range(start_epoch, args.epochs): train_sampler and train_sampler.set_epoch(epoch) setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist() - skip = start_step + 1 if (epoch == start_epoch and start_step > 0) else 0 + skip = start_step if (epoch == start_epoch and start_step > 0) else 0 batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip) loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True) if skip > 0: diff --git a/trainer/train_dpo.py b/trainer/train_dpo.py index 31102cf..f1e567d 100644 --- a/trainer/train_dpo.py +++ b/trainer/train_dpo.py @@ -206,7 +206,7 @@ if __name__ == "__main__": for epoch in range(start_epoch, args.epochs): train_sampler and train_sampler.set_epoch(epoch) setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist() - skip = start_step + 1 if (epoch == start_epoch and start_step > 0) else 0 + skip = start_step if (epoch == start_epoch and start_step > 0) else 0 batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip) loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True) if skip > 0: diff --git a/trainer/train_full_sft.py b/trainer/train_full_sft.py index 7b64a23..cc59cc7 100644 --- a/trainer/train_full_sft.py +++ b/trainer/train_full_sft.py @@ -149,7 +149,7 @@ if __name__ == "__main__": for epoch in range(start_epoch, args.epochs): train_sampler and train_sampler.set_epoch(epoch) setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist() - skip = start_step + 1 if (epoch == start_epoch and start_step > 0) else 0 + skip = start_step if (epoch == start_epoch and start_step > 0) else 0 batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip) loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True) if skip > 0: diff --git a/trainer/train_grpo.py b/trainer/train_grpo.py index bab4223..5e63779 100755 --- a/trainer/train_grpo.py +++ b/trainer/train_grpo.py @@ -291,7 +291,7 @@ if __name__ == "__main__": for epoch in range(start_epoch, args.epochs): train_sampler and train_sampler.set_epoch(epoch) setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist() - skip = start_step + 1 if (epoch == start_epoch and start_step > 0) else 0 + skip = start_step if (epoch == start_epoch and start_step > 0) else 0 batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip) loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True) if skip > 0: diff --git a/trainer/train_lora.py b/trainer/train_lora.py index 83d54c6..f5b0235 100644 --- a/trainer/train_lora.py +++ b/trainer/train_lora.py @@ -162,7 +162,7 @@ if __name__ == "__main__": for epoch in range(start_epoch, args.epochs): train_sampler and train_sampler.set_epoch(epoch) setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist() - skip = start_step + 1 if (epoch == start_epoch and start_step > 0) else 0 + skip = start_step if (epoch == start_epoch and start_step > 0) else 0 batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip) loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True) if skip > 0: diff --git a/trainer/train_ppo.py b/trainer/train_ppo.py index 47af75a..88c11bc 100644 --- a/trainer/train_ppo.py +++ b/trainer/train_ppo.py @@ -366,7 +366,7 @@ if __name__ == "__main__": for epoch in range(start_epoch, args.epochs): train_sampler and train_sampler.set_epoch(epoch) setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist() - skip = start_step + 1 if (epoch == start_epoch and start_step > 0) else 0 + skip = start_step if (epoch == start_epoch and start_step > 0) else 0 batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip) loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True) if skip > 0: diff --git a/trainer/train_pretrain.py b/trainer/train_pretrain.py index ea0ca21..a8ad97f 100644 --- a/trainer/train_pretrain.py +++ b/trainer/train_pretrain.py @@ -148,7 +148,7 @@ if __name__ == "__main__": for epoch in range(start_epoch, args.epochs): train_sampler and train_sampler.set_epoch(epoch) setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist() - skip = start_step + 1 if (epoch == start_epoch and start_step > 0) else 0 + skip = start_step if (epoch == start_epoch and start_step > 0) else 0 batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip) loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True) if skip > 0: diff --git a/trainer/train_reason.py b/trainer/train_reason.py index eaae014..a9b2d55 100644 --- a/trainer/train_reason.py +++ b/trainer/train_reason.py @@ -166,7 +166,7 @@ if __name__ == "__main__": for epoch in range(start_epoch, args.epochs): train_sampler and train_sampler.set_epoch(epoch) setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist() - skip = start_step + 1 if (epoch == start_epoch and start_step > 0) else 0 + skip = start_step if (epoch == start_epoch and start_step > 0) else 0 batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip) loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True) if skip > 0: diff --git a/trainer/train_spo.py b/trainer/train_spo.py index b6308d5..219ddba 100755 --- a/trainer/train_spo.py +++ b/trainer/train_spo.py @@ -341,7 +341,7 @@ if __name__ == "__main__": for epoch in range(start_epoch, args.epochs): train_sampler and train_sampler.set_epoch(epoch) setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist() - skip = start_step + 1 if (epoch == start_epoch and start_step > 0) else 0 + skip = start_step if (epoch == start_epoch and start_step > 0) else 0 batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip) loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True) if skip > 0: