[fix] ddp exit hang

This commit is contained in:
jingyaogong
2026-06-01 17:50:39 +08:00
parent 4a68da72d5
commit 3f1a7cc25b
8 changed files with 24 additions and 8 deletions
+3 -1
View File
@@ -487,4 +487,6 @@ if __name__ == "__main__":
else: else:
rl_train_epoch(epoch, loader, len(loader), rollout_engine, ref_model, reward_model, 0, wandb, use_sglang = (args.rollout_engine == "sglang")) rl_train_epoch(epoch, loader, len(loader), rollout_engine, ref_model, reward_model, 0, wandb, use_sglang = (args.rollout_engine == "sglang"))
if dist.is_initialized(): dist.destroy_process_group() if dist.is_initialized():
dist.barrier()
dist.destroy_process_group()
+3 -1
View File
@@ -243,4 +243,6 @@ if __name__ == "__main__":
train_epoch(epoch, loader, len(loader), teacher_model, lm_config_student, 0, wandb, args.alpha, args.temperature) train_epoch(epoch, loader, len(loader), teacher_model, lm_config_student, 0, wandb, args.alpha, args.temperature)
# ========== 9. 清理分布进程 ========== # ========== 9. 清理分布进程 ==========
if dist.is_initialized(): dist.destroy_process_group() if dist.is_initialized():
dist.barrier()
dist.destroy_process_group()
+3 -1
View File
@@ -223,4 +223,6 @@ if __name__ == "__main__":
train_epoch(epoch, loader, len(loader), ref_model, lm_config, 0, wandb, args.beta) train_epoch(epoch, loader, len(loader), ref_model, lm_config, 0, wandb, args.beta)
# ========== 9. 清理分布进程 ========== # ========== 9. 清理分布进程 ==========
if dist.is_initialized(): dist.destroy_process_group() if dist.is_initialized():
dist.barrier()
dist.destroy_process_group()
+3 -1
View File
@@ -168,4 +168,6 @@ if __name__ == "__main__":
train_epoch(epoch, loader, len(loader), 0, wandb) train_epoch(epoch, loader, len(loader), 0, wandb)
# ========== 9. 清理分布进程 ========== # ========== 9. 清理分布进程 ==========
if dist.is_initialized(): dist.destroy_process_group() if dist.is_initialized():
dist.barrier()
dist.destroy_process_group()
+3 -1
View File
@@ -329,4 +329,6 @@ if __name__ == "__main__":
grpo_train_epoch(epoch, loader, len(loader), rollout_engine, ref_model, reward_model, 0, wandb, use_sglang = (args.rollout_engine == "sglang")) grpo_train_epoch(epoch, loader, len(loader), rollout_engine, ref_model, reward_model, 0, wandb, use_sglang = (args.rollout_engine == "sglang"))
# ========== 9. 清理分布进程 ========== # ========== 9. 清理分布进程 ==========
if dist.is_initialized(): dist.destroy_process_group() if dist.is_initialized():
dist.barrier()
dist.destroy_process_group()
+3 -1
View File
@@ -181,4 +181,6 @@ if __name__ == "__main__":
train_epoch(epoch, loader, len(loader), lora_params, 0, wandb) train_epoch(epoch, loader, len(loader), lora_params, 0, wandb)
# ========== 10. 清理分布进程 ========== # ========== 10. 清理分布进程 ==========
if dist.is_initialized(): dist.destroy_process_group() if dist.is_initialized():
dist.barrier()
dist.destroy_process_group()
+3 -1
View File
@@ -432,4 +432,6 @@ if __name__ == "__main__":
ppo_train_epoch(epoch, loader, len(loader), rollout_engine, ref_model, actor_scheduler, critic_scheduler, reward_model, 0, wandb, use_sglang = (args.rollout_engine == "sglang")) ppo_train_epoch(epoch, loader, len(loader), rollout_engine, ref_model, actor_scheduler, critic_scheduler, reward_model, 0, wandb, use_sglang = (args.rollout_engine == "sglang"))
# ========== 9. 清理分布进程 ========== # ========== 9. 清理分布进程 ==========
if dist.is_initialized(): dist.destroy_process_group() if dist.is_initialized():
dist.barrier()
dist.destroy_process_group()
+3 -1
View File
@@ -167,4 +167,6 @@ if __name__ == "__main__":
train_epoch(epoch, loader, len(loader), 0, wandb) train_epoch(epoch, loader, len(loader), 0, wandb)
# ========== 9. 清理分布进程 ========== # ========== 9. 清理分布进程 ==========
if dist.is_initialized(): dist.destroy_process_group() if dist.is_initialized():
dist.barrier()
dist.destroy_process_group()