[TRTLLM-9784][fix] Resolve port conflicts (#9780)

Signed-off-by: Shuyi Xiong <219646547+shuyixiong@users.noreply.github.com>
This commit is contained in:
shuyixiong 2025-12-13 14:10:01 +08:00 committed by GitHub
parent e49c70f6df
commit 7fc720a397
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 75 additions and 16 deletions

View File

@ -13,7 +13,7 @@ from ray.util.placement_group import (PlacementGroupSchedulingStrategy,
placement_group)
from tensorrt_llm._ray_utils import unwrap_ray_errors
from tensorrt_llm._utils import get_free_port, nvtx_range_debug
from tensorrt_llm._utils import nvtx_range_debug
from tensorrt_llm.logger import logger
from ..llmapi.utils import logger_debug
@ -76,7 +76,6 @@ class RayExecutor(RpcExecutorMixin, GenerationExecutor):
self.world_size = model_world_size
self.tp_size = tp_size
self.master_address = ray.util.get_node_ip_address()
self.master_port = get_free_port()
self.worker_kwargs = dict(
**worker_kwargs,
@ -126,7 +125,6 @@ class RayExecutor(RpcExecutorMixin, GenerationExecutor):
runtime_env["env_vars"].update({
"TLLM_DISABLE_MPI": "1",
"MASTER_ADDR": self.master_address, # head-IP for NCCL/Gloo
"MASTER_PORT": str(self.master_port)
})
placement_groups, self.bundle_indices = self._get_placement_group(
@ -156,6 +154,13 @@ class RayExecutor(RpcExecutorMixin, GenerationExecutor):
ray.get(self._get_worker_ready_futures())
except ray.exceptions.ActorDiedError as e:
raise RuntimeError("RayGPUWorker died during initialization") from e
port = self.call_all_ray_workers("setup_tcp_store",
leader_only=True,
async_call=False)[0]
self.call_all_ray_workers("setup_distributed_env_and_worker",
leader_only=False,
async_call=False,
port=port)
async def init_workers_async(self):
self.create_workers(RayGPUWorker, self.worker_kwargs)
@ -163,6 +168,13 @@ class RayExecutor(RpcExecutorMixin, GenerationExecutor):
await asyncio.gather(*self._get_worker_ready_futures())
except ray.exceptions.ActorDiedError as e:
raise RuntimeError("RayGPUWorker died during initialization") from e
port = (await asyncio.gather(*self.call_all_ray_workers(
"setup_tcp_store", leader_only=True, async_call=True)))[0]
await asyncio.gather(
*self.call_all_ray_workers("setup_distributed_env_and_worker",
leader_only=False,
async_call=True,
port=port))
@unwrap_ray_errors()
def call_all_ray_workers(self, func: str, leader_only: bool,

View File

@ -1,6 +1,7 @@
import gc
import importlib
import os
from functools import wraps
from pathlib import Path
from queue import Queue
from typing import Any, List, Optional, Type, Union
@ -43,7 +44,8 @@ class RayWorkerWrapper:
def __init__(self, worker_cls, worker_kwargs, world_size, rank):
self.master_address = os.environ["MASTER_ADDR"]
self.master_port = os.environ["MASTER_PORT"]
self.world_size = world_size
self.rank = rank
# Ray can't pickle TensorRT logger
global logger
from tensorrt_llm.logger import logger
@ -55,39 +57,83 @@ class RayWorkerWrapper:
# Physical gpu id
self.gpu = int(ray.get_gpu_ids()[0])
local_gpu = self.physical_to_local_id(self.gpu)
self.local_gpu = self.physical_to_local_id(self.gpu)
torch.distributed.init_process_group(
backend="cuda:nccl,cpu:gloo",
init_method=f"tcp://{self.master_address}:{self.master_port}",
world_size=world_size,
rank=rank)
torch.cuda.set_device(self.local_gpu)
self.worker_cls = RayWorkerWrapper._inject_worker_extension(
worker_cls, worker_kwargs.pop("ray_worker_extension_cls", None))
self.worker_kwargs = worker_kwargs
def _create_tcp_store(self,
port: Optional[int] = None
) -> torch.distributed.TCPStore:
# port=0 means let the OS pick an available port (only valid for master)
# For non-master, port must be specified to connect to master's port
actual_port = port if port is not None else 0
return torch.distributed.TCPStore(host_name=self.master_address,
port=actual_port,
world_size=self.world_size,
is_master=(self.rank == 0),
wait_for_workers=False)
def setup_tcp_store(self):
if self.rank != 0:
raise RuntimeError("Only the master worker can setup TCP store")
self.store = self._create_tcp_store()
return self.store.port
def setup_distributed_env_and_worker(self, port: int):
if self.rank != 0:
self.store = self._create_tcp_store(port)
torch.distributed.init_process_group(backend="cuda:nccl,cpu:gloo",
store=self.store,
world_size=self.world_size,
rank=self.rank)
logger.info(
f"[Rank {rank}] Finished PG init. Global GPU ID: {self.gpu}, local GPU ID: {local_gpu}"
f"[Rank {self.rank}] Finished PG init. Global GPU ID: {self.gpu}, local GPU ID: {self.local_gpu}"
)
torch.cuda.set_device(local_gpu)
self.worker = self.worker_cls(device_id=self.local_gpu,
**self.worker_kwargs)
self._has_setup_distributed_env_and_worker = True
worker_cls = RayWorkerWrapper._inject_worker_extension(
worker_cls, worker_kwargs.pop("ray_worker_extension_cls", None))
self.worker = worker_cls(device_id=local_gpu, **worker_kwargs)
@property
def has_setup_distributed_env_and_worker(self) -> bool:
return getattr(self, '_has_setup_distributed_env_and_worker', False)
def ensure_distributed_setup(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
if not self.has_setup_distributed_env_and_worker:
raise RuntimeError(
"Have not setup distributed environment and worker yet")
return func(self, *args, **kwargs)
return wrapper
@ensure_distributed_setup
def submit(self, request: GenerationRequest) -> GenerationResult:
return self.worker.submit(request)
@ensure_distributed_setup
def enqueue_request(self,
request: GenerationRequest,
result_wait_queue: Queue | None = None) -> int:
return self.worker.enqueue_request(request, result_wait_queue)
@ensure_distributed_setup
def abort_request(self, request_id: int) -> None:
self.worker.abort_request(request_id)
@ensure_distributed_setup
def report_device_id(self) -> str:
local_id = self.physical_to_local_id(self.gpu)
return get_device_uuid(local_id)
@ensure_distributed_setup
def call_worker_method(self, method_name: str, *args, **kwargs):
"""Generic method to call any method on the underlying worker."""
if hasattr(self.worker, method_name):
@ -103,7 +149,8 @@ class RayWorkerWrapper:
f"The RayGPUWorker has no method called '{method_name}'.")
def shutdown(self):
return self.worker.shutdown()
if hasattr(self, 'worker'):
self.worker.shutdown()
def __repr__(self) -> str:
"""Customizes the actor's prefix in the Ray logs.