mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-01-13 19:57:20 +08:00
[fix] dist cleanup
This commit is contained in:
parent
9d898576ac
commit
42a4e8c86a
@ -176,3 +176,6 @@ if __name__ == "__main__":
|
||||
else: # 默认从头开始
|
||||
loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=(train_sampler is None), sampler=train_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
train_epoch(epoch, loader, len(loader), tokenizer, lm_config, 0, wandb)
|
||||
|
||||
# ========== 9. 清理分布进程 ==========
|
||||
if dist.is_initialized(): dist.destroy_process_group()
|
||||
@ -226,3 +226,6 @@ if __name__ == "__main__":
|
||||
else: # 默认从头开始
|
||||
loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=(train_sampler is None), sampler=train_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
train_epoch(epoch, loader, len(loader), teacher_model, lm_config_student, 0, wandb, args.alpha, args.temperature)
|
||||
|
||||
# ========== 9. 清理分布进程 ==========
|
||||
if dist.is_initialized(): dist.destroy_process_group()
|
||||
@ -211,3 +211,6 @@ if __name__ == "__main__":
|
||||
else: # 默认从头开始
|
||||
loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=(train_sampler is None), sampler=train_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
train_epoch(epoch, loader, len(loader), ref_model, lm_config, 0, wandb, args.beta)
|
||||
|
||||
# ========== 9. 清理分布进程 ==========
|
||||
if dist.is_initialized(): dist.destroy_process_group()
|
||||
@ -164,3 +164,6 @@ if __name__ == "__main__":
|
||||
else: # 默认从头开始
|
||||
loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=(train_sampler is None), sampler=train_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
train_epoch(epoch, loader, len(loader), 0, wandb)
|
||||
|
||||
# ========== 9. 清理分布进程 ==========
|
||||
if dist.is_initialized(): dist.destroy_process_group()
|
||||
@ -294,3 +294,6 @@ if __name__ == "__main__":
|
||||
drop_last=False, shuffle=(train_sampler is None),
|
||||
num_workers=args.num_workers, sampler=train_sampler)
|
||||
grpo_train_epoch(epoch, loader, len(loader), ref_model, reward_model, reward_tokenizer, 0, wandb)
|
||||
|
||||
# ========== 9. 清理分布进程 ==========
|
||||
if dist.is_initialized(): dist.destroy_process_group()
|
||||
@ -177,3 +177,6 @@ if __name__ == "__main__":
|
||||
else: # 默认从头开始
|
||||
loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=(train_sampler is None), sampler=train_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
train_epoch(epoch, loader, len(loader), lora_params, 0, wandb)
|
||||
|
||||
# ========== 10. 清理分布进程 ==========
|
||||
if dist.is_initialized(): dist.destroy_process_group()
|
||||
@ -368,3 +368,6 @@ if __name__ == "__main__":
|
||||
sampler=train_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
ppo_train_epoch(epoch, loader, len(loader), old_actor_model, ref_model,
|
||||
actor_scheduler, critic_scheduler, reward_model, reward_tokenizer, 0, wandb)
|
||||
|
||||
# ========== 9. 清理分布进程 ==========
|
||||
if dist.is_initialized(): dist.destroy_process_group()
|
||||
@ -163,3 +163,6 @@ if __name__ == "__main__":
|
||||
else: # 默认从头开始
|
||||
loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=(train_sampler is None), sampler=train_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
train_epoch(epoch, loader, len(loader), 0, wandb)
|
||||
|
||||
# ========== 9. 清理分布进程 ==========
|
||||
if dist.is_initialized(): dist.destroy_process_group()
|
||||
@ -344,3 +344,6 @@ if __name__ == "__main__":
|
||||
drop_last=False, shuffle=(train_sampler is None),
|
||||
num_workers=args.num_workers, sampler=train_sampler)
|
||||
spo_train_epoch(epoch, loader, len(loader), ref_model, reward_model, reward_tokenizer, value_tracker, 0, wandb)
|
||||
|
||||
# ========== 9. 清理分布进程 ==========
|
||||
if dist.is_initialized(): dist.destroy_process_group()
|
||||
Loading…
Reference in New Issue
Block a user