TensorRT-LLMs/tensorrt_llm/executor/executor.py
Yan Chunwei b86256eb54
[TRTLLM-9144][fix] enhance RPC robustness (#8711)
Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>
Signed-off-by: Erin Ho <14718778+hchings@users.noreply.github.com>
Signed-off-by: Yan Chunwei <328693+Superjomn@users.noreply.github.com>
Co-authored-by: Erin Ho <14718778+hchings@users.noreply.github.com>
2025-12-02 21:37:59 +08:00

591 lines
24 KiB
Python

import atexit
import faulthandler
import multiprocessing
import platform
import signal
import traceback
from abc import ABC, abstractmethod
from collections.abc import Mapping
from pathlib import Path
from queue import Queue
from typing import (TYPE_CHECKING, AsyncIterable, Dict, Generator, List,
Optional, Union)
import numpy as np
import torch
from tensorrt_llm.inputs.multimodal import MultimodalParams
from tensorrt_llm.logger import logger, set_level
from .._utils import mpi_world_size
from ..bindings import executor as tllm
from ..builder import Engine
from ..disaggregated_params import DisaggregatedParams
from ..llmapi.llm_args import BaseLlmArgs, TorchLlmArgs
from ..llmapi.llm_utils import KvCacheRetentionConfig
from ..llmapi.mpi_session import (MpiSession, external_mpi_comm_available,
need_spawn_mpi_workers)
from ..llmapi.tokenizer import TokenizerBase
from ..llmapi.utils import (AsyncQueue, enable_llm_debug,
enable_worker_single_process_for_tp1, logger_debug,
print_colored)
from ..sampling_params import (BatchedLogitsProcessor, LogprobParams,
SamplingParams)
from ..scheduling_params import SchedulingParams
from .ipc import FusedIpcQueue
from .postproc_worker import PostprocParams, PostprocWorkerConfig
from .request import GenerationRequest, LoRARequest, PromptAdapterRequest
from .result import GenerationResult, IterationResult
from .utils import IntraProcessQueue, ProcessPoolExecutorSession, RequestError
if TYPE_CHECKING:
from .proxy import GenerationExecutorProxy
from .worker import GenerationExecutorWorker
__all__ = [
"GenerationExecutor",
"CppExecutorError",
]
if enable_llm_debug():
# Mainly enable more detailed logging from cpp runtime.
set_level("info")
async def empty_async_iterable() -> AsyncIterable:
if False: # ensures the function remains an async generator
yield
class CppExecutorError(RuntimeError):
def __init__(self, message: Optional[str] = None):
self.message = message
self.stack_trace = traceback.format_exc()
super().__init__(message)
def __str__(self):
return f"{self.message}\nStack trace:\n{self.stack_trace}"
class IterationResultQueue:
is_initialized: bool = False
# FusedIpcQueue or IntraProcessQueue is used to communicate results from workers to proxy
queue: Optional[Union[Queue, FusedIpcQueue, IntraProcessQueue]] = None
aqueue: Optional[AsyncQueue] = None
class GenerationExecutor(ABC):
def __init__(self,
num_postprocess_workers: int = 0,
postprocess_tokenizer_dir: Optional[str] = None,
is_llm_executor: Optional[bool] = None):
self.postproc_config = PostprocWorkerConfig(
num_postprocess_workers=num_postprocess_workers,
postprocess_tokenizer_dir=postprocess_tokenizer_dir)
self.kv_events_queues = IterationResultQueue()
self.stats_queues = IterationResultQueue()
atexit.register(self.shutdown)
# This is used to capture the exceptions from the threads.
self._error_queue = Queue()
# A flag to avoid calling shutdown() recursively. This happens when the background threads raise errors.
self.doing_shutdown = False
self._last_client_id: int = 1
# whether it's the executor instance of LLM API
self._is_llm_executor = is_llm_executor
self._iter_kv_events_result: IterationResult | None = None
self._iter_stats_result: IterationResult | None = None
@abstractmethod
def submit(self, request: GenerationRequest) -> GenerationResult:
pass
@abstractmethod
def abort_request(self, request_id: int) -> None:
pass
def generate_async(
self,
prompt_token_ids: List[int],
sampling_params: SamplingParams,
query_token_ids: Optional[Union[torch.Tensor, np.ndarray, list]] = None,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
streaming: bool = False,
kv_cache_retention_config: Optional[KvCacheRetentionConfig] = None,
disaggregated_params: Optional[DisaggregatedParams] = None,
trace_headers: Optional[Mapping[str, str]] = None,
postproc_params: Optional[PostprocParams] = None,
multimodal_params: Optional[MultimodalParams] = None,
scheduling_params: Optional[SchedulingParams] = None,
cache_salt_id: Optional[int] = None,
arrival_time: Optional[float] = None,
) -> GenerationResult:
"""Generate output for the given prompt token ids in the asynchronous mode.
Asynchronous generation accepts single prompt only.
"""
assert isinstance(prompt_token_ids[0], int)
assert isinstance(sampling_params, SamplingParams)
self._maybe_initialize_iteration_results()
if postproc_params:
postproc_params.postproc_args.num_prompt_tokens = len(
prompt_token_ids)
request = GenerationRequest(
prompt_token_ids,
sampling_params=sampling_params,
postproc_params=postproc_params,
query_token_ids=query_token_ids,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
streaming=streaming,
kv_cache_retention_config=kv_cache_retention_config,
disaggregated_params=disaggregated_params,
trace_headers=trace_headers,
multimodal_params=multimodal_params,
scheduling_params=scheduling_params,
cache_salt_id=cache_salt_id,
arrival_time=arrival_time)
result = self.submit(request)
# release memory in time
if hasattr(request, "multimodal_params"):
del request.multimodal_params
return result
def generate(
self,
prompt_token_ids: Union[List[int], List[List[int]]],
sampling_params: Union[SamplingParams, List[SamplingParams]],
query_token_ids: Optional[Union[torch.Tensor, np.ndarray, list]] = None,
lora_request: Optional[Union[LoRARequest, List[LoRARequest]]] = None,
prompt_adapter_request: Optional[Union[
PromptAdapterRequest, List[PromptAdapterRequest]]] = None,
disaggregated_params: Optional[DisaggregatedParams] = None,
) -> Union[GenerationResult, List[GenerationResult]]:
"""Generate output for the given prompt token ids in the synchronous mode.
Synchronous generation accepts either single prompt or batched prompts.
"""
unbatched = isinstance(prompt_token_ids[0], int)
if unbatched:
prompt_token_ids = [prompt_token_ids]
if query_token_ids:
query_token_ids = [query_token_ids]
futures = []
for i, p in enumerate(prompt_token_ids):
if isinstance(sampling_params, list):
sp = sampling_params[i]
else:
sp = sampling_params
if isinstance(lora_request, list):
lora_req = lora_request[i]
else:
lora_req = lora_request
if isinstance(prompt_adapter_request, list):
pa_req = prompt_adapter_request[i]
else:
pa_req = prompt_adapter_request
future = self.generate_async(
p,
sampling_params=sp,
query_token_ids=query_token_ids,
lora_request=lora_req,
prompt_adapter_request=pa_req,
streaming=False,
disaggregated_params=disaggregated_params)
futures.append(future)
for future in futures:
future.result()
if unbatched:
futures = futures[0]
return futures
def _get_next_client_id(self):
# (self._last_client_id + 1) % UINT64_MAX
self._last_client_id = (self._last_client_id + 1) & ((1 << 64) - 1)
return self._last_client_id
def _get_logprob_params(
self, request: GenerationRequest) -> Optional[LogprobParams]:
"""Store logprobs-related fields from request for the later logprob calculation."""
logprob_params = None
if request.sampling_params.logprobs or request.sampling_params.prompt_logprobs:
logprob_params = LogprobParams(
logprobs=request.sampling_params.logprobs,
prompt_logprobs=request.sampling_params.prompt_logprobs,
# drop logits if users didn't explicitly ask for it, or if it's using PostProcess flow
drop_context_logits=(
not request.sampling_params._need_return_context_logits)
or self.postproc_config.num_postprocess_workers > 0,
drop_generation_logits=(
not request.sampling_params._need_return_generation_logits)
or self.postproc_config.num_postprocess_workers > 0)
return logprob_params
def _maybe_initialize_iteration_results(self):
if self._is_llm_executor:
if self._iter_stats_result is None:
# singleton to store cpp runtime stats
self._iter_stats_result = IterationResult()
else:
# expect more engine stats whenever new prompts are submitted
self._iter_stats_result.mark_undone()
if self._iter_kv_events_result is None:
self._iter_kv_events_result = IterationResult()
else:
self._iter_kv_events_result.mark_undone()
def _handle_background_error(self, error: Optional[Exception | str] = None):
""" Process the errors from the threads or processes.
NOTE: This should be called in the main thread.
"""
if error is not None:
# For details please refer to the comment of `GenerationResult.error`
if isinstance(error, RequestError):
# A per-request error, can be captured and ignored
if enable_llm_debug():
print_colored(f"Got per-request error: {repr(error)}\n",
"red")
elif isinstance(error, str):
# A per-request error, can be captured and ignored
if enable_llm_debug():
print_colored(f"Got per-request error: {repr(error)}\n",
"red")
print_colored(str(traceback.extract_stack()) + "\n", "red")
error = RequestError(error)
else:
# Serious error from background thread or process
if not isinstance(error, BaseException):
error = RuntimeError(repr(error))
if enable_llm_debug():
print_colored(
f"Got background error: {repr(error)}, will shutdown the LLM instance\n",
"red")
self.shutdown()
raise error
# Here we raise the first error in the queue. This method will be called repeatedly and user can choose to catch
# more than one error.
if not self._error_queue.empty():
e = self._error_queue.get()
self._error_queue.task_done()
self.shutdown()
# We can catch some exceptions here.
raise e
def is_shutdown(self) -> bool:
return self.doing_shutdown
@abstractmethod
def shutdown(self):
pass
@property
def enable_postprocess_parallel(self) -> bool:
return self.postproc_config.enabled
def get_stats(self, timeout: float) -> List[dict]:
"""
Get iteration statistics from the runtime.
Args:
timeout (float): Max wait time in seconds when retrieving stats from queue.
Returns:
List[dict]: A list of runtime stats as dict.
"""
if self._iter_stats_result is None:
print_colored(
"Iteration statistics are not available yet. To collect runtime statistics, please call get_stats() AFTER prompts have been submitted.\n",
"yellow")
return []
self._iter_stats_result.set_timeout(timeout)
return self._iter_stats_result.get_results()
def aget_stats(self, timeout: float) -> IterationResult:
"""
Get iteration statistics from the runtime.
Returns:
IterationResult: An async iterable object containing runtime stats.
"""
if self._iter_stats_result is None:
print_colored(
"Iteration statistics are not available yet. To collect runtime statistics, please call get_stats_async() in async coroutine or the /metrics endpoint (if you're using trtllm-serve) AFTER prompts have been submitted.\n",
"yellow")
return empty_async_iterable()
self._iter_stats_result.set_timeout(timeout)
return self._iter_stats_result
def get_kv_events(self, timeout: float) -> List[dict]:
"""
Get iteration kv events from the runtime.
Args:
timeout (float): Max wait time in seconds when retrieving stats from queue.
Returns:
List[dict]: A list of runtime events as dict.
"""
assert self._iter_kv_events_result is not None, "KV Event IterationResult is not properly instantiated."
self._iter_kv_events_result.set_timeout(timeout)
return self._iter_kv_events_result.get_results()
def aget_kv_events(self, timeout=None) -> IterationResult:
"""
Get iteration kv events from the runtime.
Args:
timeout (float): Max wait time in seconds when retrieving stats from queue.
Returns:
IterationResult: An async iterable object containing runtime events.
"""
assert self._iter_kv_events_result is not None, "KV Event IterationResult is not properly instantiated."
self._iter_kv_events_result.set_timeout(timeout)
return self._iter_kv_events_result
@staticmethod
def _create_ray_executor(
worker_kwargs: Dict,
model_world_size: int,
postproc_worker_config: PostprocWorkerConfig,
is_llm_executor: bool,
tp_size: int,
):
logger.warning(f"Orchestrator is creating Ray executor")
from .ray_executor import RayExecutor
return RayExecutor(worker_kwargs,
model_world_size=model_world_size,
postproc_worker_config=postproc_worker_config,
is_llm_executor=is_llm_executor,
tp_size=tp_size)
@staticmethod
def _create_rpc_executor(
worker_kwargs: Dict,
model_world_size: int,
mpi_session: Optional[MpiSession],
postproc_worker_config: PostprocWorkerConfig,
is_llm_executor: bool,
):
"""Create RPC-based executor (GenerationExecutorRpcProxy)."""
from .rpc_proxy import GenerationExecutorRpcProxy
logger.warning(f"Orchestrator is creating RPC executor")
return GenerationExecutorRpcProxy(
worker_kwargs,
model_world_size=model_world_size,
mpi_session=mpi_session,
postproc_worker_config=postproc_worker_config,
is_llm_executor=is_llm_executor)
@staticmethod
def _create_ipc_executor(
worker_kwargs: Dict,
model_world_size: int,
mpi_session: Optional[MpiSession],
postproc_worker_config: PostprocWorkerConfig,
is_llm_executor: bool,
use_worker: bool = False,
):
"""Create IPC-based executor (GenerationExecutorProxy or GenerationExecutorWorker).
Args:
use_worker: If True, creates GenerationExecutorWorker (single process).
If False, creates GenerationExecutorProxy (multi-process with IPC).
"""
logger.warning(f"Orchestrator is creating IPC executor")
if use_worker:
from .worker import GenerationExecutorWorker
return GenerationExecutorWorker(**worker_kwargs,
is_llm_executor=is_llm_executor)
else:
from .proxy import GenerationExecutorProxy
return GenerationExecutorProxy(
worker_kwargs,
model_world_size=model_world_size,
mpi_session=mpi_session,
postproc_worker_config=postproc_worker_config,
is_llm_executor=is_llm_executor)
@staticmethod
def create(
engine: Union[Path, Engine],
executor_config: Optional[tllm.ExecutorConfig] = None,
batched_logits_processor: Optional[BatchedLogitsProcessor] = None,
model_world_size: int = 1,
world_size: int = 0,
mpi_session: Optional[MpiSession] = None,
reuse_mpi_comm: bool = False,
return_logits: bool = False,
postproc_worker_config: Optional[PostprocWorkerConfig] = None,
is_llm_executor: Optional[bool] = None,
hf_model_dir: Optional[Path] = None,
tokenizer: Optional[TokenizerBase] = None,
llm_args: Optional[BaseLlmArgs] = None,
**args,
) -> Union["GenerationExecutorProxy", "GenerationExecutorWorker"]:
if world_size == 0:
world_size = mpi_world_size()
if world_size > 1 and world_size < model_world_size:
raise RuntimeError(
"Cannot instantiate Generator for engine built "
f"for {model_world_size} ranks, while currently running "
f"on {world_size} ranks.")
postproc_worker_config = postproc_worker_config or PostprocWorkerConfig(
)
if postproc_worker_config.enabled:
logger_debug(
f"Using {postproc_worker_config.num_postprocess_workers} postprocess parallel processes.\n",
"green")
worker_kwargs = {
"engine": engine,
"executor_config": executor_config,
"batched_logits_processor": batched_logits_processor,
"hf_model_dir": hf_model_dir,
"tokenizer": tokenizer,
"llm_args": llm_args,
}
orchestrator_type = None if not isinstance(
llm_args, TorchLlmArgs) else llm_args.orchestrator_type
if orchestrator_type == "ray":
if llm_args and hasattr(llm_args, 'ray_worker_extension_cls'):
worker_kwargs[
"ray_worker_extension_cls"] = llm_args.ray_worker_extension_cls
return GenerationExecutor._create_ray_executor(
worker_kwargs,
model_world_size,
postproc_worker_config,
is_llm_executor=is_llm_executor,
tp_size=args.get("tp_size", 1))
elif orchestrator_type is not None and orchestrator_type != "rpc":
raise ValueError(
f"Unsupported orchestrator_type: {orchestrator_type}")
# The case where the Python main process is launched by mpirun
mpirun_launch = external_mpi_comm_available(model_world_size)
# The case where the Python main process utilizes mpi4py to spawn MPI workers
spawn_workers = need_spawn_mpi_workers(model_world_size)
orchestrator_is_rpc = llm_args and llm_args.orchestrator_type == "rpc"
if spawn_workers or (mpirun_launch and reuse_mpi_comm):
if reuse_mpi_comm:
assert mpi_session is not None, "reuse_mpi_comm requires an external MPI session"
if orchestrator_is_rpc:
return GenerationExecutor._create_rpc_executor(
worker_kwargs,
model_world_size=model_world_size,
mpi_session=mpi_session,
postproc_worker_config=postproc_worker_config,
is_llm_executor=is_llm_executor)
return GenerationExecutor._create_ipc_executor(
worker_kwargs,
model_world_size=model_world_size,
mpi_session=mpi_session,
postproc_worker_config=postproc_worker_config,
is_llm_executor=is_llm_executor,
use_worker=False)
# WAR: For the performance of gathering logits, we use single process worker
# for TP1 to avoid the large overhead of IPC.
# WAR: Developers can enable this manually, this will be easier for TP1
# debugging. We will introduce a better solution in the future.
if return_logits or enable_worker_single_process_for_tp1():
logger.warning(
"Using single process worker for TP1, this may hurt streaming generation performance."
)
if orchestrator_is_rpc:
return GenerationExecutor._create_rpc_executor(
worker_kwargs,
model_world_size=model_world_size,
mpi_session=mpi_session,
postproc_worker_config=postproc_worker_config,
is_llm_executor=is_llm_executor)
return GenerationExecutor._create_ipc_executor(
worker_kwargs,
model_world_size=model_world_size,
mpi_session=mpi_session,
postproc_worker_config=postproc_worker_config,
is_llm_executor=is_llm_executor,
use_worker=True)
# For single-gpu case:
# Partition the workload to multiple process for streaming performance.
# While this requires uses to protect their entrypoint to
# `if __name__ == "__main__":`.
if not platform.system() == 'Windows':
if orchestrator_is_rpc:
return GenerationExecutor._create_rpc_executor(
worker_kwargs,
model_world_size=model_world_size,
mpi_session=mpi_session,
postproc_worker_config=postproc_worker_config,
is_llm_executor=is_llm_executor)
return GenerationExecutor._create_ipc_executor(
worker_kwargs,
model_world_size=model_world_size,
mpi_session=None, # use mpi4py
postproc_worker_config=postproc_worker_config,
is_llm_executor=is_llm_executor,
use_worker=False)
else:
ctx = multiprocessing.get_context("spawn")
# The ProcessPoolExecutorSession is used to support Windows, as mpi4py cannot.
mpi_session = ProcessPoolExecutorSession(n_workers=1,
mp_context=ctx)
# TODO: add rpc worker here
return GenerationExecutor._create_ipc_executor(
worker_kwargs,
model_world_size=model_world_size,
mpi_session=mpi_session,
postproc_worker_config=postproc_worker_config,
is_llm_executor=is_llm_executor,
use_worker=False)
def wait_first_completed(
self, futures: List[GenerationResult]
) -> Generator[GenerationResult, None, None]:
wait_set = set(futures)
# clear already-finished requests
for f in futures:
if f._done:
wait_set.pop(f)
yield f
# wait remaining active requests
while len(wait_set) > 0:
fut = wait_set.pop()
if fut.request_id not in self._results:
yield fut
else:
wait_set.add(fut)
if enable_llm_debug():
print_colored("LLM debug mode enabled.\n", "yellow")
# This will dump all the alive threads when the process is interrupted by SIGINT.
faulthandler.register(signal.SIGINT, all_threads=True)