diff --git a/tensorrt_llm/executor/ray_executor.py b/tensorrt_llm/executor/ray_executor.py index e03f524bea..0fc4fa2810 100644 --- a/tensorrt_llm/executor/ray_executor.py +++ b/tensorrt_llm/executor/ray_executor.py @@ -13,7 +13,7 @@ from ray.util.placement_group import (PlacementGroupSchedulingStrategy, placement_group) from tensorrt_llm._ray_utils import unwrap_ray_errors -from tensorrt_llm._utils import get_free_port, nvtx_range_debug +from tensorrt_llm._utils import nvtx_range_debug from tensorrt_llm.logger import logger from ..llmapi.utils import logger_debug @@ -76,7 +76,6 @@ class RayExecutor(RpcExecutorMixin, GenerationExecutor): self.world_size = model_world_size self.tp_size = tp_size self.master_address = ray.util.get_node_ip_address() - self.master_port = get_free_port() self.worker_kwargs = dict( **worker_kwargs, @@ -126,7 +125,6 @@ class RayExecutor(RpcExecutorMixin, GenerationExecutor): runtime_env["env_vars"].update({ "TLLM_DISABLE_MPI": "1", "MASTER_ADDR": self.master_address, # head-IP for NCCL/Gloo - "MASTER_PORT": str(self.master_port) }) placement_groups, self.bundle_indices = self._get_placement_group( @@ -156,6 +154,13 @@ class RayExecutor(RpcExecutorMixin, GenerationExecutor): ray.get(self._get_worker_ready_futures()) except ray.exceptions.ActorDiedError as e: raise RuntimeError("RayGPUWorker died during initialization") from e + port = self.call_all_ray_workers("setup_tcp_store", + leader_only=True, + async_call=False)[0] + self.call_all_ray_workers("setup_distributed_env_and_worker", + leader_only=False, + async_call=False, + port=port) async def init_workers_async(self): self.create_workers(RayGPUWorker, self.worker_kwargs) @@ -163,6 +168,13 @@ class RayExecutor(RpcExecutorMixin, GenerationExecutor): await asyncio.gather(*self._get_worker_ready_futures()) except ray.exceptions.ActorDiedError as e: raise RuntimeError("RayGPUWorker died during initialization") from e + port = (await asyncio.gather(*self.call_all_ray_workers( + "setup_tcp_store", leader_only=True, async_call=True)))[0] + await asyncio.gather( + *self.call_all_ray_workers("setup_distributed_env_and_worker", + leader_only=False, + async_call=True, + port=port)) @unwrap_ray_errors() def call_all_ray_workers(self, func: str, leader_only: bool, diff --git a/tensorrt_llm/executor/ray_gpu_worker.py b/tensorrt_llm/executor/ray_gpu_worker.py index fca5386cb5..864d23d3af 100644 --- a/tensorrt_llm/executor/ray_gpu_worker.py +++ b/tensorrt_llm/executor/ray_gpu_worker.py @@ -1,6 +1,7 @@ import gc import importlib import os +from functools import wraps from pathlib import Path from queue import Queue from typing import Any, List, Optional, Type, Union @@ -43,7 +44,8 @@ class RayWorkerWrapper: def __init__(self, worker_cls, worker_kwargs, world_size, rank): self.master_address = os.environ["MASTER_ADDR"] - self.master_port = os.environ["MASTER_PORT"] + self.world_size = world_size + self.rank = rank # Ray can't pickle TensorRT logger global logger from tensorrt_llm.logger import logger @@ -55,39 +57,83 @@ class RayWorkerWrapper: # Physical gpu id self.gpu = int(ray.get_gpu_ids()[0]) - local_gpu = self.physical_to_local_id(self.gpu) + self.local_gpu = self.physical_to_local_id(self.gpu) - torch.distributed.init_process_group( - backend="cuda:nccl,cpu:gloo", - init_method=f"tcp://{self.master_address}:{self.master_port}", - world_size=world_size, - rank=rank) + torch.cuda.set_device(self.local_gpu) + self.worker_cls = RayWorkerWrapper._inject_worker_extension( + worker_cls, worker_kwargs.pop("ray_worker_extension_cls", None)) + self.worker_kwargs = worker_kwargs + + def _create_tcp_store(self, + port: Optional[int] = None + ) -> torch.distributed.TCPStore: + # port=0 means let the OS pick an available port (only valid for master) + # For non-master, port must be specified to connect to master's port + actual_port = port if port is not None else 0 + return torch.distributed.TCPStore(host_name=self.master_address, + port=actual_port, + world_size=self.world_size, + is_master=(self.rank == 0), + wait_for_workers=False) + + def setup_tcp_store(self): + if self.rank != 0: + raise RuntimeError("Only the master worker can setup TCP store") + self.store = self._create_tcp_store() + return self.store.port + + def setup_distributed_env_and_worker(self, port: int): + if self.rank != 0: + self.store = self._create_tcp_store(port) + + torch.distributed.init_process_group(backend="cuda:nccl,cpu:gloo", + store=self.store, + world_size=self.world_size, + rank=self.rank) logger.info( - f"[Rank {rank}] Finished PG init. Global GPU ID: {self.gpu}, local GPU ID: {local_gpu}" + f"[Rank {self.rank}] Finished PG init. Global GPU ID: {self.gpu}, local GPU ID: {self.local_gpu}" ) - torch.cuda.set_device(local_gpu) + self.worker = self.worker_cls(device_id=self.local_gpu, + **self.worker_kwargs) + self._has_setup_distributed_env_and_worker = True - worker_cls = RayWorkerWrapper._inject_worker_extension( - worker_cls, worker_kwargs.pop("ray_worker_extension_cls", None)) - self.worker = worker_cls(device_id=local_gpu, **worker_kwargs) + @property + def has_setup_distributed_env_and_worker(self) -> bool: + return getattr(self, '_has_setup_distributed_env_and_worker', False) + def ensure_distributed_setup(func): + + @wraps(func) + def wrapper(self, *args, **kwargs): + if not self.has_setup_distributed_env_and_worker: + raise RuntimeError( + "Have not setup distributed environment and worker yet") + return func(self, *args, **kwargs) + + return wrapper + + @ensure_distributed_setup def submit(self, request: GenerationRequest) -> GenerationResult: return self.worker.submit(request) + @ensure_distributed_setup def enqueue_request(self, request: GenerationRequest, result_wait_queue: Queue | None = None) -> int: return self.worker.enqueue_request(request, result_wait_queue) + @ensure_distributed_setup def abort_request(self, request_id: int) -> None: self.worker.abort_request(request_id) + @ensure_distributed_setup def report_device_id(self) -> str: local_id = self.physical_to_local_id(self.gpu) return get_device_uuid(local_id) + @ensure_distributed_setup def call_worker_method(self, method_name: str, *args, **kwargs): """Generic method to call any method on the underlying worker.""" if hasattr(self.worker, method_name): @@ -103,7 +149,8 @@ class RayWorkerWrapper: f"The RayGPUWorker has no method called '{method_name}'.") def shutdown(self): - return self.worker.shutdown() + if hasattr(self, 'worker'): + self.worker.shutdown() def __repr__(self) -> str: """Customizes the actor's prefix in the Ray logs.