diff --git a/tensorrt_llm/_torch/auto_deploy/distributed/common.py b/tensorrt_llm/_torch/auto_deploy/distributed/common.py index d680a9fa4c..d9c698b30b 100644 --- a/tensorrt_llm/_torch/auto_deploy/distributed/common.py +++ b/tensorrt_llm/_torch/auto_deploy/distributed/common.py @@ -103,6 +103,13 @@ def is_torchelastic(): return "TORCHELASTIC_RUN_ID" in os.environ +def cleanup(): + """Destroy process group when the program exits.""" + if dist.is_initialized(): + ad_logger.info("Destroying process group") + dist.destroy_process_group() + + def initialize(rank: int = 0, world_size: int = 1, port: Optional[int] = None) -> Tuple[int, int]: if is_ompi(): lib = "OMPI" @@ -136,6 +143,9 @@ def initialize(rank: int = 0, world_size: int = 1, port: Optional[int] = None) - # We use nccl backend dist.init_process_group("nccl", world_size=world_size, rank=local_rank) + # Register cleanup function to be called at exit + atexit.register(cleanup) + # set a manual seed for reproducibility torch.manual_seed(1111) @@ -153,6 +163,9 @@ def init_and_run_process(job, rank, size, port, **kwargs): kwargs[q].put(None) kwargs[q].close() raise e + finally: + # Make sure to clean up even if an exception occurs + cleanup() def _start_multiprocess_job( @@ -275,3 +288,7 @@ class MultiProcessExecutor: q.join_thread() self.output_queue.close() self.output_queue.join_thread() + + # Make sure all process groups are cleaned up + if dist.is_initialized(): + dist.destroy_process_group() diff --git a/tensorrt_llm/_torch/auto_deploy/shim/demollm.py b/tensorrt_llm/_torch/auto_deploy/shim/demollm.py index 3de91df6f3..37d0c35440 100644 --- a/tensorrt_llm/_torch/auto_deploy/shim/demollm.py +++ b/tensorrt_llm/_torch/auto_deploy/shim/demollm.py @@ -391,6 +391,13 @@ class DemoLLM(LLM): device="cuda", ) + def __del__(self): + """Ensure proper cleanup of distributed resources.""" + if hasattr(self, "_executor") and self._executor is not None: + self._executor.shutdown() + # Call cleanup to ensure process group is properly destroyed + dist_ad.cleanup() + @staticmethod def _handle_response(request_output: RequestOutput, response: List[CompletionOutput]): request_output._done = True