import atexit import faulthandler import multiprocessing import platform import signal import traceback from abc import ABC, abstractmethod from collections.abc import Mapping from pathlib import Path from queue import Queue from typing import (TYPE_CHECKING, AsyncIterable, Dict, Generator, List, Optional, Union) import numpy as np import torch from tensorrt_llm.inputs.multimodal import MultimodalParams from tensorrt_llm.logger import logger, set_level from .._utils import mpi_world_size from ..bindings import executor as tllm from ..builder import Engine from ..disaggregated_params import DisaggregatedParams from ..llmapi.llm_args import BaseLlmArgs, TorchLlmArgs from ..llmapi.llm_utils import KvCacheRetentionConfig from ..llmapi.mpi_session import (MpiSession, external_mpi_comm_available, need_spawn_mpi_workers) from ..llmapi.tokenizer import TokenizerBase from ..llmapi.utils import (AsyncQueue, enable_llm_debug, enable_worker_single_process_for_tp1, logger_debug, print_colored) from ..sampling_params import (BatchedLogitsProcessor, LogprobParams, SamplingParams) from ..scheduling_params import SchedulingParams from .ipc import FusedIpcQueue from .postproc_worker import PostprocParams, PostprocWorkerConfig from .request import GenerationRequest, LoRARequest, PromptAdapterRequest from .result import GenerationResult, IterationResult from .utils import IntraProcessQueue, ProcessPoolExecutorSession, RequestError if TYPE_CHECKING: from .proxy import GenerationExecutorProxy from .worker import GenerationExecutorWorker __all__ = [ "GenerationExecutor", "CppExecutorError", ] if enable_llm_debug(): # Mainly enable more detailed logging from cpp runtime. set_level("info") async def empty_async_iterable() -> AsyncIterable: if False: # ensures the function remains an async generator yield class CppExecutorError(RuntimeError): def __init__(self, message: Optional[str] = None): self.message = message self.stack_trace = traceback.format_exc() super().__init__(message) def __str__(self): return f"{self.message}\nStack trace:\n{self.stack_trace}" class IterationResultQueue: is_initialized: bool = False # FusedIpcQueue or IntraProcessQueue is used to communicate results from workers to proxy queue: Optional[Union[Queue, FusedIpcQueue, IntraProcessQueue]] = None aqueue: Optional[AsyncQueue] = None class GenerationExecutor(ABC): def __init__(self, num_postprocess_workers: int = 0, postprocess_tokenizer_dir: Optional[str] = None, is_llm_executor: Optional[bool] = None): self.postproc_config = PostprocWorkerConfig( num_postprocess_workers=num_postprocess_workers, postprocess_tokenizer_dir=postprocess_tokenizer_dir) self.kv_events_queues = IterationResultQueue() self.stats_queues = IterationResultQueue() atexit.register(self.shutdown) # This is used to capture the exceptions from the threads. self._error_queue = Queue() # A flag to avoid calling shutdown() recursively. This happens when the background threads raise errors. self.doing_shutdown = False self._last_client_id: int = 1 # whether it's the executor instance of LLM API self._is_llm_executor = is_llm_executor self._iter_kv_events_result: IterationResult | None = None self._iter_stats_result: IterationResult | None = None @abstractmethod def submit(self, request: GenerationRequest) -> GenerationResult: pass @abstractmethod def abort_request(self, request_id: int) -> None: pass def generate_async( self, prompt_token_ids: List[int], sampling_params: SamplingParams, query_token_ids: Optional[Union[torch.Tensor, np.ndarray, list]] = None, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, streaming: bool = False, kv_cache_retention_config: Optional[KvCacheRetentionConfig] = None, disaggregated_params: Optional[DisaggregatedParams] = None, trace_headers: Optional[Mapping[str, str]] = None, postproc_params: Optional[PostprocParams] = None, multimodal_params: Optional[MultimodalParams] = None, scheduling_params: Optional[SchedulingParams] = None, cache_salt_id: Optional[int] = None, arrival_time: Optional[float] = None, ) -> GenerationResult: """Generate output for the given prompt token ids in the asynchronous mode. Asynchronous generation accepts single prompt only. """ assert isinstance(prompt_token_ids[0], int) assert isinstance(sampling_params, SamplingParams) self._maybe_initialize_iteration_results() if postproc_params: postproc_params.postproc_args.num_prompt_tokens = len( prompt_token_ids) request = GenerationRequest( prompt_token_ids, sampling_params=sampling_params, postproc_params=postproc_params, query_token_ids=query_token_ids, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, streaming=streaming, kv_cache_retention_config=kv_cache_retention_config, disaggregated_params=disaggregated_params, trace_headers=trace_headers, multimodal_params=multimodal_params, scheduling_params=scheduling_params, cache_salt_id=cache_salt_id, arrival_time=arrival_time) result = self.submit(request) # release memory in time if hasattr(request, "multimodal_params"): del request.multimodal_params return result def generate( self, prompt_token_ids: Union[List[int], List[List[int]]], sampling_params: Union[SamplingParams, List[SamplingParams]], query_token_ids: Optional[Union[torch.Tensor, np.ndarray, list]] = None, lora_request: Optional[Union[LoRARequest, List[LoRARequest]]] = None, prompt_adapter_request: Optional[Union[ PromptAdapterRequest, List[PromptAdapterRequest]]] = None, disaggregated_params: Optional[DisaggregatedParams] = None, ) -> Union[GenerationResult, List[GenerationResult]]: """Generate output for the given prompt token ids in the synchronous mode. Synchronous generation accepts either single prompt or batched prompts. """ unbatched = isinstance(prompt_token_ids[0], int) if unbatched: prompt_token_ids = [prompt_token_ids] if query_token_ids: query_token_ids = [query_token_ids] futures = [] for i, p in enumerate(prompt_token_ids): if isinstance(sampling_params, list): sp = sampling_params[i] else: sp = sampling_params if isinstance(lora_request, list): lora_req = lora_request[i] else: lora_req = lora_request if isinstance(prompt_adapter_request, list): pa_req = prompt_adapter_request[i] else: pa_req = prompt_adapter_request future = self.generate_async( p, sampling_params=sp, query_token_ids=query_token_ids, lora_request=lora_req, prompt_adapter_request=pa_req, streaming=False, disaggregated_params=disaggregated_params) futures.append(future) for future in futures: future.result() if unbatched: futures = futures[0] return futures def _get_next_client_id(self): # (self._last_client_id + 1) % UINT64_MAX self._last_client_id = (self._last_client_id + 1) & ((1 << 64) - 1) return self._last_client_id def _get_logprob_params( self, request: GenerationRequest) -> Optional[LogprobParams]: """Store logprobs-related fields from request for the later logprob calculation.""" logprob_params = None if request.sampling_params.logprobs or request.sampling_params.prompt_logprobs: logprob_params = LogprobParams( logprobs=request.sampling_params.logprobs, prompt_logprobs=request.sampling_params.prompt_logprobs, # drop logits if users didn't explicitly ask for it, or if it's using PostProcess flow drop_context_logits=( not request.sampling_params._need_return_context_logits) or self.postproc_config.num_postprocess_workers > 0, drop_generation_logits=( not request.sampling_params._need_return_generation_logits) or self.postproc_config.num_postprocess_workers > 0) return logprob_params def _maybe_initialize_iteration_results(self): if self._is_llm_executor: if self._iter_stats_result is None: # singleton to store cpp runtime stats self._iter_stats_result = IterationResult() else: # expect more engine stats whenever new prompts are submitted self._iter_stats_result.mark_undone() if self._iter_kv_events_result is None: self._iter_kv_events_result = IterationResult() else: self._iter_kv_events_result.mark_undone() def _handle_background_error(self, error: Optional[Exception | str] = None): """ Process the errors from the threads or processes. NOTE: This should be called in the main thread. """ if error is not None: # For details please refer to the comment of `GenerationResult.error` if isinstance(error, RequestError): # A per-request error, can be captured and ignored if enable_llm_debug(): print_colored(f"Got per-request error: {repr(error)}\n", "red") elif isinstance(error, str): # A per-request error, can be captured and ignored if enable_llm_debug(): print_colored(f"Got per-request error: {repr(error)}\n", "red") print_colored(str(traceback.extract_stack()) + "\n", "red") error = RequestError(error) else: # Serious error from background thread or process if not isinstance(error, BaseException): error = RuntimeError(repr(error)) if enable_llm_debug(): print_colored( f"Got background error: {repr(error)}, will shutdown the LLM instance\n", "red") self.shutdown() raise error # Here we raise the first error in the queue. This method will be called repeatedly and user can choose to catch # more than one error. if not self._error_queue.empty(): e = self._error_queue.get() self._error_queue.task_done() self.shutdown() # We can catch some exceptions here. raise e def is_shutdown(self) -> bool: return self.doing_shutdown @abstractmethod def shutdown(self): pass @property def enable_postprocess_parallel(self) -> bool: return self.postproc_config.enabled def get_stats(self, timeout: float) -> List[dict]: """ Get iteration statistics from the runtime. Args: timeout (float): Max wait time in seconds when retrieving stats from queue. Returns: List[dict]: A list of runtime stats as dict. """ if self._iter_stats_result is None: print_colored( "Iteration statistics are not available yet. To collect runtime statistics, please call get_stats() AFTER prompts have been submitted.\n", "yellow") return [] self._iter_stats_result.set_timeout(timeout) return self._iter_stats_result.get_results() def aget_stats(self, timeout: float) -> IterationResult: """ Get iteration statistics from the runtime. Returns: IterationResult: An async iterable object containing runtime stats. """ if self._iter_stats_result is None: print_colored( "Iteration statistics are not available yet. To collect runtime statistics, please call get_stats_async() in async coroutine or the /metrics endpoint (if you're using trtllm-serve) AFTER prompts have been submitted.\n", "yellow") return empty_async_iterable() self._iter_stats_result.set_timeout(timeout) return self._iter_stats_result def get_kv_events(self, timeout: float) -> List[dict]: """ Get iteration kv events from the runtime. Args: timeout (float): Max wait time in seconds when retrieving stats from queue. Returns: List[dict]: A list of runtime events as dict. """ assert self._iter_kv_events_result is not None, "KV Event IterationResult is not properly instantiated." self._iter_kv_events_result.set_timeout(timeout) return self._iter_kv_events_result.get_results() def aget_kv_events(self, timeout=None) -> IterationResult: """ Get iteration kv events from the runtime. Args: timeout (float): Max wait time in seconds when retrieving stats from queue. Returns: IterationResult: An async iterable object containing runtime events. """ assert self._iter_kv_events_result is not None, "KV Event IterationResult is not properly instantiated." self._iter_kv_events_result.set_timeout(timeout) return self._iter_kv_events_result @staticmethod def _create_ray_executor( worker_kwargs: Dict, model_world_size: int, postproc_worker_config: PostprocWorkerConfig, is_llm_executor: bool, tp_size: int, ): logger.warning(f"Orchestrator is creating Ray executor") from .ray_executor import RayExecutor return RayExecutor(worker_kwargs, model_world_size=model_world_size, postproc_worker_config=postproc_worker_config, is_llm_executor=is_llm_executor, tp_size=tp_size) @staticmethod def _create_rpc_executor( worker_kwargs: Dict, model_world_size: int, mpi_session: Optional[MpiSession], postproc_worker_config: PostprocWorkerConfig, is_llm_executor: bool, ): """Create RPC-based executor (GenerationExecutorRpcProxy).""" from .rpc_proxy import GenerationExecutorRpcProxy logger.warning(f"Orchestrator is creating RPC executor") return GenerationExecutorRpcProxy( worker_kwargs, model_world_size=model_world_size, mpi_session=mpi_session, postproc_worker_config=postproc_worker_config, is_llm_executor=is_llm_executor) @staticmethod def _create_ipc_executor( worker_kwargs: Dict, model_world_size: int, mpi_session: Optional[MpiSession], postproc_worker_config: PostprocWorkerConfig, is_llm_executor: bool, use_worker: bool = False, ): """Create IPC-based executor (GenerationExecutorProxy or GenerationExecutorWorker). Args: use_worker: If True, creates GenerationExecutorWorker (single process). If False, creates GenerationExecutorProxy (multi-process with IPC). """ logger.warning(f"Orchestrator is creating IPC executor") if use_worker: from .worker import GenerationExecutorWorker return GenerationExecutorWorker(**worker_kwargs, is_llm_executor=is_llm_executor) else: from .proxy import GenerationExecutorProxy return GenerationExecutorProxy( worker_kwargs, model_world_size=model_world_size, mpi_session=mpi_session, postproc_worker_config=postproc_worker_config, is_llm_executor=is_llm_executor) @staticmethod def create( engine: Union[Path, Engine], executor_config: Optional[tllm.ExecutorConfig] = None, batched_logits_processor: Optional[BatchedLogitsProcessor] = None, model_world_size: int = 1, world_size: int = 0, mpi_session: Optional[MpiSession] = None, reuse_mpi_comm: bool = False, return_logits: bool = False, postproc_worker_config: Optional[PostprocWorkerConfig] = None, is_llm_executor: Optional[bool] = None, hf_model_dir: Optional[Path] = None, tokenizer: Optional[TokenizerBase] = None, llm_args: Optional[BaseLlmArgs] = None, **args, ) -> Union["GenerationExecutorProxy", "GenerationExecutorWorker"]: if world_size == 0: world_size = mpi_world_size() if world_size > 1 and world_size < model_world_size: raise RuntimeError( "Cannot instantiate Generator for engine built " f"for {model_world_size} ranks, while currently running " f"on {world_size} ranks.") postproc_worker_config = postproc_worker_config or PostprocWorkerConfig( ) if postproc_worker_config.enabled: logger_debug( f"Using {postproc_worker_config.num_postprocess_workers} postprocess parallel processes.\n", "green") worker_kwargs = { "engine": engine, "executor_config": executor_config, "batched_logits_processor": batched_logits_processor, "hf_model_dir": hf_model_dir, "tokenizer": tokenizer, "llm_args": llm_args, } orchestrator_type = None if not isinstance( llm_args, TorchLlmArgs) else llm_args.orchestrator_type if orchestrator_type == "ray": if llm_args and hasattr(llm_args, 'ray_worker_extension_cls'): worker_kwargs[ "ray_worker_extension_cls"] = llm_args.ray_worker_extension_cls return GenerationExecutor._create_ray_executor( worker_kwargs, model_world_size, postproc_worker_config, is_llm_executor=is_llm_executor, tp_size=args.get("tp_size", 1)) elif orchestrator_type is not None and orchestrator_type != "rpc": raise ValueError( f"Unsupported orchestrator_type: {orchestrator_type}") # The case where the Python main process is launched by mpirun mpirun_launch = external_mpi_comm_available(model_world_size) # The case where the Python main process utilizes mpi4py to spawn MPI workers spawn_workers = need_spawn_mpi_workers(model_world_size) orchestrator_is_rpc = llm_args and llm_args.orchestrator_type == "rpc" if spawn_workers or (mpirun_launch and reuse_mpi_comm): if reuse_mpi_comm: assert mpi_session is not None, "reuse_mpi_comm requires an external MPI session" if orchestrator_is_rpc: return GenerationExecutor._create_rpc_executor( worker_kwargs, model_world_size=model_world_size, mpi_session=mpi_session, postproc_worker_config=postproc_worker_config, is_llm_executor=is_llm_executor) return GenerationExecutor._create_ipc_executor( worker_kwargs, model_world_size=model_world_size, mpi_session=mpi_session, postproc_worker_config=postproc_worker_config, is_llm_executor=is_llm_executor, use_worker=False) # WAR: For the performance of gathering logits, we use single process worker # for TP1 to avoid the large overhead of IPC. # WAR: Developers can enable this manually, this will be easier for TP1 # debugging. We will introduce a better solution in the future. if return_logits or enable_worker_single_process_for_tp1(): logger.warning( "Using single process worker for TP1, this may hurt streaming generation performance." ) if orchestrator_is_rpc: return GenerationExecutor._create_rpc_executor( worker_kwargs, model_world_size=model_world_size, mpi_session=mpi_session, postproc_worker_config=postproc_worker_config, is_llm_executor=is_llm_executor) return GenerationExecutor._create_ipc_executor( worker_kwargs, model_world_size=model_world_size, mpi_session=mpi_session, postproc_worker_config=postproc_worker_config, is_llm_executor=is_llm_executor, use_worker=True) # For single-gpu case: # Partition the workload to multiple process for streaming performance. # While this requires uses to protect their entrypoint to # `if __name__ == "__main__":`. if not platform.system() == 'Windows': if orchestrator_is_rpc: return GenerationExecutor._create_rpc_executor( worker_kwargs, model_world_size=model_world_size, mpi_session=mpi_session, postproc_worker_config=postproc_worker_config, is_llm_executor=is_llm_executor) return GenerationExecutor._create_ipc_executor( worker_kwargs, model_world_size=model_world_size, mpi_session=None, # use mpi4py postproc_worker_config=postproc_worker_config, is_llm_executor=is_llm_executor, use_worker=False) else: ctx = multiprocessing.get_context("spawn") # The ProcessPoolExecutorSession is used to support Windows, as mpi4py cannot. mpi_session = ProcessPoolExecutorSession(n_workers=1, mp_context=ctx) # TODO: add rpc worker here return GenerationExecutor._create_ipc_executor( worker_kwargs, model_world_size=model_world_size, mpi_session=mpi_session, postproc_worker_config=postproc_worker_config, is_llm_executor=is_llm_executor, use_worker=False) def wait_first_completed( self, futures: List[GenerationResult] ) -> Generator[GenerationResult, None, None]: wait_set = set(futures) # clear already-finished requests for f in futures: if f._done: wait_set.pop(f) yield f # wait remaining active requests while len(wait_set) > 0: fut = wait_set.pop() if fut.request_id not in self._results: yield fut else: wait_set.add(fut) if enable_llm_debug(): print_colored("LLM debug mode enabled.\n", "yellow") # This will dump all the alive threads when the process is interrupted by SIGINT. faulthandler.register(signal.SIGINT, all_threads=True)