mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
453 lines
18 KiB
Python
453 lines
18 KiB
Python
import asyncio
|
|
import os
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
try:
|
|
import ray
|
|
except ModuleNotFoundError as e:
|
|
e.msg = """Cannot import Ray. Please install 'ray' package to use ray orchestrator"""
|
|
raise
|
|
|
|
from ray.util.placement_group import (PlacementGroupSchedulingStrategy,
|
|
get_current_placement_group,
|
|
placement_group)
|
|
|
|
from tensorrt_llm._ray_utils import unwrap_ray_errors
|
|
from tensorrt_llm._utils import nvtx_range_debug
|
|
from tensorrt_llm.logger import logger
|
|
|
|
from ..llmapi.utils import logger_debug
|
|
from .executor import GenerationExecutor
|
|
from .postproc_worker import PostprocWorkerConfig
|
|
from .ray_gpu_worker import RayGPUWorker, RayWorkerWrapper
|
|
from .request import GenerationRequest
|
|
from .result import GenerationResult
|
|
from .rpc_proxy_mixin import RpcExecutorMixin
|
|
from .utils import has_event_loop
|
|
|
|
__all__ = [
|
|
"RayExecutor",
|
|
]
|
|
|
|
|
|
class RayExecutor(RpcExecutorMixin, GenerationExecutor):
|
|
|
|
def __init__(self,
|
|
worker_kwargs: Dict,
|
|
model_world_size: int,
|
|
postproc_worker_config: PostprocWorkerConfig,
|
|
is_llm_executor: bool,
|
|
tp_size=1):
|
|
os.environ['RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES'] = '1'
|
|
os.environ["RAY_DEDUP_LOGS"] = "0" # for debug
|
|
|
|
super().__init__(model_world_size, postproc_worker_config,
|
|
is_llm_executor)
|
|
|
|
self.has_start_local_cluser = False
|
|
runtime_env = {
|
|
"env_vars": {
|
|
"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1"
|
|
}
|
|
}
|
|
|
|
ray_init_args = {
|
|
"include_dashboard": False,
|
|
"namespace": "trtllm",
|
|
"ignore_reinit_error": True,
|
|
"runtime_env": runtime_env
|
|
}
|
|
|
|
try:
|
|
if os.environ.get("TLLM_RAY_FORCE_LOCAL_CLUSTER", "0") != "1":
|
|
try:
|
|
ray.init(address="auto", **ray_init_args)
|
|
logger.info(f"Attached to an existing Ray cluster.")
|
|
except ConnectionError:
|
|
logger.info(f"Ray cluster not found, starting a new one.")
|
|
|
|
if not ray.is_initialized():
|
|
ray.init(**ray_init_args)
|
|
self.has_start_local_cluser = True
|
|
else:
|
|
ray.init(address="local", **ray_init_args)
|
|
self.has_start_local_cluser = True
|
|
|
|
self.world_size = model_world_size
|
|
self.tp_size = tp_size
|
|
self.master_address = ray.util.get_node_ip_address()
|
|
|
|
self.worker_kwargs = dict(
|
|
**worker_kwargs,
|
|
postproc_worker_config=postproc_worker_config,
|
|
is_llm_executor=is_llm_executor)
|
|
|
|
self.init_rpc_executor()
|
|
# Inject the generated HMAC key into worker_kwargs for workers
|
|
self.worker_kwargs['hmac_key'] = self.hmac_key
|
|
self.worker_kwargs['rpc_addr'] = self.rpc_addr
|
|
|
|
placement_config = getattr(self.worker_kwargs['llm_args'],
|
|
'ray_placement_config', None)
|
|
defer_workers_init = placement_config.defer_workers_init if placement_config else False
|
|
|
|
if defer_workers_init:
|
|
self.workers = [
|
|
] # Placeholder, will be initialized in setup_async
|
|
self._mainloop_started = False # DO NOT start mainloop until after setup_engine_remote_async is called
|
|
else:
|
|
if not has_event_loop():
|
|
self.init_workers_sync()
|
|
self.setup_engine_remote()
|
|
self.setup_mainloop(tasks=[self._fetch_responses_loop_async],
|
|
thread_name="ray_executor_main_loop")
|
|
|
|
except Exception as e:
|
|
self.shutdown()
|
|
logger.error(f"Failed to initialize RayExecutor: {e}")
|
|
raise e
|
|
|
|
def create_workers(self, worker_cls, worker_kwargs):
|
|
llm_args = worker_kwargs.get("llm_args")
|
|
placement_config = getattr(llm_args, 'ray_placement_config',
|
|
None) if llm_args else None
|
|
|
|
# When set to be a fraction, it allows Ray to schedule
|
|
# multiple actors on a single GPU for colocate use cases.
|
|
num_gpus = float(os.getenv("TRTLLM_RAY_PER_WORKER_GPUS", "1.0"))
|
|
if placement_config and placement_config.per_worker_gpu_share is not None:
|
|
num_gpus = placement_config.per_worker_gpu_share
|
|
|
|
logger.debug(f"{num_gpus=} for each worker.")
|
|
|
|
runtime_env = ray.runtime_env.RuntimeEnv()
|
|
runtime_env["env_vars"] = os.environ.copy()
|
|
runtime_env["env_vars"].update({
|
|
"TLLM_DISABLE_MPI": "1",
|
|
"MASTER_ADDR": self.master_address, # head-IP for NCCL/Gloo
|
|
})
|
|
|
|
placement_groups, self.bundle_indices = self._get_placement_group(
|
|
tp_size=self.tp_size, worker_kwargs=worker_kwargs)
|
|
|
|
if isinstance(placement_groups, list):
|
|
self.placement_group = None
|
|
else:
|
|
self.placement_group = placement_groups
|
|
|
|
self.workers = []
|
|
for rank in range(self.world_size):
|
|
pg = placement_groups[rank] if isinstance(
|
|
placement_groups, list) else placement_groups
|
|
worker = RayWorkerWrapper.options(
|
|
num_gpus=num_gpus,
|
|
runtime_env=runtime_env,
|
|
scheduling_strategy=PlacementGroupSchedulingStrategy(
|
|
placement_group=pg,
|
|
placement_group_bundle_index=self.bundle_indices[rank],
|
|
)).remote(worker_cls, worker_kwargs, self.world_size, rank)
|
|
self.workers.append(worker)
|
|
|
|
def init_workers_sync(self):
|
|
self.create_workers(RayGPUWorker, self.worker_kwargs)
|
|
try:
|
|
ray.get(self._get_worker_ready_futures())
|
|
except ray.exceptions.ActorDiedError as e:
|
|
raise RuntimeError("RayGPUWorker died during initialization") from e
|
|
port = self.call_all_ray_workers("setup_tcp_store",
|
|
leader_only=True,
|
|
async_call=False)[0]
|
|
self.call_all_ray_workers("setup_distributed_env_and_worker",
|
|
leader_only=False,
|
|
async_call=False,
|
|
port=port)
|
|
|
|
async def init_workers_async(self):
|
|
self.create_workers(RayGPUWorker, self.worker_kwargs)
|
|
try:
|
|
await asyncio.gather(*self._get_worker_ready_futures())
|
|
except ray.exceptions.ActorDiedError as e:
|
|
raise RuntimeError("RayGPUWorker died during initialization") from e
|
|
port = (await asyncio.gather(*self.call_all_ray_workers(
|
|
"setup_tcp_store", leader_only=True, async_call=True)))[0]
|
|
await asyncio.gather(
|
|
*self.call_all_ray_workers("setup_distributed_env_and_worker",
|
|
leader_only=False,
|
|
async_call=True,
|
|
port=port))
|
|
|
|
@unwrap_ray_errors()
|
|
def call_all_ray_workers(self, func: str, leader_only: bool,
|
|
async_call: bool, *args, **kwargs):
|
|
workers = (self.workers[0], ) if leader_only else self.workers
|
|
if async_call:
|
|
return [
|
|
getattr(worker, func).remote(*args, **kwargs)
|
|
for worker in workers
|
|
]
|
|
else:
|
|
return ray.get([
|
|
getattr(worker, func).remote(*args, **kwargs)
|
|
for worker in workers
|
|
])
|
|
|
|
@unwrap_ray_errors()
|
|
def collective_rpc(self,
|
|
method: str,
|
|
args: tuple = (),
|
|
kwargs: Optional[dict] = None,
|
|
non_block: bool = False,
|
|
unique_reply_rank: Optional[int] = None) -> list[Any]:
|
|
workers = (self.workers[unique_reply_rank],
|
|
) if unique_reply_rank is not None else self.workers
|
|
kwargs = kwargs or {}
|
|
|
|
refs = []
|
|
for w in workers:
|
|
try:
|
|
refs.append(getattr(w, method).remote(*args, **kwargs))
|
|
except AttributeError:
|
|
# Here worker is the RayWorkerWrapper.
|
|
# For extended worker methods, we need to use call_worker_method since
|
|
# Ray actor doesn't work with __getattr__ delegation.
|
|
refs.append(w.call_worker_method.remote(method, *args,
|
|
**kwargs))
|
|
return refs if non_block else ray.get(refs)
|
|
|
|
@unwrap_ray_errors()
|
|
async def collective_rpc_async(
|
|
self,
|
|
method: str,
|
|
args: tuple = (),
|
|
kwargs: Optional[dict] = None,
|
|
unique_reply_rank: Optional[int] = None) -> list[Any]:
|
|
refs = self.collective_rpc(method,
|
|
args,
|
|
kwargs,
|
|
non_block=True,
|
|
unique_reply_rank=unique_reply_rank)
|
|
return await asyncio.gather(*refs)
|
|
|
|
def submit(self, request: "GenerationRequest") -> "GenerationResult":
|
|
"""
|
|
Low-level API to the executor. Return a "future" GenerationResult
|
|
which can be waited. Forwards the request to the workers through RPC.
|
|
"""
|
|
request.set_id(self._get_next_client_id())
|
|
logprob_params = self._get_logprob_params(request)
|
|
|
|
with nvtx_range_debug("rpc_submit"):
|
|
self.rpc_client.submit(request).remote(need_response=False)
|
|
|
|
result = GenerationResult(
|
|
request,
|
|
background_error_handler=self._handle_background_error,
|
|
executor=self,
|
|
disaggregated_params=request.disaggregated_params,
|
|
logprob_params=logprob_params)
|
|
self._results[request.id] = result
|
|
|
|
return result
|
|
|
|
def start(self):
|
|
pass
|
|
|
|
def setup_engine_remote(self):
|
|
return self.collective_rpc("setup_engine", non_block=False)
|
|
|
|
async def setup_engine_remote_async(self):
|
|
"""Async version of setup_engine_remote for use after async worker initialization."""
|
|
if not self.workers or len(self.workers) == 0:
|
|
raise RuntimeError(
|
|
"Workers must be initialized before calling setup_engine_remote_async"
|
|
)
|
|
|
|
# Setup engine on all workers
|
|
result = await self.collective_rpc_async("setup_engine")
|
|
logger.info("setup_engine_remote_async finished")
|
|
|
|
# Now that engine is set up, start the mainloop for fetching responses
|
|
if hasattr(self, '_mainloop_started') and not self._mainloop_started:
|
|
logger.info("Starting mainloop after engine setup")
|
|
self.setup_mainloop(tasks=[self._fetch_responses_loop_async],
|
|
thread_name="ray_executor_main_loop")
|
|
self._mainloop_started = True
|
|
|
|
return result
|
|
|
|
def report_device_ids(self) -> list[str]:
|
|
gpu_ids = self.call_all_ray_workers("report_device_id",
|
|
leader_only=False,
|
|
async_call=False)
|
|
return sorted(gpu_ids)
|
|
|
|
def abort_request(self, request_id: int) -> None:
|
|
self.call_all_ray_workers("abort_request",
|
|
leader_only=True,
|
|
async_call=False,
|
|
request_id=request_id)
|
|
|
|
def shutdown(self):
|
|
if hasattr(self, '_shutdown_event') and self._shutdown_event.is_set():
|
|
return
|
|
if hasattr(self, '_shutdown_event'):
|
|
self._shutdown_event.set()
|
|
|
|
logger_debug(f"Shutting down RayExecutor", color="yellow")
|
|
|
|
if hasattr(self, 'main_loop') and self.main_loop and hasattr(
|
|
self, 'main_loop_task_obj') and self.main_loop_task_obj:
|
|
logger_debug("Cancelling main loop task.", color="yellow")
|
|
try:
|
|
self.main_loop.call_soon_threadsafe(
|
|
self.main_loop_task_obj.cancel)
|
|
except Exception as e:
|
|
logger_debug(f"Error cancelling main loop task: {e}",
|
|
color="yellow")
|
|
|
|
if hasattr(self, 'main_loop_thread'):
|
|
self.main_loop_thread.join()
|
|
|
|
# Then, shutdown the workers
|
|
if hasattr(self, 'workers') and self.workers is not None:
|
|
try:
|
|
shutdown_refs = [
|
|
worker.shutdown.remote() for worker in self.workers
|
|
]
|
|
# Add timeout to prevent indefinite hanging
|
|
ray.get(shutdown_refs, timeout=30.0)
|
|
except ray.exceptions.GetTimeoutError:
|
|
logger.warning(
|
|
"Timeout waiting for workers to shutdown after 30 seconds")
|
|
except Exception as e:
|
|
logger.warning(f"Error shutting down: {e}")
|
|
|
|
if hasattr(self, 'rpc_client') and self.rpc_client is not None:
|
|
try:
|
|
self.rpc_client.close()
|
|
except Exception as e:
|
|
logger_debug(f"Suppressed error during RPC client close: {e}")
|
|
|
|
self.workers = None
|
|
if hasattr(self,
|
|
"placement_group") and self.placement_group is not None:
|
|
# Only remove placement group if Ray is still initialized
|
|
# to avoid triggering auto_init_ray() during program exit
|
|
if ray.is_initialized():
|
|
ray.util.remove_placement_group(self.placement_group)
|
|
self.placement_group = None
|
|
self.bundle_indices = None
|
|
|
|
if self.has_start_local_cluser and ray.is_initialized():
|
|
logger.debug("Shutting down Ray cluster")
|
|
ray.shutdown()
|
|
|
|
def _get_worker_ready_futures(self):
|
|
return [worker.__ray_ready__.remote() for worker in self.workers]
|
|
|
|
def _get_placement_group(
|
|
self,
|
|
tp_size: int,
|
|
worker_kwargs: Dict = None) -> Tuple[Any, List[int]]:
|
|
"""
|
|
Either use the existing placement group from driver script (e.g., in the case of RL FW integration),
|
|
or create a default PACK placement group where each bundle has tp_size GPUs.
|
|
- When tp_size ≤ GPUs per node, keep one TP group per node.
|
|
- When tp_size > GPUs per node, allow a TP group span nodes.
|
|
- rank 0 must be put on the driver node
|
|
|
|
Returns:
|
|
Tuple of (placement_group(s), bundle_indices)
|
|
- placement_group(s) can be a single PlacementGroup or a List[PlacementGroup]
|
|
- bundle_indices is always a List[int]
|
|
"""
|
|
llm_args = worker_kwargs.get("llm_args") if worker_kwargs else None
|
|
|
|
placement_config = getattr(llm_args, 'ray_placement_config',
|
|
None) if llm_args else None
|
|
if placement_config and placement_config.placement_groups is not None:
|
|
total_workers = sum(
|
|
len(indices)
|
|
for indices in placement_config.placement_bundle_indices)
|
|
if total_workers != self.world_size:
|
|
raise ValueError(
|
|
f"Total bundle indices ({total_workers}) must equal world_size ({self.world_size})"
|
|
)
|
|
|
|
logger.info(
|
|
f"Creating {self.world_size} workers with external placement groups"
|
|
)
|
|
|
|
flat_pgs = []
|
|
flat_indices = []
|
|
for pg, indices in zip(placement_config.placement_groups,
|
|
placement_config.placement_bundle_indices):
|
|
for idx in indices:
|
|
flat_pgs.append(pg)
|
|
flat_indices.append(idx)
|
|
|
|
return flat_pgs, flat_indices
|
|
|
|
bundle_indices = os.getenv("TRTLLM_RAY_BUNDLE_INDICES", None)
|
|
|
|
if bundle_indices:
|
|
pg = get_current_placement_group()
|
|
if pg is not None:
|
|
bundle_indices = list(map(int, bundle_indices.split(",")))
|
|
assert len(bundle_indices) == self.world_size, (
|
|
f"Need {self.world_size} bundle indices for world_size, got {bundle_indices=}"
|
|
)
|
|
assert len(set(bundle_indices)) == len(bundle_indices), \
|
|
f"TRTLLM_RAY_BUNDLE_INDICES cannot have duplicate values, but got {bundle_indices=}."
|
|
|
|
assert max(bundle_indices) < len(pg.bundle_specs), \
|
|
f"{bundle_indices=} out of range for PG with {len(pg.bundle_specs)} bundles"
|
|
|
|
logger.info(
|
|
f"Found existing placement group {pg.bundle_specs=}. {bundle_indices=}"
|
|
)
|
|
|
|
# TODO: need to ping TP group onto the same node for RL FW integration case
|
|
|
|
return pg, bundle_indices
|
|
else:
|
|
logger.warning(
|
|
f"Ignoring TRTLLM_RAY_BUNDLE_INDICES={bundle_indices} because no global placement group is found."
|
|
)
|
|
|
|
if self.world_size % tp_size:
|
|
raise ValueError("world_size must be a multiple of tp_size")
|
|
|
|
head_tag = f"node:{self.master_address}"
|
|
nodes = ray.nodes()
|
|
gpus_per_node = int(nodes[0]["Resources"].get(
|
|
"GPU", 0)) # assume symmetric across nodes
|
|
|
|
bundle_cpu = bundle_gpu = min(tp_size, gpus_per_node)
|
|
|
|
bundles, bundle_indices = [], []
|
|
current = 0
|
|
for rank in range(self.world_size):
|
|
if current == 0:
|
|
bundle = {"GPU": bundle_gpu, "CPU": bundle_cpu}
|
|
if len(bundles) == 0:
|
|
bundle[head_tag] = 0.01 # to force placement on head node
|
|
bundles.append(bundle)
|
|
|
|
bundle_indices.append(len(bundles) - 1)
|
|
current = (current + 1) % bundle_gpu
|
|
|
|
strategy = "PACK"
|
|
logger.debug(
|
|
f"[Strategy={strategy}] Bundles: {bundles} for tp_size: {tp_size} and world_size: {self.world_size}"
|
|
)
|
|
pg = placement_group(bundles, strategy=strategy)
|
|
|
|
return pg, bundle_indices
|
|
|
|
@property
|
|
def enable_postprocess_parallel(self) -> bool:
|
|
ret = super().enable_postprocess_parallel
|
|
assert ret == False, "Postprocess parallel is not supported in RayExecutor"
|
|
return ret
|