From a82526da11ea5c4c222b19cbd7736a774012d75a Mon Sep 17 00:00:00 2001 From: jingyaogong Date: Thu, 23 Oct 2025 20:13:28 +0800 Subject: [PATCH] [feat] shuffle data --- trainer/train_distill_reason.py | 2 +- trainer/train_distillation.py | 2 +- trainer/train_dpo.py | 2 +- trainer/train_full_sft.py | 2 +- trainer/train_grpo.py | 2 +- trainer/train_lora.py | 2 +- trainer/train_ppo.py | 2 +- trainer/train_pretrain.py | 2 +- trainer/train_spo.py | 2 +- 9 files changed, 9 insertions(+), 9 deletions(-) diff --git a/trainer/train_distill_reason.py b/trainer/train_distill_reason.py index 525490e..3bb15a8 100644 --- a/trainer/train_distill_reason.py +++ b/trainer/train_distill_reason.py @@ -199,7 +199,7 @@ if __name__ == "__main__": batch_size=args.batch_size, pin_memory=True, drop_last=False, - shuffle=False, + shuffle=(train_sampler is None), num_workers=args.num_workers, sampler=train_sampler ) diff --git a/trainer/train_distillation.py b/trainer/train_distillation.py index 377e5f3..c97860c 100644 --- a/trainer/train_distillation.py +++ b/trainer/train_distillation.py @@ -248,7 +248,7 @@ if __name__ == "__main__": batch_size=args.batch_size, pin_memory=True, drop_last=False, - shuffle=False, + shuffle=(train_sampler is None), num_workers=args.num_workers, sampler=train_sampler ) diff --git a/trainer/train_dpo.py b/trainer/train_dpo.py index 4b72afa..8c7bb45 100644 --- a/trainer/train_dpo.py +++ b/trainer/train_dpo.py @@ -231,7 +231,7 @@ if __name__ == "__main__": batch_size=args.batch_size, pin_memory=True, drop_last=False, - shuffle=False, + shuffle=(train_sampler is None), num_workers=args.num_workers, sampler=train_sampler ) diff --git a/trainer/train_full_sft.py b/trainer/train_full_sft.py index 7c601a4..041801f 100644 --- a/trainer/train_full_sft.py +++ b/trainer/train_full_sft.py @@ -185,7 +185,7 @@ if __name__ == "__main__": batch_size=args.batch_size, pin_memory=True, drop_last=False, - shuffle=False, + shuffle=(train_sampler is None), num_workers=args.num_workers, sampler=train_sampler ) diff --git a/trainer/train_grpo.py b/trainer/train_grpo.py index 7c33a94..704800d 100755 --- a/trainer/train_grpo.py +++ b/trainer/train_grpo.py @@ -299,7 +299,7 @@ if __name__ == "__main__": train_ds = RLAIFDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len) train_sampler = DistributedSampler(train_ds) if ddp else None train_loader = DataLoader(train_ds, batch_size=args.batch_size, pin_memory=True, - drop_last=False, shuffle=False, + drop_last=False, shuffle=(train_sampler is None), num_workers=args.num_workers, sampler=train_sampler) optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate) diff --git a/trainer/train_lora.py b/trainer/train_lora.py index e2659f4..df9f9ae 100644 --- a/trainer/train_lora.py +++ b/trainer/train_lora.py @@ -196,7 +196,7 @@ if __name__ == "__main__": batch_size=args.batch_size, pin_memory=True, drop_last=False, - shuffle=False, + shuffle=(train_sampler is None), num_workers=args.num_workers, sampler=train_sampler ) diff --git a/trainer/train_ppo.py b/trainer/train_ppo.py index 6ddbb64..7652c22 100644 --- a/trainer/train_ppo.py +++ b/trainer/train_ppo.py @@ -344,7 +344,7 @@ if __name__ == "__main__": train_ds = RLAIFDataset(args.data_path, tokenizer, max_length=(args.max_seq_len + args.max_gen_len)) train_sampler = DistributedSampler(train_ds) if ddp else None train_loader = DataLoader(train_ds, batch_size=args.batch_size, pin_memory=True, - drop_last=False, shuffle=False, + drop_last=False, shuffle=(train_sampler is None), num_workers=args.num_workers, sampler=train_sampler) # 初始化优化器 diff --git a/trainer/train_pretrain.py b/trainer/train_pretrain.py index 13b83fc..36a3cd8 100644 --- a/trainer/train_pretrain.py +++ b/trainer/train_pretrain.py @@ -183,7 +183,7 @@ if __name__ == "__main__": batch_size=args.batch_size, pin_memory=True, drop_last=False, - shuffle=False, + shuffle=(train_sampler is None), num_workers=args.num_workers, sampler=train_sampler ) diff --git a/trainer/train_spo.py b/trainer/train_spo.py index 1b81791..e13e741 100755 --- a/trainer/train_spo.py +++ b/trainer/train_spo.py @@ -348,7 +348,7 @@ if __name__ == "__main__": train_ds = RLAIFDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len) train_sampler = DistributedSampler(train_ds) if ddp else None train_loader = DataLoader(train_ds, batch_size=args.batch_size, pin_memory=True, - drop_last=False, shuffle=False, + drop_last=False, shuffle=(train_sampler is None), num_workers=args.num_workers, sampler=train_sampler) optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)