mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-06-06 00:04:50 +00:00
[update] shuffle data
This commit is contained in:
@@ -221,13 +221,14 @@ if __name__ == "__main__":
|
||||
# ========== 8. 开始训练 ==========
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
train_sampler and train_sampler.set_epoch(epoch)
|
||||
if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点
|
||||
batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1)
|
||||
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
|
||||
skip = start_step + 1 if (epoch == start_epoch and start_step > 0) else 0
|
||||
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
|
||||
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
if skip > 0:
|
||||
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始')
|
||||
train_epoch(epoch, loader, len(loader) + start_step + 1, teacher_model, lm_config_student, start_step, wandb, args.alpha, args.temperature)
|
||||
else: # 默认从头开始
|
||||
loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=(train_sampler is None), sampler=train_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
train_epoch(epoch, loader, len(loader) + skip, teacher_model, lm_config_student, start_step, wandb, args.alpha, args.temperature)
|
||||
else:
|
||||
train_epoch(epoch, loader, len(loader), teacher_model, lm_config_student, 0, wandb, args.alpha, args.temperature)
|
||||
|
||||
# ========== 9. 清理分布进程 ==========
|
||||
|
||||
@@ -205,13 +205,14 @@ if __name__ == "__main__":
|
||||
# ========== 8. 开始训练 ==========
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
train_sampler and train_sampler.set_epoch(epoch)
|
||||
if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点
|
||||
batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1)
|
||||
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
|
||||
skip = start_step + 1 if (epoch == start_epoch and start_step > 0) else 0
|
||||
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
|
||||
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
if skip > 0:
|
||||
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始')
|
||||
train_epoch(epoch, loader, len(loader) + start_step + 1, ref_model, lm_config, start_step, wandb, args.beta)
|
||||
else: # 默认从头开始
|
||||
loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=(train_sampler is None), sampler=train_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
train_epoch(epoch, loader, len(loader) + skip, ref_model, lm_config, start_step, wandb, args.beta)
|
||||
else:
|
||||
train_epoch(epoch, loader, len(loader), ref_model, lm_config, 0, wandb, args.beta)
|
||||
|
||||
# ========== 9. 清理分布进程 ==========
|
||||
|
||||
@@ -148,13 +148,14 @@ if __name__ == "__main__":
|
||||
# ========== 8. 开始训练 ==========
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
train_sampler and train_sampler.set_epoch(epoch)
|
||||
if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点
|
||||
batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1)
|
||||
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
|
||||
skip = start_step + 1 if (epoch == start_epoch and start_step > 0) else 0
|
||||
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
|
||||
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
if skip > 0:
|
||||
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始')
|
||||
train_epoch(epoch, loader, len(loader) + start_step + 1, start_step, wandb)
|
||||
else: # 默认从头开始
|
||||
loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=(train_sampler is None), sampler=train_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
train_epoch(epoch, loader, len(loader) + skip, start_step, wandb)
|
||||
else:
|
||||
train_epoch(epoch, loader, len(loader), 0, wandb)
|
||||
|
||||
# ========== 9. 清理分布进程 ==========
|
||||
|
||||
@@ -290,15 +290,14 @@ if __name__ == "__main__":
|
||||
# ========== 8. 开始训练 ==========
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
train_sampler and train_sampler.set_epoch(epoch)
|
||||
if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点
|
||||
batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1)
|
||||
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
|
||||
skip = start_step + 1 if (epoch == start_epoch and start_step > 0) else 0
|
||||
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
|
||||
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
if skip > 0:
|
||||
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始')
|
||||
grpo_train_epoch(epoch, loader, len(loader) + start_step + 1, ref_model, reward_model, reward_tokenizer, start_step, wandb)
|
||||
else: # 默认从头开始
|
||||
loader = DataLoader(train_ds, batch_size=args.batch_size, pin_memory=True,
|
||||
drop_last=False, shuffle=(train_sampler is None),
|
||||
num_workers=args.num_workers, sampler=train_sampler)
|
||||
grpo_train_epoch(epoch, loader, len(loader) + skip, ref_model, reward_model, reward_tokenizer, start_step, wandb)
|
||||
else:
|
||||
grpo_train_epoch(epoch, loader, len(loader), ref_model, reward_model, reward_tokenizer, 0, wandb)
|
||||
|
||||
# ========== 9. 清理分布进程 ==========
|
||||
|
||||
@@ -161,13 +161,14 @@ if __name__ == "__main__":
|
||||
# ========== 9. 开始训练 ==========
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
train_sampler and train_sampler.set_epoch(epoch)
|
||||
if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点
|
||||
batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1)
|
||||
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
|
||||
skip = start_step + 1 if (epoch == start_epoch and start_step > 0) else 0
|
||||
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
|
||||
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
if skip > 0:
|
||||
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始')
|
||||
train_epoch(epoch, loader, len(loader) + start_step + 1, lora_params, start_step, wandb)
|
||||
else: # 默认从头开始
|
||||
loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=(train_sampler is None), sampler=train_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
train_epoch(epoch, loader, len(loader) + skip, lora_params, start_step, wandb)
|
||||
else:
|
||||
train_epoch(epoch, loader, len(loader), lora_params, 0, wandb)
|
||||
|
||||
# ========== 10. 清理分布进程 ==========
|
||||
|
||||
@@ -365,15 +365,15 @@ if __name__ == "__main__":
|
||||
# ========== 8. 开始训练 ==========
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
train_sampler and train_sampler.set_epoch(epoch)
|
||||
if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点
|
||||
batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1)
|
||||
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
|
||||
skip = start_step + 1 if (epoch == start_epoch and start_step > 0) else 0
|
||||
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
|
||||
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
if skip > 0:
|
||||
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始')
|
||||
ppo_train_epoch(epoch, loader, len(loader) + start_step + 1, old_actor_model, ref_model,
|
||||
ppo_train_epoch(epoch, loader, len(loader) + skip, old_actor_model, ref_model,
|
||||
actor_scheduler, critic_scheduler, reward_model, reward_tokenizer, start_step, wandb)
|
||||
else: # 默认从头开始
|
||||
loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=(train_sampler is None),
|
||||
sampler=train_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
else:
|
||||
ppo_train_epoch(epoch, loader, len(loader), old_actor_model, ref_model,
|
||||
actor_scheduler, critic_scheduler, reward_model, reward_tokenizer, 0, wandb)
|
||||
|
||||
|
||||
@@ -147,13 +147,14 @@ if __name__ == "__main__":
|
||||
# ========== 8. 开始训练 ==========
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
train_sampler and train_sampler.set_epoch(epoch)
|
||||
if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点
|
||||
batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1)
|
||||
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
|
||||
skip = start_step + 1 if (epoch == start_epoch and start_step > 0) else 0
|
||||
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
|
||||
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
if skip > 0:
|
||||
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始')
|
||||
train_epoch(epoch, loader, len(loader) + start_step + 1, start_step, wandb)
|
||||
else: # 默认从头开始
|
||||
loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=(train_sampler is None), sampler=train_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
train_epoch(epoch, loader, len(loader) + skip, start_step, wandb)
|
||||
else:
|
||||
train_epoch(epoch, loader, len(loader), 0, wandb)
|
||||
|
||||
# ========== 9. 清理分布进程 ==========
|
||||
|
||||
@@ -165,13 +165,14 @@ if __name__ == "__main__":
|
||||
# ========== 8. 开始训练 ==========
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
train_sampler and train_sampler.set_epoch(epoch)
|
||||
if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点
|
||||
batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1)
|
||||
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
|
||||
skip = start_step + 1 if (epoch == start_epoch and start_step > 0) else 0
|
||||
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
|
||||
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
if skip > 0:
|
||||
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始')
|
||||
train_epoch(epoch, loader, len(loader) + start_step + 1, tokenizer, lm_config, start_step, wandb)
|
||||
else: # 默认从头开始
|
||||
loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=(train_sampler is None), sampler=train_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
train_epoch(epoch, loader, len(loader) + skip, tokenizer, lm_config, start_step, wandb)
|
||||
else:
|
||||
train_epoch(epoch, loader, len(loader), tokenizer, lm_config, 0, wandb)
|
||||
|
||||
# ========== 9. 清理分布进程 ==========
|
||||
|
||||
@@ -340,15 +340,14 @@ if __name__ == "__main__":
|
||||
# ========== 8. 开始训练 ==========
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
train_sampler and train_sampler.set_epoch(epoch)
|
||||
if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点
|
||||
batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1)
|
||||
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
|
||||
skip = start_step + 1 if (epoch == start_epoch and start_step > 0) else 0
|
||||
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
|
||||
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
if skip > 0:
|
||||
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始')
|
||||
spo_train_epoch(epoch, loader, len(loader) + start_step + 1, ref_model, reward_model, reward_tokenizer, value_tracker, start_step, wandb)
|
||||
else: # 默认从头开始
|
||||
loader = DataLoader(train_ds, batch_size=args.batch_size, pin_memory=True,
|
||||
drop_last=False, shuffle=(train_sampler is None),
|
||||
num_workers=args.num_workers, sampler=train_sampler)
|
||||
spo_train_epoch(epoch, loader, len(loader) + skip, ref_model, reward_model, reward_tokenizer, value_tracker, start_step, wandb)
|
||||
else:
|
||||
spo_train_epoch(epoch, loader, len(loader), ref_model, reward_model, reward_tokenizer, value_tracker, 0, wandb)
|
||||
|
||||
# ========== 9. 清理分布进程 ==========
|
||||
|
||||
Reference in New Issue
Block a user