import copy import datetime import enum import json import os import weakref from pathlib import Path from queue import Queue from typing import Dict, List, Optional, Tuple, Union import psutil import torch from tensorrt_llm.logger import logger from .._torch.pyexecutor.llm_request import LlmResponse from .._utils import (global_mpi_rank, global_mpi_size, mpi_comm, mpi_rank, nvtx_range_debug) from ..bindings import executor as tllm from ..builder import ConfigEncoder, Engine, EngineConfig from ..llmapi.llm_args import BaseLlmArgs, PybindMirror from ..llmapi.tokenizer import TokenizerBase from ..llmapi.tracer import global_tracer from ..llmapi.utils import _SyncQueue, get_numa_aware_cpu_affinity, logger_debug from ..lora_manager import LoraManager from ..metrics import RequestEventTiming from ..prompt_adapter_manager import PromptAdapterManager from ..runtime import ModelConfig from ..runtime.model_runner import _engine_config_to_model_config from ..sampling_params import BatchedLogitsProcessor, SamplingParams from .executor import GenerationExecutor, IterationResultQueue from .ipc import FusedIpcQueue, IpcQueue from .postproc_worker import (PostprocParams, PostprocWorker, PostprocWorkerConfig) from .request import GenerationRequest, LoRARequest, PromptAdapterRequest from .result import (GenerationResult, LogProbsResult, ResponseWrapper, compute_logprobs) from .utils import (ErrorResponse, IntraProcessQueue, RequestError, is_llm_response) __all__ = [ "BaseWorker", "_init_hf_modules", ] def _init_hf_modules(): """Initialize cached HuggingFace modules for models with trust_remote_code=True. This is safe to call multiple times (idempotent) and should be called: 1. At module import time (for main process and spawned subprocesses) 2. At worker_main entry (for forked processes or external MPI ranks) References: https://github.com/vllm-project/vllm/pull/871 """ try: from transformers.dynamic_module_utils import init_hf_modules init_hf_modules() logger.debug("HF modules initialized") except ImportError as e: logger.warning(f"ImportError initializing HF modules: {e}") except Exception as e: logger.error(f"Exception initializing HF modules: {e}") _init_hf_modules() class BaseWorker(GenerationExecutor): class WorkerExit(GeneratorExit): pass def __init__( self, engine: Union[Path, Engine], executor_config: Optional[tllm.ExecutorConfig] = None, batched_logits_processor: Optional[BatchedLogitsProcessor] = None, postproc_worker_config: Optional[PostprocWorkerConfig] = None, is_llm_executor: Optional[bool] = None, hf_model_dir: Optional[Path] = None, tokenizer: Optional[TokenizerBase] = None, llm_args: Optional[BaseLlmArgs] = None, ) -> None: postproc_config = postproc_worker_config or PostprocWorkerConfig() super().__init__( num_postprocess_workers=postproc_config.num_postprocess_workers, postprocess_tokenizer_dir=postproc_config.postprocess_tokenizer_dir, is_llm_executor=is_llm_executor, ) # inputs self._engine = engine self._executor_config = executor_config self._batched_logits_processor = batched_logits_processor self._postproc_worker_config = postproc_worker_config self._is_llm_executor = is_llm_executor self._hf_model_dir = hf_model_dir self._tokenizer = tokenizer self.llm_args = llm_args self.engine = None self.result_queue: Optional[IpcQueue] = None self.postproc_queues: Optional[List[IpcQueue]] = None self.rank = mpi_rank() self.global_rank = global_mpi_rank() # mapping: client_id -> GenerationResult self._results: Dict[int, GenerationResult] = {} # mapping: client_id from Proxy -> request_id returned from runtime backend self._client_id_to_request_id: Dict[int, int] = {} self._await_response_helper = AwaitResponseHelper(weakref.proxy(self)) self._backend = None if llm_args is None else llm_args.backend self._is_pytorch_backend = self._backend in ["pytorch", "_autodeploy"] self._lora_config = llm_args.lora_config if self._is_pytorch_backend else None if global_mpi_size() > 1: logger.set_rank(self.global_rank) def _configure_affinity(self, device_id): '''Probe and configure the CPU affinity of the worker based on NUMA topology. Args: device_id: The CUDA device ID to determine optimal CPU affinity. Note: If the process already has constrained affinity, a warning is logged. Configuration is handled as follows: TLLM_NUMA_AWARE_WORKER_AFFINITY = -> Affinity is automatically configured if it is unconstrained, and deleted if it is constrained externally by the user. TLLM_NUMA_AWARE_WORKER_AFFINITY = 1 -> Affinity is unconditionally auto-configured. TLLM_NUMA_AWARE_WORKER_AFFINITY = 0 or any other value -> Affinity is unconditionally _not_ auto-configured. ''' # Get the current affinity setting pid = os.getpid() process = psutil.Process(pid) cpu_affinity = process.cpu_affinity() all_cpus = list(range(psutil.cpu_count())) constrained_affinity = (cpu_affinity != all_cpus) numa_aware_affinity = os.environ.get("TLLM_NUMA_AWARE_WORKER_AFFINITY") # If affinity is constrained but the user hasn't explicitly # requested NUMA-aware affinity, remove the constraints. if constrained_affinity: logger.warning( f"Worker process {pid} is affined to run on the following CPUs: " f"{cpu_affinity} (subset of all logical CPUs). This may harm " f"performance if set incorrectly.") if numa_aware_affinity is None: logger.warning( f"Worker process {pid} has constrained CPU affinity " f"but `TLLM_NUMA_AWARE_WORKER_AFFINITY` is not set. " f"Removing CPU affinity constraints.") process.cpu_affinity(all_cpus) # If affinity is unconstrained and the user hasn't explicitly # prohibited it or the user has explicitly requested it, choose the # optimal affinity based upon the NUMA topology if ((numa_aware_affinity is None and not constrained_affinity) or (numa_aware_affinity == "1")): process.cpu_affinity(get_numa_aware_cpu_affinity(device_id)) logger.info( f"Worker process {pid} CPU affinity set to " f"{process.cpu_affinity()} for optimal NUMA-aware scheduling.") def _get_comm_ranks_device_id(self): device_id = self.global_rank % torch.cuda.device_count() torch.cuda.set_device(device_id) # Make sure C++ executor would use same devices/ranks as py_executor global_rank = global_mpi_rank() comm_ranks = mpi_comm().allgather(global_rank) device_ids = mpi_comm().allgather(device_id) self._configure_affinity(device_id) return comm_ranks, device_ids def setup_engine(self): """ Setup the engine for the worker. """ if isinstance(self._engine, list): self._engine = self._engine[self.rank] def _create_py_executor(): args = {} assert hasattr( self.llm_args, "backend" ), "llm_args should be with backend in _create_py_executor" _ = self._get_comm_ranks_device_id() if self._backend == "pytorch": from tensorrt_llm._torch.pyexecutor.py_executor_creator import \ create_py_executor create_executor = create_py_executor args["llm_args"] = self.llm_args args["checkpoint_dir"] = self._hf_model_dir args["tokenizer"] = self._tokenizer elif self._backend == "_autodeploy": from tensorrt_llm._torch.auto_deploy.llm_args import \ LlmArgs as ADLlmArgs from tensorrt_llm._torch.auto_deploy.shim.ad_executor import \ create_autodeploy_executor create_executor = create_autodeploy_executor assert isinstance(self.llm_args, ADLlmArgs) args["ad_config"] = self.llm_args args["tokenizer"] = self._tokenizer else: raise ValueError(f"Unsupported backend config: {self._backend}") # Define additional attributes that can be used later, such as in _deduce_max_tokens self.mapping = self.llm_args.parallel_config.to_mapping() self.checkpoint_loader = None if self._backend == "pytorch": from tensorrt_llm._torch.pyexecutor.model_loader import \ _construct_checkpoint_loader self.checkpoint_loader = _construct_checkpoint_loader( self.llm_args.backend, self.llm_args.checkpoint_loader, self.llm_args.checkpoint_format) self.max_seq_len = self.llm_args.max_seq_len # creare_py_executor may change some fields of llm_args _executor = create_executor(**args) if _executor.max_seq_len is not None: # max_seq_len might be updated by model engine as in create_py_executor self.max_seq_len = _executor.max_seq_len return _executor def _create_engine(executor_config): engine = self._engine if executor_config is None: executor_config = tllm.ExecutorConfig(1) executor_config.logits_post_processor_config = tllm.LogitsPostProcessorConfig( processor_batched=self._batched_logits_processor, replicate=False) comm_ranks, device_ids = self._get_comm_ranks_device_id() executor_config.parallel_config = tllm.ParallelConfig( participant_ids=comm_ranks, device_ids=device_ids) if isinstance(engine, Engine): return tllm.Executor(engine.engine, json.dumps(engine.config.to_dict(), cls=ConfigEncoder), tllm.ModelType.DECODER_ONLY, executor_config=executor_config, managed_weights=engine.managed_weights) assert not hasattr(executor_config, "backend") return tllm.Executor(engine, tllm.ModelType.DECODER_ONLY, executor_config) self.engine = _create_py_executor( ) if self.llm_args is not None else _create_engine( self._executor_config) self._lora_manager: Optional[LoraManager] = None self._prompt_adapter_manager: Optional[PromptAdapterManager] = None self._runtime_model_config: Optional[ModelConfig] = None if self.rank == 0 and isinstance(self.engine, tllm.Executor): if isinstance(self.engine, Engine): engine_config = self.engine.config else: engine_config = EngineConfig.from_json_file( f"{self._engine}/config.json") self._runtime_model_config = _engine_config_to_model_config( engine_config) if engine_config.build_config.plugin_config.lora_plugin: # TODO(azuker): Passing peft cache manager to LoraManager is used for LoRA optimization # (see LoraManager constructor docstring). Getting the peft cache manager from this # point in the TRT flow is currently not supported (it's at the CPP # Executor->ExecutorImpl->TrtGptModel->mPeftCacheManager) therefore for now this LoRA # optimization is not available in TRT-python flow. self._lora_manager = LoraManager( mapping=engine_config.pretrained_config.mapping, model_config=self._runtime_model_config, cpp_peft_cache_manager=None) if engine_config.build_config.max_prompt_embedding_table_size > 0: self._prompt_adapter_manager = PromptAdapterManager() if self._backend == "pytorch" and self._lora_config is not None: from tensorrt_llm._torch.pyexecutor.resource_manager import \ ResourceManagerType peft_cache_manager = self.engine.resource_manager.resource_managers.get( ResourceManagerType.PEFT_CACHE_MANAGER) self._lora_manager = peft_cache_manager.get_lora_manager() lora_model_config = self.engine.model_engine.lora_model_config assert lora_model_config is not None self._lora_model_config = lora_model_config def await_responses(self, timeout: Optional[float] = None) -> list: return self.engine.await_responses(timeout=datetime.timedelta( seconds=timeout) if timeout is not None else None) def fetch_stats(self) -> list: if isinstance(self.engine, tllm.Executor): iter_stats = self.engine.get_latest_iteration_stats() #TODO: Support req stats with TRT engine # This would require ensuring iter and req stats have same size return [(iter_stat, None) for iter_stat in iter_stats] else: return self.engine.get_latest_iteration_stats() def fetch_kv_cache_events(self) -> list: if isinstance(self.engine, tllm.Executor): return self.engine.get_latest_kv_cache_events() else: return self.engine.get_latest_kv_cache_events() def set_result_queue(self, queue): """In multi-gpu mode, result_queue will be set here to communicate between the proxy and the worker 0 process.""" assert self.postproc_queues is None self.result_queue = queue def set_postproc_queues(self, queues: List["IpcQueue"]): """ Set the IPC queues for feeding post-processing processes. """ assert self.result_queue is None self.postproc_queues = queues def _set_iteration_result_queue(self, it_result_queue: IterationResultQueue, queue: Union[Queue, FusedIpcQueue, IntraProcessQueue]): assert not it_result_queue.is_initialized, "Iteration result queue should not already be initialized." it_result_queue.is_initialized = True it_result_queue.queue = queue it_result_queue.aqueue = None def return_queue(self, client_id: int): """ If a centralized result queue is registered (used for communication with the proxy) send the message there. Otherwise, push the result directly in the GenerationResult queue. """ if self.result_queue is not None: return self.result_queue return self._results[client_id].queue def abort_request(self, client_id: int) -> None: # NOTE: the request_id is the request_id generated by cpp runtime, not the client_id if self.engine.can_enqueue_requests(): request_id = self._client_id_to_request_id.get(client_id, None) if request_id is None: logger.warning( f"Request of client_id {client_id} is finished, cannot abort it." ) return self.engine.cancel_request(request_id) def _engine_response_callback(self, response: tllm.Response): return response def _has_background_error(self) -> bool: return not self._error_queue.empty() def _create_error_response(self, response: tllm.Response) -> ErrorResponse: bck_error = self._error_queue.get_nowait() assert isinstance(bck_error, Exception) return ErrorResponse(response.client_id, str(bck_error), response.request_id) def start(self): raise NotImplementedError( "start method is not implemented in BaseWorker") def _load_lora_adapter(self, lora_request: LoRARequest) -> bool: """Returns True if the adapter was loaded by this call, False if it was already loaded""" adapter_id = str(lora_request.adapter_id) newly_loaded_uids = self._lora_manager.load_from_ckpt( [lora_request.path], model_config=self._runtime_model_config if self._runtime_model_config is not None else self._lora_model_config, uids=[adapter_id], ckpt_source=lora_request.ckpt_source) return adapter_id in newly_loaded_uids def _load_prompt_adapter(self, prompt_adapter_request: PromptAdapterRequest): self._prompt_adapter_manager.load_from_ckpt( [prompt_adapter_request.local_path], model_config=self._runtime_model_config, uids=[str(prompt_adapter_request.adapter_id)]) def _enqueue_request(self, request: GenerationRequest, result_wait_queue=None) -> int: assert request.id is not None py_lora_path = None if self._lora_manager is not None and request.lora_request is not None: adapter_in_cache = self._lora_manager.is_adapter_in_cpu_cache( request.lora_request.adapter_id) self._load_lora_adapter(request.lora_request) uid = str(request.lora_request.adapter_id) lora_config = tllm.LoraConfig( task_id=request.lora_request.adapter_id, weights=self._lora_manager.cpp_lora_weights[uid] if not adapter_in_cache else None, config=self._lora_manager.cpp_lora_config[uid]) py_lora_path = request.lora_request.lora_path else: lora_config = None prompt_token_ids = copy.deepcopy(request.prompt_token_ids) prompt_tuning_config = None if request.prompt_adapter_request is not None: self._load_prompt_adapter(request.prompt_adapter_request) uid = str(request.prompt_adapter_request.adapter_id) prompt_tuning_config = tllm.PromptTuningConfig( self._prompt_adapter_manager.uid_to_weights[uid]) vocab_size = self._runtime_model_config.vocab_size pa_length = prompt_tuning_config.embedding_table.size(0) prompt_token_ids = list(range( vocab_size, vocab_size + pa_length)) + prompt_token_ids # MULTIMODAL # NOTE: Since, we only support PyTorch backend for multimodal, we will send multimodal_data through the 'py_multimodal_data' field # except `multimodal_input` as it needs to go through the C++ runtime. multimodal_input = None if request.multimodal_params is not None and request.multimodal_params.has_content( ): if request.multimodal_params.multimodal_input is not None: multimodal_input = tllm.MultimodalInput( multimodal_hashes=request.multimodal_params. multimodal_input.multimodal_hashes, multimodal_positions=request.multimodal_params. multimodal_input.multimodal_positions, multimodal_lengths=request.multimodal_params. multimodal_input.multimodal_lengths) # NOTE: Setting to None here to avoid sending multimodal_input again through the 'py_multimodal_data' field request.multimodal_params.multimodal_input = None context_phase_params = None request_type = tllm.RequestType.REQUEST_TYPE_CONTEXT_AND_GENERATION if request.disaggregated_params is not None: assert ( not self._is_pytorch_backend or self.engine.kv_cache_transceiver is not None or request.disaggregated_params.request_type == "context_and_generation" ), "kv_cache_transceiver is disabled, please set 'cache_transceiver_config: backend:` in config file for disaggregated serving" request_type = request.disaggregated_params.get_request_type() if request_type == tllm.RequestType.REQUEST_TYPE_GENERATION_ONLY: context_phase_params = request.disaggregated_params.get_context_phase_params( ) if self._is_pytorch_backend: if not self.llm_args.disable_overlap_scheduler: is_disaggregated = self.engine.kv_cache_transceiver is not None if is_disaggregated and ( request_type == tllm.RequestType.REQUEST_TYPE_CONTEXT_ONLY): raise ValueError( "Context only requests are not supported in pytorch backend when overlap is enabled." ) assert request.id is not None def _deduce_max_tokens(request: GenerationRequest, executor_config: tllm.ExecutorConfig, llm_args: Optional[BaseLlmArgs] = None) -> int: # deduce max_tokens when it's not set by user max_tokens = request.sampling_params.max_tokens query_token_len = len( request.query_token_ids) if request.query_token_ids else 0 cp_size = 1 max_seq_len = None if llm_args is not None: # deduce max_tokens by llm args assert executor_config is None, "An empty executor_config in _deduce_max_tokens is expected when LLM arguments are defined." if hasattr(self, "mapping") and self.mapping.cp_size is not None: cp_size = self.mapping.cp_size max_seq_len = getattr(self, "max_seq_len", None) else: # deduce max_tokens by executor config if hasattr(executor_config, "mapping" ) and executor_config.mapping.cp_size is not None: cp_size = executor_config.mapping.cp_size max_seq_len = getattr(executor_config, "max_seq_len", None) if max_seq_len is None: logger.warning("`default_max_tokens` cannot be deduced") if max_tokens is None: raise ValueError( "`max_tokens` must be set when `default_max_tokens` cannot be deduced" ) else: # use max_tokens if can't deduce default_max_tokens return max_tokens if executor_config is not None: assert ( len(prompt_token_ids) <= executor_config.max_seq_len ), f"`prompt_token_ids` length ({len(prompt_token_ids)}) is greater than `max_seq_len` ({executor_config.max_seq_len})" splited_prompt_len = int(len(prompt_token_ids) / cp_size) default_max_tokens = max_seq_len - splited_prompt_len - query_token_len if default_max_tokens <= 0: # Raise error on `default_max_tokens` not enough, since max_tokens should be less than `default_max_tokens`` raise ValueError( f"`default_max_tokens` ({default_max_tokens}) must be greater than 0, " f"`default_max_tokens` ({default_max_tokens}) = max_seq_len ({max_seq_len})" f" - `splited_prompt_len` ({splited_prompt_len}) - `query_token_len` ({query_token_len})" ) # default_max_tokens is the biggest available value if max_tokens is None: return default_max_tokens elif max_tokens > default_max_tokens and default_max_tokens > 0: logger.warning( f"User-specified `max_tokens` ({max_tokens}) is greater than deduced " f"`default_max_tokens` ({default_max_tokens}), using default_max_tokens instead." ) return default_max_tokens elif max_tokens <= 0: raise ValueError( f"`max_tokens` ({max_tokens}) must be greater than 0") else: return max_tokens try: executor_request = tllm.Request( client_id=request.id, input_token_ids=prompt_token_ids, max_tokens=_deduce_max_tokens( request, self._executor_config if not self.llm_args else None, self.llm_args), streaming=request.streaming, sampling_config=request.sampling_params._get_sampling_config(), end_id=-1 if request.sampling_params.ignore_eos else request.sampling_params.end_id, pad_id=request.sampling_params.pad_id, output_config=request.sampling_params._get_output_config( is_pytorch_backend=self._is_pytorch_backend), # Beam search enforces return_all_generated_tokens=True regardless of the passed value return_all_generated_tokens=False, # convert python config into pybind config lookahead_config=PybindMirror.maybe_to_pybind( request.sampling_params.lookahead_config), guided_decoding_params=request.sampling_params. _get_guided_decoding_params(), bad_words=request.sampling_params._get_bad_words(), stop_words=[] if request.sampling_params.ignore_eos else request.sampling_params._get_stop_words(), embedding_bias=request.sampling_params.embedding_bias, lora_config=lora_config, prompt_tuning_config=prompt_tuning_config, multimodal_input=multimodal_input, # NOTE: `multimodal_embedding` and `mrope_config` will be in MultimodalParams.multimodal_data. And this will be handled below by `py_multimodal_data`. multimodal_embedding=None, mrope_config=None, logits_post_processor_name=( tllm.Request.BATCHED_POST_PROCESSOR_NAME if request.sampling_params.apply_batched_logits_processor else None), logits_post_processor=None if self._is_pytorch_backend else request.sampling_params.logits_processor, kv_cache_retention_config=request.kv_cache_retention_config, context_phase_params=context_phase_params, type=request_type, cache_salt_id=request.cache_salt_id) executor_request.py_num_logprobs = request.sampling_params.logprobs executor_request.py_lora_path = py_lora_path if self._is_pytorch_backend and request.multimodal_params is not None: if request.multimodal_params.multimodal_data is not None: # NOTE: Deserialize SharedTensor handle to actual tensor request.multimodal_params.to_tensor("multimodal_data") executor_request.py_multimodal_data = request.multimodal_params.multimodal_data if self._is_pytorch_backend and request.sampling_params.logits_processor: # For PyTorch backend, we attach logits processors as a dynamic Python attribute # instead of using the C++ binding, since the latter will cause PyCapsule pickling issues. lp = request.sampling_params.logits_processor executor_request.py_logits_post_processors = lp if isinstance( lp, list) else [lp] executor_request.py_scheduling_params = None if self._is_pytorch_backend and request.scheduling_params is not None: executor_request.py_scheduling_params = request.scheduling_params if request.arrival_time is not None: executor_request.py_arrival_time = request.arrival_time if request.query_token_ids is not None: # pytorch star attention workflow # a workaround to avoid public interface update if self._is_pytorch_backend and result_wait_queue is not None: req_id = self.engine.enqueue_request( executor_request, request.query_token_ids, result_wait_queue=result_wait_queue) else: req_id = self.engine.enqueue_request( executor_request, request.query_token_ids) else: if self._is_pytorch_backend and result_wait_queue is not None: req_id = self.engine.enqueue_request( executor_request, result_wait_queue=result_wait_queue) else: req_id = self.engine.enqueue_request(executor_request) return req_id except Exception as e: raise RequestError(str(e)) from e def submit(self, request: GenerationRequest) -> GenerationResult: """ Low-level API to the executor. Return a "future" GenerationResult which can be waited. """ self.start() if self.rank != 0: raise RuntimeError( "Only rank 0 can submit requests.\n" "To fix this, ensure that the llm.generate(...) method is " "guarded with the `if __name__ == '__main__':` block.") client_id = request.id if request.id is not None else self._get_next_client_id( ) if request.id is None: request.set_id(client_id) logprob_params = self._get_logprob_params(request) result = GenerationResult( request, background_error_handler=self._handle_background_error, executor=self, disaggregated_params=request.disaggregated_params, logprob_params=logprob_params) self._results[client_id] = result request_id = self._enqueue_request(request) # request_id returned from backend is necessary for the abort_request method. self._client_id_to_request_id[client_id] = request_id self._handle_background_error() return result def shutdown(self): if self.doing_shutdown: return else: self.doing_shutdown = True if self.engine is not None and self.engine.can_enqueue_requests(): self.engine.shutdown() self.engine = None # Define a Callable to join iteration and request stats @staticmethod def _stats_serializer( stats: Tuple[tllm.IterationStats, tllm.RequestStats]) -> str: iteration_stats, req_stats = stats stats_dict = json.loads(iteration_stats.to_json_str()) if req_stats is not None and len(req_stats) > 0: stats_dict["requestStats"] = [] for req_stat in req_stats: stats_dict["requestStats"].append( json.loads(req_stat.to_json_str())) # Convert back to JSON string return json.dumps(stats_dict) # Define a Callable to serialize KV cache events @staticmethod def _kv_cache_events_serializer(events) -> str: from .._utils import KVCacheEventSerializer return json.dumps(KVCacheEventSerializer.serialize(events)) def _pop_result(self, client_id: int): self._results.pop(client_id, None) self._client_id_to_request_id.pop(client_id, None) def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback) -> bool: self.shutdown() return exc_type is None or exc_type == self.WorkerExit def __del__(self): self.shutdown() class AwaitResponseHelper: ''' Multiple-implementations for await_response for performance. ''' class HandlerKind(enum.Enum): unknown = 0 single_process_worker = 1 ipc_batched = 2 def __init__(self, worker: "BaseWorker"): # TODO: make worker weakref self.worker = worker self.handler_kind: AwaitResponseHelper.HandlerKind = AwaitResponseHelper.HandlerKind.unknown self.enable_postprocprocess_parallel = self.worker.enable_postprocess_parallel # The error responses when submit request failed will be put here self.temp_error_responses = Queue() def responses_handler(self, responses: List[tllm.Response]): HandlerKind = AwaitResponseHelper.HandlerKind if self.handler_kind is HandlerKind.unknown: if not (self.worker.result_queue is not None or self.worker.postproc_queues is not None): logger_debug(f"creating await_response helper for Worker\n", color="yellow") # When ExecutorBindingWorker is used in the main process # aka the single process mode self.handler_kind = HandlerKind.single_process_worker elif self.worker.result_queue is not None or self.worker.postproc_queues is not None: # The ExecutorBindingProxy is used logger_debug(f"creating await_response helper for IPC\n", color="yellow") self.handler_kind = HandlerKind.ipc_batched else: raise NotImplementedError match self.handler_kind: case HandlerKind.single_process_worker: return self.handle_for_worker(responses) case HandlerKind.ipc_batched: return self.handle_for_ipc_batched(responses) case _: raise NotImplementedError def __call__(self, timeout: Optional[float] = None) -> bool: ''' This method should be called by a ManagedThread. ''' timeout = timeout or 0.1 responses = self.worker.engine.await_responses( timeout=datetime.timedelta(seconds=timeout)) # filter since The _engine_response_callback may return None responses = list( filter( lambda _: _, [self.worker._engine_response_callback(r) for r in responses])) # append the error responses to the temp_error_responses while not self.temp_error_responses.empty(): responses.append(self.temp_error_responses.get()) with nvtx_range_debug(f"await_response-{len(responses)}", color="red", category="Worker"): self.responses_handler(responses) return True def handle_for_worker(self, responses: List[tllm.Response]) -> None: ''' Return the responses to asyncio.event_loop. ''' event_loop = None async_queues = [] for response in responses: assert response is not None queue = self.worker.return_queue(response.client_id) if not response.has_error(): response = _maybe_wrap_response(self.worker, response, self.worker._is_pytorch_backend) # For AsyncQueue.sync_q, we will batch the events to avoid too many # event notifications, thus put without wait here. if isinstance(queue, _SyncQueue): global_tracer().log_instant("worker-rsp.put") queue.put_nowait(response) async_queues.append(queue) # all the loops are identical event_loop = event_loop or queue.loop else: queue.put(response) if response.has_error() or response.result.is_final: self.worker._pop_result(response.client_id) # Notify the events in bulk for performance. if async_queues: _SyncQueue.notify_many(event_loop, async_queues) def handle_for_ipc_batched(self, responses: List[tllm.Response]) -> None: ''' Perform the IPC in batch explicitly. ''' postproc_batches = [ [] for _ in range(self.worker.postproc_config.num_postprocess_workers) ] if self.enable_postprocprocess_parallel else None rsp_batch = [] if not self.enable_postprocprocess_parallel else None for response in responses: if isinstance(response, ErrorResponse): pass # send ErrorResponse directly elif self.worker._has_background_error(): response = self.worker._create_error_response(response) elif response.has_error(): # Convert to ErrorResponse, because tllm.Response cannot be # serialized when it has error. response = ErrorResponse(response.client_id, response.error_msg, response.request_id) else: response = _maybe_wrap_response(self.worker, response, self.worker._is_pytorch_backend) _send_rsp(self.worker, response, postproc_batches=postproc_batches, rsp_batch=rsp_batch) if postproc_batches: for wid, batch in enumerate(postproc_batches): if batch: self.worker.postproc_queues[wid].put(batch) if rsp_batch: self.worker.result_queue.put(rsp_batch) def _get_params_for_first_rsp( worker, client_id) -> Tuple[Optional[SamplingParams], Optional[PostprocParams]]: res = worker._results.get(client_id, None) assert res is not None if not res._params_transmitted: res._params_transmitted = True return res.sampling_params, res.postproc_params return None, None def _compute_pytorch_prompt_logprobs( generation_result: GenerationResult, response: LlmResponse) -> Optional[LogProbsResult]: """Compute prompt logprobs for PyTorch backend (cached when streaming) """ logprob_params = generation_result._logprob_params # should be present and non None assert logprob_params is not None if generation_result._streaming: cached = getattr(generation_result, '_cached_prompt_logprobs', None) if cached is not None: return LogProbsResult( prompt=cached, generation=None ) # generation logprobs, if requested, is provided directly in response.result.log_probs from the sampler. context_logits = response.result.context_logits assert context_logits is not None, "context_logits cannot be None when prompt_logprobs is requested." logprobs_result = compute_logprobs(logprob_params.prompt_logprobs, None, context_logits, None, None) if generation_result._streaming: generation_result._cached_prompt_logprobs = logprobs_result.prompt return logprobs_result def _get_logprobs(worker, response: Union[tllm.Response, LlmResponse], is_pytorch_backend=False) -> Optional[LogProbsResult]: """Compute logprobs from response logits when needed. Logprobs provenance varies by backend: - PyTorch: Generation logprobs computed in sampler, only prompt logprobs computed here - TRT: Both prompt and generation logprobs computed here from logits """ logprobs_result = None generation_result = worker._results.get(response.client_id, None) if not generation_result: return None logprob_params = getattr(generation_result, "_logprob_params", None) if logprob_params: if is_pytorch_backend: if not logprob_params.prompt_logprobs: # PyTorch: generation logprobs computed in sampler, no post-processing needed return None else: logprobs_result = _compute_pytorch_prompt_logprobs( generation_result, response) if logprob_params.drop_context_logits: response.clear_context_logits() return logprobs_result # TRT backend: compute both prompt and generation logprobs from logits logprobs_result = compute_logprobs(logprob_params.prompt_logprobs, logprob_params.logprobs, response.result.context_logits, response.result.generation_logits, response.result.output_token_ids[0]) if logprob_params.drop_context_logits: response.clear_context_logits() if logprob_params.drop_generation_logits: response.clear_generation_logits() if response.result.is_final: generation_result.clear_logprob_params() return logprobs_result def _send_rsp( worker, response: Union[tllm.Response, ResponseWrapper, ErrorResponse], postproc_batches: Optional[List[List["PostprocWorker.Input"]]] = None, rsp_batch: Optional[List[tllm.Response]] = None): # if postproc_batches is set, append to batch instead of putting to IpcQueue if worker.result_queue is not None: if rsp_batch is not None: rsp_batch.append(response) else: worker.result_queue.put(response) else: sampling_params, postproc_params = _get_params_for_first_rsp( worker, response.client_id) inp = PostprocWorker.Input( response, # sampling_params is necessary for creating fake GenerationResult # instances in the postproc processes. They are for incremental # detokenize. They should be transmitted only once for each # Request. sampling_params=sampling_params, postproc_params=postproc_params, streaming=worker._results.get(response.client_id, None)._streaming) pid = response.client_id % worker.postproc_config.num_postprocess_workers if not postproc_batches: # Group the responses into buckets for the postprocessing steps. # Bucketing is used instead of random dispatching because the # incremental detokenization during postprocessing relies on the # prior CompletionOutput of a given request. worker.postproc_queues[pid].put(inp) else: postproc_batches[pid].append(inp) # Eliminate the finished GenerationRequest instances timely, which may # take considerable memory. if is_llm_response(response): if response.has_error() or response.result.is_final: worker._pop_result(response.client_id) elif isinstance(response, ErrorResponse): worker._pop_result(response.client_id) else: raise ValueError(f"Unknown response type: {response}") def _get_metrics_dict( response: tllm.Response) -> dict[RequestEventTiming, float]: req_perf_metrics, metrics_dict = None, {} res = response.result if res: if hasattr(res, '_result'): if result := res.get_result(): req_perf_metrics = result.request_perf_metrics else: req_perf_metrics = res.request_perf_metrics if req_perf_metrics and req_perf_metrics.timing_metrics: metrics_dict = { RequestEventTiming.ARRIVAL_TIME: req_perf_metrics.timing_metrics.arrival_time.total_seconds(), RequestEventTiming.FIRST_TOKEN_TIME: req_perf_metrics.timing_metrics.first_token_time.total_seconds( ), RequestEventTiming.FIRST_SCHEDULED_TIME: req_perf_metrics.timing_metrics.first_scheduled_time. total_seconds(), RequestEventTiming.LAST_TOKEN_TIME: req_perf_metrics.timing_metrics.last_token_time.total_seconds(), RequestEventTiming.KV_CACHE_TRANSFER_START: req_perf_metrics.timing_metrics.kv_cache_transfer_start. total_seconds(), RequestEventTiming.KV_CACHE_TRANSFER_END: req_perf_metrics.timing_metrics.kv_cache_transfer_end. total_seconds(), RequestEventTiming.KV_CACHE_SIZE: req_perf_metrics.timing_metrics.kv_cache_size, } return metrics_dict def _maybe_wrap_response( worker, response: tllm.Response, is_pytorch_backend=False) -> Union[tllm.Response, ResponseWrapper]: logprobs_result = _get_logprobs(worker, response, is_pytorch_backend) req_perf_metrics = _get_metrics_dict(response) if logprobs_result or req_perf_metrics: response = ResponseWrapper(response, logprobs_result, req_perf_metrics) return response