[fix] dist cleanup

This commit is contained in:
jingyaogong 2026-01-02 22:25:55 +08:00
parent 9d898576ac
commit 42a4e8c86a
9 changed files with 27 additions and 0 deletions

View File

@ -176,3 +176,6 @@ if __name__ == "__main__":
else: # 默认从头开始
loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=(train_sampler is None), sampler=train_sampler, num_workers=args.num_workers, pin_memory=True)
train_epoch(epoch, loader, len(loader), tokenizer, lm_config, 0, wandb)
# ========== 9. 清理分布进程 ==========
if dist.is_initialized(): dist.destroy_process_group()

View File

@ -226,3 +226,6 @@ if __name__ == "__main__":
else: # 默认从头开始
loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=(train_sampler is None), sampler=train_sampler, num_workers=args.num_workers, pin_memory=True)
train_epoch(epoch, loader, len(loader), teacher_model, lm_config_student, 0, wandb, args.alpha, args.temperature)
# ========== 9. 清理分布进程 ==========
if dist.is_initialized(): dist.destroy_process_group()

View File

@ -211,3 +211,6 @@ if __name__ == "__main__":
else: # 默认从头开始
loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=(train_sampler is None), sampler=train_sampler, num_workers=args.num_workers, pin_memory=True)
train_epoch(epoch, loader, len(loader), ref_model, lm_config, 0, wandb, args.beta)
# ========== 9. 清理分布进程 ==========
if dist.is_initialized(): dist.destroy_process_group()

View File

@ -164,3 +164,6 @@ if __name__ == "__main__":
else: # 默认从头开始
loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=(train_sampler is None), sampler=train_sampler, num_workers=args.num_workers, pin_memory=True)
train_epoch(epoch, loader, len(loader), 0, wandb)
# ========== 9. 清理分布进程 ==========
if dist.is_initialized(): dist.destroy_process_group()

View File

@ -294,3 +294,6 @@ if __name__ == "__main__":
drop_last=False, shuffle=(train_sampler is None),
num_workers=args.num_workers, sampler=train_sampler)
grpo_train_epoch(epoch, loader, len(loader), ref_model, reward_model, reward_tokenizer, 0, wandb)
# ========== 9. 清理分布进程 ==========
if dist.is_initialized(): dist.destroy_process_group()

View File

@ -177,3 +177,6 @@ if __name__ == "__main__":
else: # 默认从头开始
loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=(train_sampler is None), sampler=train_sampler, num_workers=args.num_workers, pin_memory=True)
train_epoch(epoch, loader, len(loader), lora_params, 0, wandb)
# ========== 10. 清理分布进程 ==========
if dist.is_initialized(): dist.destroy_process_group()

View File

@ -368,3 +368,6 @@ if __name__ == "__main__":
sampler=train_sampler, num_workers=args.num_workers, pin_memory=True)
ppo_train_epoch(epoch, loader, len(loader), old_actor_model, ref_model,
actor_scheduler, critic_scheduler, reward_model, reward_tokenizer, 0, wandb)
# ========== 9. 清理分布进程 ==========
if dist.is_initialized(): dist.destroy_process_group()

View File

@ -163,3 +163,6 @@ if __name__ == "__main__":
else: # 默认从头开始
loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=(train_sampler is None), sampler=train_sampler, num_workers=args.num_workers, pin_memory=True)
train_epoch(epoch, loader, len(loader), 0, wandb)
# ========== 9. 清理分布进程 ==========
if dist.is_initialized(): dist.destroy_process_group()

View File

@ -344,3 +344,6 @@ if __name__ == "__main__":
drop_last=False, shuffle=(train_sampler is None),
num_workers=args.num_workers, sampler=train_sampler)
spo_train_epoch(epoch, loader, len(loader), ref_model, reward_model, reward_tokenizer, value_tracker, 0, wandb)
# ========== 9. 清理分布进程 ==========
if dist.is_initialized(): dist.destroy_process_group()