TensorRT-LLMs/tensorrt_llm/executor/base_worker.py
JadoTu 82aaf98070
[None][feat] add the eos tokens in generation config to stop words in the sampler (#10389)
Signed-off-by: jiant <107457950+JadoTu@users.noreply.github.com>
2026-01-06 09:24:03 +08:00

993 lines
44 KiB
Python

import copy
import datetime
import enum
import json
import os
import weakref
from pathlib import Path
from queue import Queue
from typing import Dict, List, Optional, Tuple, Union
import psutil
import torch
from tensorrt_llm.logger import logger
from .._torch.pyexecutor.llm_request import LlmResponse
from .._utils import (global_mpi_rank, global_mpi_size, mpi_comm, mpi_rank,
nvtx_range_debug)
from ..bindings import executor as tllm
from ..builder import ConfigEncoder, Engine, EngineConfig
from ..llmapi.llm_args import BaseLlmArgs, PybindMirror
from ..llmapi.tokenizer import TokenizerBase
from ..llmapi.tracer import global_tracer
from ..llmapi.utils import _SyncQueue, get_numa_aware_cpu_affinity, logger_debug
from ..lora_manager import LoraManager
from ..metrics import RequestEventTiming
from ..prompt_adapter_manager import PromptAdapterManager
from ..runtime import ModelConfig
from ..runtime.model_runner import _engine_config_to_model_config
from ..sampling_params import BatchedLogitsProcessor, SamplingParams
from .executor import GenerationExecutor, IterationResultQueue
from .ipc import FusedIpcQueue, IpcQueue
from .postproc_worker import (PostprocParams, PostprocWorker,
PostprocWorkerConfig)
from .request import GenerationRequest, LoRARequest, PromptAdapterRequest
from .result import (GenerationResult, LogProbsResult, ResponseWrapper,
compute_logprobs)
from .utils import (ErrorResponse, IntraProcessQueue, RequestError,
is_llm_response)
__all__ = [
"BaseWorker",
"_init_hf_modules",
]
def _init_hf_modules():
"""Initialize cached HuggingFace modules for models with trust_remote_code=True.
This is safe to call multiple times (idempotent) and should be called:
1. At module import time (for main process and spawned subprocesses)
2. At worker_main entry (for forked processes or external MPI ranks)
References: https://github.com/vllm-project/vllm/pull/871
"""
try:
from transformers.dynamic_module_utils import init_hf_modules
init_hf_modules()
logger.debug("HF modules initialized")
except ImportError as e:
logger.warning(f"ImportError initializing HF modules: {e}")
except Exception as e:
logger.error(f"Exception initializing HF modules: {e}")
_init_hf_modules()
class BaseWorker(GenerationExecutor):
class WorkerExit(GeneratorExit):
pass
def __init__(
self,
engine: Union[Path, Engine],
executor_config: Optional[tllm.ExecutorConfig] = None,
batched_logits_processor: Optional[BatchedLogitsProcessor] = None,
postproc_worker_config: Optional[PostprocWorkerConfig] = None,
is_llm_executor: Optional[bool] = None,
hf_model_dir: Optional[Path] = None,
tokenizer: Optional[TokenizerBase] = None,
llm_args: Optional[BaseLlmArgs] = None,
) -> None:
postproc_config = postproc_worker_config or PostprocWorkerConfig()
super().__init__(
num_postprocess_workers=postproc_config.num_postprocess_workers,
postprocess_tokenizer_dir=postproc_config.postprocess_tokenizer_dir,
is_llm_executor=is_llm_executor,
)
# inputs
self._engine = engine
self._executor_config = executor_config
self._batched_logits_processor = batched_logits_processor
self._postproc_worker_config = postproc_worker_config
self._is_llm_executor = is_llm_executor
self._hf_model_dir = hf_model_dir
self._tokenizer = tokenizer
self.llm_args = llm_args
self.engine = None
self.result_queue: Optional[IpcQueue] = None
self.postproc_queues: Optional[List[IpcQueue]] = None
self.rank = mpi_rank()
self.global_rank = global_mpi_rank()
# mapping: client_id -> GenerationResult
self._results: Dict[int, GenerationResult] = {}
# mapping: client_id from Proxy -> request_id returned from runtime backend
self._client_id_to_request_id: Dict[int, int] = {}
self._await_response_helper = AwaitResponseHelper(weakref.proxy(self))
self._backend = None if llm_args is None else llm_args.backend
self._is_pytorch_backend = self._backend in ["pytorch", "_autodeploy"]
self._lora_config = llm_args.lora_config if self._is_pytorch_backend else None
if global_mpi_size() > 1:
logger.set_rank(self.global_rank)
def _configure_affinity(self, device_id):
'''Probe and configure the CPU affinity of the worker based on NUMA topology.
Args:
device_id: The CUDA device ID to determine optimal CPU affinity.
Note:
If the process already has constrained affinity, a warning is logged.
Configuration is handled as follows:
TLLM_NUMA_AWARE_WORKER_AFFINITY = <unset>
-> Affinity is automatically configured if it is unconstrained,
and deleted if it is constrained externally by the user.
TLLM_NUMA_AWARE_WORKER_AFFINITY = 1
-> Affinity is unconditionally auto-configured.
TLLM_NUMA_AWARE_WORKER_AFFINITY = 0 or any other value
-> Affinity is unconditionally _not_ auto-configured.
'''
# Get the current affinity setting
pid = os.getpid()
process = psutil.Process(pid)
cpu_affinity = process.cpu_affinity()
all_cpus = list(range(psutil.cpu_count()))
constrained_affinity = (cpu_affinity != all_cpus)
numa_aware_affinity = os.environ.get("TLLM_NUMA_AWARE_WORKER_AFFINITY")
# If affinity is constrained but the user hasn't explicitly
# requested NUMA-aware affinity, remove the constraints.
if constrained_affinity:
logger.warning(
f"Worker process {pid} is affined to run on the following CPUs: "
f"{cpu_affinity} (subset of all logical CPUs). This may harm "
f"performance if set incorrectly.")
if numa_aware_affinity is None:
logger.warning(
f"Worker process {pid} has constrained CPU affinity "
f"but `TLLM_NUMA_AWARE_WORKER_AFFINITY` is not set. "
f"Removing CPU affinity constraints.")
process.cpu_affinity(all_cpus)
# If affinity is unconstrained and the user hasn't explicitly
# prohibited it or the user has explicitly requested it, choose the
# optimal affinity based upon the NUMA topology
if ((numa_aware_affinity is None and not constrained_affinity)
or (numa_aware_affinity == "1")):
process.cpu_affinity(get_numa_aware_cpu_affinity(device_id))
logger.info(
f"Worker process {pid} CPU affinity set to "
f"{process.cpu_affinity()} for optimal NUMA-aware scheduling.")
def _get_comm_ranks_device_id(self):
device_id = self.global_rank % torch.cuda.device_count()
torch.cuda.set_device(device_id)
# Make sure C++ executor would use same devices/ranks as py_executor
global_rank = global_mpi_rank()
comm_ranks = mpi_comm().allgather(global_rank)
device_ids = mpi_comm().allgather(device_id)
self._configure_affinity(device_id)
return comm_ranks, device_ids
def setup_engine(self):
"""
Setup the engine for the worker.
"""
if isinstance(self._engine, list):
self._engine = self._engine[self.rank]
def _create_py_executor():
args = {}
assert hasattr(
self.llm_args, "backend"
), "llm_args should be with backend in _create_py_executor"
_ = self._get_comm_ranks_device_id()
if self._backend == "pytorch":
from tensorrt_llm._torch.pyexecutor.py_executor_creator import \
create_py_executor
create_executor = create_py_executor
args["llm_args"] = self.llm_args
args["checkpoint_dir"] = self._hf_model_dir
args["tokenizer"] = self._tokenizer
elif self._backend == "_autodeploy":
from tensorrt_llm._torch.auto_deploy.llm_args import \
LlmArgs as ADLlmArgs
from tensorrt_llm._torch.auto_deploy.shim.ad_executor import \
create_autodeploy_executor
create_executor = create_autodeploy_executor
assert isinstance(self.llm_args, ADLlmArgs)
args["ad_config"] = self.llm_args
args["tokenizer"] = self._tokenizer
else:
raise ValueError(f"Unsupported backend config: {self._backend}")
# Define additional attributes that can be used later, such as in _deduce_max_tokens
self.mapping = self.llm_args.parallel_config.to_mapping()
self.checkpoint_loader = None
if self._backend == "pytorch":
from tensorrt_llm._torch.pyexecutor.model_loader import \
_construct_checkpoint_loader
self.checkpoint_loader = _construct_checkpoint_loader(
self.llm_args.backend, self.llm_args.checkpoint_loader,
self.llm_args.checkpoint_format)
self.max_seq_len = self.llm_args.max_seq_len
# creare_py_executor may change some fields of llm_args
_executor = create_executor(**args)
if _executor.max_seq_len is not None:
# max_seq_len might be updated by model engine as in create_py_executor
self.max_seq_len = _executor.max_seq_len
return _executor
def _create_engine(executor_config):
engine = self._engine
if executor_config is None:
executor_config = tllm.ExecutorConfig(1)
executor_config.logits_post_processor_config = tllm.LogitsPostProcessorConfig(
processor_batched=self._batched_logits_processor,
replicate=False)
comm_ranks, device_ids = self._get_comm_ranks_device_id()
executor_config.parallel_config = tllm.ParallelConfig(
participant_ids=comm_ranks, device_ids=device_ids)
if isinstance(engine, Engine):
return tllm.Executor(engine.engine,
json.dumps(engine.config.to_dict(),
cls=ConfigEncoder),
tllm.ModelType.DECODER_ONLY,
executor_config=executor_config,
managed_weights=engine.managed_weights)
assert not hasattr(executor_config, "backend")
return tllm.Executor(engine, tllm.ModelType.DECODER_ONLY,
executor_config)
self.engine = _create_py_executor(
) if self.llm_args is not None else _create_engine(
self._executor_config)
self._lora_manager: Optional[LoraManager] = None
self._prompt_adapter_manager: Optional[PromptAdapterManager] = None
self._runtime_model_config: Optional[ModelConfig] = None
if self.rank == 0 and isinstance(self.engine, tllm.Executor):
if isinstance(self.engine, Engine):
engine_config = self.engine.config
else:
engine_config = EngineConfig.from_json_file(
f"{self._engine}/config.json")
self._runtime_model_config = _engine_config_to_model_config(
engine_config)
if engine_config.build_config.plugin_config.lora_plugin:
# TODO(azuker): Passing peft cache manager to LoraManager is used for LoRA optimization
# (see LoraManager constructor docstring). Getting the peft cache manager from this
# point in the TRT flow is currently not supported (it's at the CPP
# Executor->ExecutorImpl->TrtGptModel->mPeftCacheManager) therefore for now this LoRA
# optimization is not available in TRT-python flow.
self._lora_manager = LoraManager(
mapping=engine_config.pretrained_config.mapping,
model_config=self._runtime_model_config,
cpp_peft_cache_manager=None)
if engine_config.build_config.max_prompt_embedding_table_size > 0:
self._prompt_adapter_manager = PromptAdapterManager()
if self._backend == "pytorch" and self._lora_config is not None:
from tensorrt_llm._torch.pyexecutor.resource_manager import \
ResourceManagerType
peft_cache_manager = self.engine.resource_manager.resource_managers.get(
ResourceManagerType.PEFT_CACHE_MANAGER)
self._lora_manager = peft_cache_manager.get_lora_manager()
lora_model_config = self.engine.model_engine.lora_model_config
assert lora_model_config is not None
self._lora_model_config = lora_model_config
def await_responses(self, timeout: Optional[float] = None) -> list:
return self.engine.await_responses(timeout=datetime.timedelta(
seconds=timeout) if timeout is not None else None)
def fetch_stats(self) -> list:
if isinstance(self.engine, tllm.Executor):
iter_stats = self.engine.get_latest_iteration_stats()
#TODO: Support req stats with TRT engine
# This would require ensuring iter and req stats have same size
return [(iter_stat, None) for iter_stat in iter_stats]
else:
return self.engine.get_latest_iteration_stats()
def fetch_kv_cache_events(self) -> list:
if isinstance(self.engine, tllm.Executor):
return self.engine.get_latest_kv_cache_events()
else:
return self.engine.get_latest_kv_cache_events()
def set_result_queue(self, queue):
"""In multi-gpu mode, result_queue will be set here to communicate between the proxy and the worker 0 process."""
assert self.postproc_queues is None
self.result_queue = queue
def set_postproc_queues(self, queues: List["IpcQueue"]):
""" Set the IPC queues for feeding post-processing processes. """
assert self.result_queue is None
self.postproc_queues = queues
def _set_iteration_result_queue(self, it_result_queue: IterationResultQueue,
queue: Union[Queue, FusedIpcQueue,
IntraProcessQueue]):
assert not it_result_queue.is_initialized, "Iteration result queue should not already be initialized."
it_result_queue.is_initialized = True
it_result_queue.queue = queue
it_result_queue.aqueue = None
def return_queue(self, client_id: int):
""" If a centralized result queue is registered (used for communication with the proxy)
send the message there.
Otherwise, push the result directly in the GenerationResult queue.
"""
if self.result_queue is not None:
return self.result_queue
return self._results[client_id].queue
def abort_request(self, client_id: int) -> None:
# NOTE: the request_id is the request_id generated by cpp runtime, not the client_id
if self.engine.can_enqueue_requests():
request_id = self._client_id_to_request_id.get(client_id, None)
if request_id is None:
logger.warning(
f"Request of client_id {client_id} is finished, cannot abort it."
)
return
self.engine.cancel_request(request_id)
def _engine_response_callback(self, response: tllm.Response):
return response
def _has_background_error(self) -> bool:
return not self._error_queue.empty()
def _create_error_response(self, response: tllm.Response) -> ErrorResponse:
bck_error = self._error_queue.get_nowait()
assert isinstance(bck_error, Exception)
return ErrorResponse(response.client_id, str(bck_error),
response.request_id)
def start(self):
raise NotImplementedError(
"start method is not implemented in BaseWorker")
def _load_lora_adapter(self, lora_request: LoRARequest) -> bool:
"""Returns True if the adapter was loaded by this call, False if it was already loaded"""
adapter_id = str(lora_request.adapter_id)
newly_loaded_uids = self._lora_manager.load_from_ckpt(
[lora_request.path],
model_config=self._runtime_model_config if
self._runtime_model_config is not None else self._lora_model_config,
uids=[adapter_id],
ckpt_source=lora_request.ckpt_source)
return adapter_id in newly_loaded_uids
def _load_prompt_adapter(self,
prompt_adapter_request: PromptAdapterRequest):
self._prompt_adapter_manager.load_from_ckpt(
[prompt_adapter_request.local_path],
model_config=self._runtime_model_config,
uids=[str(prompt_adapter_request.adapter_id)])
def _enqueue_request(self,
request: GenerationRequest,
result_wait_queue=None) -> int:
assert request.id is not None
py_lora_path = None
if self._lora_manager is not None and request.lora_request is not None:
adapter_in_cache = self._lora_manager.is_adapter_in_cpu_cache(
request.lora_request.adapter_id)
self._load_lora_adapter(request.lora_request)
uid = str(request.lora_request.adapter_id)
lora_config = tllm.LoraConfig(
task_id=request.lora_request.adapter_id,
weights=self._lora_manager.cpp_lora_weights[uid]
if not adapter_in_cache else None,
config=self._lora_manager.cpp_lora_config[uid])
py_lora_path = request.lora_request.lora_path
else:
lora_config = None
prompt_token_ids = copy.deepcopy(request.prompt_token_ids)
prompt_tuning_config = None
if request.prompt_adapter_request is not None:
self._load_prompt_adapter(request.prompt_adapter_request)
uid = str(request.prompt_adapter_request.adapter_id)
prompt_tuning_config = tllm.PromptTuningConfig(
self._prompt_adapter_manager.uid_to_weights[uid])
vocab_size = self._runtime_model_config.vocab_size
pa_length = prompt_tuning_config.embedding_table.size(0)
prompt_token_ids = list(range(
vocab_size, vocab_size + pa_length)) + prompt_token_ids
# MULTIMODAL
# NOTE: Since, we only support PyTorch backend for multimodal, we will send multimodal_data through the 'py_multimodal_data' field
# except `multimodal_input` as it needs to go through the C++ runtime.
multimodal_input = None
if request.multimodal_params is not None and request.multimodal_params.has_content(
):
if request.multimodal_params.multimodal_input is not None:
multimodal_input = tllm.MultimodalInput(
multimodal_hashes=request.multimodal_params.
multimodal_input.multimodal_hashes,
multimodal_positions=request.multimodal_params.
multimodal_input.multimodal_positions,
multimodal_lengths=request.multimodal_params.
multimodal_input.multimodal_lengths)
# NOTE: Setting to None here to avoid sending multimodal_input again through the 'py_multimodal_data' field
request.multimodal_params.multimodal_input = None
context_phase_params = None
request_type = tllm.RequestType.REQUEST_TYPE_CONTEXT_AND_GENERATION
if request.disaggregated_params is not None:
assert (
not self._is_pytorch_backend
or self.engine.kv_cache_transceiver is not None
or request.disaggregated_params.request_type
== "context_and_generation"
), "kv_cache_transceiver is disabled, please set 'cache_transceiver_config: backend:<backend_type>` in config file for disaggregated serving"
request_type = request.disaggregated_params.get_request_type()
if request_type == tllm.RequestType.REQUEST_TYPE_GENERATION_ONLY:
context_phase_params = request.disaggregated_params.get_context_phase_params(
)
if self._is_pytorch_backend:
if not self.llm_args.disable_overlap_scheduler:
is_disaggregated = self.engine.kv_cache_transceiver is not None
if is_disaggregated and (
request_type
== tllm.RequestType.REQUEST_TYPE_CONTEXT_ONLY):
raise ValueError(
"Context only requests are not supported in pytorch backend when overlap is enabled."
)
assert request.id is not None
def _deduce_max_tokens(request: GenerationRequest,
executor_config: tllm.ExecutorConfig,
llm_args: Optional[BaseLlmArgs] = None) -> int:
# deduce max_tokens when it's not set by user
max_tokens = request.sampling_params.max_tokens
query_token_len = len(
request.query_token_ids) if request.query_token_ids else 0
cp_size = 1
max_seq_len = None
if llm_args is not None:
# deduce max_tokens by llm args
assert executor_config is None, "An empty executor_config in _deduce_max_tokens is expected when LLM arguments are defined."
if hasattr(self,
"mapping") and self.mapping.cp_size is not None:
cp_size = self.mapping.cp_size
max_seq_len = getattr(self, "max_seq_len", None)
else:
# deduce max_tokens by executor config
if hasattr(executor_config, "mapping"
) and executor_config.mapping.cp_size is not None:
cp_size = executor_config.mapping.cp_size
max_seq_len = getattr(executor_config, "max_seq_len", None)
if max_seq_len is None:
logger.warning("`default_max_tokens` cannot be deduced")
if max_tokens is None:
raise ValueError(
"`max_tokens` must be set when `default_max_tokens` cannot be deduced"
)
else:
# use max_tokens if can't deduce default_max_tokens
return max_tokens
if executor_config is not None:
assert (
len(prompt_token_ids) <= executor_config.max_seq_len
), f"`prompt_token_ids` length ({len(prompt_token_ids)}) is greater than `max_seq_len` ({executor_config.max_seq_len})"
splited_prompt_len = int(len(prompt_token_ids) / cp_size)
default_max_tokens = max_seq_len - splited_prompt_len - query_token_len
if default_max_tokens <= 0:
# Raise error on `default_max_tokens` not enough, since max_tokens should be less than `default_max_tokens``
raise ValueError(
f"`default_max_tokens` ({default_max_tokens}) must be greater than 0, "
f"`default_max_tokens` ({default_max_tokens}) = max_seq_len ({max_seq_len})"
f" - `splited_prompt_len` ({splited_prompt_len}) - `query_token_len` ({query_token_len})"
)
# default_max_tokens is the biggest available value
if max_tokens is None:
return default_max_tokens
elif max_tokens > default_max_tokens and default_max_tokens > 0:
logger.warning(
f"User-specified `max_tokens` ({max_tokens}) is greater than deduced "
f"`default_max_tokens` ({default_max_tokens}), using default_max_tokens instead."
)
return default_max_tokens
elif max_tokens <= 0:
raise ValueError(
f"`max_tokens` ({max_tokens}) must be greater than 0")
else:
return max_tokens
try:
executor_request = tllm.Request(
client_id=request.id,
input_token_ids=prompt_token_ids,
max_tokens=_deduce_max_tokens(
request,
self._executor_config if not self.llm_args else None,
self.llm_args),
streaming=request.streaming,
sampling_config=request.sampling_params._get_sampling_config(),
end_id=-1 if request.sampling_params.ignore_eos else
request.sampling_params.end_id,
pad_id=request.sampling_params.pad_id,
output_config=request.sampling_params._get_output_config(
is_pytorch_backend=self._is_pytorch_backend),
# Beam search enforces return_all_generated_tokens=True regardless of the passed value
return_all_generated_tokens=False,
# convert python config into pybind config
lookahead_config=PybindMirror.maybe_to_pybind(
request.sampling_params.lookahead_config),
guided_decoding_params=request.sampling_params.
_get_guided_decoding_params(),
bad_words=request.sampling_params._get_bad_words(),
stop_words=[] if request.sampling_params.ignore_eos else
request.sampling_params._get_stop_words(),
embedding_bias=request.sampling_params.embedding_bias,
lora_config=lora_config,
prompt_tuning_config=prompt_tuning_config,
multimodal_input=multimodal_input,
# NOTE: `multimodal_embedding` and `mrope_config` will be in MultimodalParams.multimodal_data. And this will be handled below by `py_multimodal_data`.
multimodal_embedding=None,
mrope_config=None,
logits_post_processor_name=(
tllm.Request.BATCHED_POST_PROCESSOR_NAME
if request.sampling_params.apply_batched_logits_processor
else None),
logits_post_processor=None if self._is_pytorch_backend else
request.sampling_params.logits_processor,
kv_cache_retention_config=request.kv_cache_retention_config,
context_phase_params=context_phase_params,
type=request_type,
cache_salt_id=request.cache_salt_id)
executor_request.py_num_logprobs = request.sampling_params.logprobs
executor_request.py_lora_path = py_lora_path
if self._is_pytorch_backend and request.multimodal_params is not None:
if request.multimodal_params.multimodal_data is not None:
# NOTE: Deserialize SharedTensor handle to actual tensor
request.multimodal_params.to_tensor("multimodal_data")
executor_request.py_multimodal_data = request.multimodal_params.multimodal_data
if self._is_pytorch_backend and request.sampling_params.logits_processor:
# For PyTorch backend, we attach logits processors as a dynamic Python attribute
# instead of using the C++ binding, since the latter will cause PyCapsule pickling issues.
lp = request.sampling_params.logits_processor
executor_request.py_logits_post_processors = lp if isinstance(
lp, list) else [lp]
executor_request.py_scheduling_params = None
if self._is_pytorch_backend and request.scheduling_params is not None:
executor_request.py_scheduling_params = request.scheduling_params
if request.arrival_time is not None:
executor_request.py_arrival_time = request.arrival_time
if request.query_token_ids is not None:
# pytorch star attention workflow
# a workaround to avoid public interface update
if self._is_pytorch_backend and result_wait_queue is not None:
req_id = self.engine.enqueue_request(
executor_request,
request.query_token_ids,
result_wait_queue=result_wait_queue)
else:
req_id = self.engine.enqueue_request(
executor_request, request.query_token_ids)
else:
if self._is_pytorch_backend and result_wait_queue is not None:
req_id = self.engine.enqueue_request(
executor_request, result_wait_queue=result_wait_queue)
else:
req_id = self.engine.enqueue_request(executor_request)
return req_id
except Exception as e:
raise RequestError(str(e)) from e
def submit(self, request: GenerationRequest) -> GenerationResult:
""" Low-level API to the executor. Return a "future" GenerationResult which can be waited. """
self.start()
if self.rank != 0:
raise RuntimeError(
"Only rank 0 can submit requests.\n"
"To fix this, ensure that the llm.generate(...) method is "
"guarded with the `if __name__ == '__main__':` block.")
client_id = request.id if request.id is not None else self._get_next_client_id(
)
if request.id is None:
request.set_id(client_id)
logprob_params = self._get_logprob_params(request)
result = GenerationResult(
request,
background_error_handler=self._handle_background_error,
executor=self,
disaggregated_params=request.disaggregated_params,
logprob_params=logprob_params)
self._results[client_id] = result
request_id = self._enqueue_request(request)
# request_id returned from backend is necessary for the abort_request method.
self._client_id_to_request_id[client_id] = request_id
self._handle_background_error()
return result
def shutdown(self):
if self.doing_shutdown:
return
else:
self.doing_shutdown = True
if self.engine is not None and self.engine.can_enqueue_requests():
self.engine.shutdown()
self.engine = None
# Define a Callable to join iteration and request stats
@staticmethod
def _stats_serializer(
stats: Tuple[tllm.IterationStats, tllm.RequestStats]) -> str:
iteration_stats, req_stats = stats
stats_dict = json.loads(iteration_stats.to_json_str())
if req_stats is not None and len(req_stats) > 0:
stats_dict["requestStats"] = []
for req_stat in req_stats:
stats_dict["requestStats"].append(
json.loads(req_stat.to_json_str()))
# Convert back to JSON string
return json.dumps(stats_dict)
# Define a Callable to serialize KV cache events
@staticmethod
def _kv_cache_events_serializer(events) -> str:
from .._utils import KVCacheEventSerializer
return json.dumps(KVCacheEventSerializer.serialize(events))
def _pop_result(self, client_id: int):
self._results.pop(client_id, None)
self._client_id_to_request_id.pop(client_id, None)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback) -> bool:
self.shutdown()
return exc_type is None or exc_type == self.WorkerExit
def __del__(self):
self.shutdown()
class AwaitResponseHelper:
''' Multiple-implementations for await_response for performance. '''
class HandlerKind(enum.Enum):
unknown = 0
single_process_worker = 1
ipc_batched = 2
def __init__(self, worker: "BaseWorker"):
# TODO: make worker weakref
self.worker = worker
self.handler_kind: AwaitResponseHelper.HandlerKind = AwaitResponseHelper.HandlerKind.unknown
self.enable_postprocprocess_parallel = self.worker.enable_postprocess_parallel
# The error responses when submit request failed will be put here
self.temp_error_responses = Queue()
def responses_handler(self, responses: List[tllm.Response]):
HandlerKind = AwaitResponseHelper.HandlerKind
if self.handler_kind is HandlerKind.unknown:
if not (self.worker.result_queue is not None
or self.worker.postproc_queues is not None):
logger_debug(f"creating await_response helper for Worker\n",
color="yellow")
# When ExecutorBindingWorker is used in the main process
# aka the single process mode
self.handler_kind = HandlerKind.single_process_worker
elif self.worker.result_queue is not None or self.worker.postproc_queues is not None:
# The ExecutorBindingProxy is used
logger_debug(f"creating await_response helper for IPC\n",
color="yellow")
self.handler_kind = HandlerKind.ipc_batched
else:
raise NotImplementedError
match self.handler_kind:
case HandlerKind.single_process_worker:
return self.handle_for_worker(responses)
case HandlerKind.ipc_batched:
return self.handle_for_ipc_batched(responses)
case _:
raise NotImplementedError
def __call__(self, timeout: Optional[float] = None) -> bool:
''' This method should be called by a ManagedThread. '''
timeout = timeout or 0.1
responses = self.worker.engine.await_responses(
timeout=datetime.timedelta(seconds=timeout))
# filter since The _engine_response_callback may return None
responses = list(
filter(
lambda _: _,
[self.worker._engine_response_callback(r) for r in responses]))
# append the error responses to the temp_error_responses
while not self.temp_error_responses.empty():
responses.append(self.temp_error_responses.get())
with nvtx_range_debug(f"await_response-{len(responses)}",
color="red",
category="Worker"):
self.responses_handler(responses)
return True
def handle_for_worker(self, responses: List[tllm.Response]) -> None:
''' Return the responses to asyncio.event_loop. '''
event_loop = None
async_queues = []
for response in responses:
assert response is not None
queue = self.worker.return_queue(response.client_id)
if not response.has_error():
response = _maybe_wrap_response(self.worker, response,
self.worker._is_pytorch_backend)
# For AsyncQueue.sync_q, we will batch the events to avoid too many
# event notifications, thus put without wait here.
if isinstance(queue, _SyncQueue):
global_tracer().log_instant("worker-rsp.put")
queue.put_nowait(response)
async_queues.append(queue)
# all the loops are identical
event_loop = event_loop or queue.loop
else:
queue.put(response)
if response.has_error() or response.result.is_final:
self.worker._pop_result(response.client_id)
# Notify the events in bulk for performance.
if async_queues:
_SyncQueue.notify_many(event_loop, async_queues)
def handle_for_ipc_batched(self, responses: List[tllm.Response]) -> None:
''' Perform the IPC in batch explicitly. '''
postproc_batches = [
[]
for _ in range(self.worker.postproc_config.num_postprocess_workers)
] if self.enable_postprocprocess_parallel else None
rsp_batch = [] if not self.enable_postprocprocess_parallel else None
for response in responses:
if isinstance(response, ErrorResponse):
pass # send ErrorResponse directly
elif self.worker._has_background_error():
response = self.worker._create_error_response(response)
elif response.has_error():
# Convert to ErrorResponse, because tllm.Response cannot be
# serialized when it has error.
response = ErrorResponse(response.client_id, response.error_msg,
response.request_id)
else:
response = _maybe_wrap_response(self.worker, response,
self.worker._is_pytorch_backend)
_send_rsp(self.worker,
response,
postproc_batches=postproc_batches,
rsp_batch=rsp_batch)
if postproc_batches:
for wid, batch in enumerate(postproc_batches):
if batch:
self.worker.postproc_queues[wid].put(batch)
if rsp_batch:
self.worker.result_queue.put(rsp_batch)
def _get_params_for_first_rsp(
worker,
client_id) -> Tuple[Optional[SamplingParams], Optional[PostprocParams]]:
res = worker._results.get(client_id, None)
assert res is not None
if not res._params_transmitted:
res._params_transmitted = True
return res.sampling_params, res.postproc_params
return None, None
def _compute_pytorch_prompt_logprobs(
generation_result: GenerationResult,
response: LlmResponse) -> Optional[LogProbsResult]:
"""Compute prompt logprobs for PyTorch backend (cached when streaming) """
logprob_params = generation_result._logprob_params # should be present and non None
assert logprob_params is not None
if generation_result._streaming:
cached = getattr(generation_result, '_cached_prompt_logprobs', None)
if cached is not None:
return LogProbsResult(
prompt=cached, generation=None
) # generation logprobs, if requested, is provided directly in response.result.log_probs from the sampler.
context_logits = response.result.context_logits
assert context_logits is not None, "context_logits cannot be None when prompt_logprobs is requested."
logprobs_result = compute_logprobs(logprob_params.prompt_logprobs, None,
context_logits, None, None)
if generation_result._streaming:
generation_result._cached_prompt_logprobs = logprobs_result.prompt
return logprobs_result
def _get_logprobs(worker,
response: Union[tllm.Response, LlmResponse],
is_pytorch_backend=False) -> Optional[LogProbsResult]:
"""Compute logprobs from response logits when needed.
Logprobs provenance varies by backend:
- PyTorch: Generation logprobs computed in sampler, only prompt logprobs computed here
- TRT: Both prompt and generation logprobs computed here from logits
"""
logprobs_result = None
generation_result = worker._results.get(response.client_id, None)
if not generation_result:
return None
logprob_params = getattr(generation_result, "_logprob_params", None)
if logprob_params:
if is_pytorch_backend:
if not logprob_params.prompt_logprobs:
# PyTorch: generation logprobs computed in sampler, no post-processing needed
return None
else:
logprobs_result = _compute_pytorch_prompt_logprobs(
generation_result, response)
if logprob_params.drop_context_logits:
response.clear_context_logits()
return logprobs_result
# TRT backend: compute both prompt and generation logprobs from logits
logprobs_result = compute_logprobs(logprob_params.prompt_logprobs,
logprob_params.logprobs,
response.result.context_logits,
response.result.generation_logits,
response.result.output_token_ids[0])
if logprob_params.drop_context_logits:
response.clear_context_logits()
if logprob_params.drop_generation_logits:
response.clear_generation_logits()
if response.result.is_final:
generation_result.clear_logprob_params()
return logprobs_result
def _send_rsp(
worker,
response: Union[tllm.Response, ResponseWrapper, ErrorResponse],
postproc_batches: Optional[List[List["PostprocWorker.Input"]]] = None,
rsp_batch: Optional[List[tllm.Response]] = None):
# if postproc_batches is set, append to batch instead of putting to IpcQueue
if worker.result_queue is not None:
if rsp_batch is not None:
rsp_batch.append(response)
else:
worker.result_queue.put(response)
else:
sampling_params, postproc_params = _get_params_for_first_rsp(
worker, response.client_id)
inp = PostprocWorker.Input(
response,
# sampling_params is necessary for creating fake GenerationResult
# instances in the postproc processes. They are for incremental
# detokenize. They should be transmitted only once for each
# Request.
sampling_params=sampling_params,
postproc_params=postproc_params,
streaming=worker._results.get(response.client_id, None)._streaming)
pid = response.client_id % worker.postproc_config.num_postprocess_workers
if not postproc_batches:
# Group the responses into buckets for the postprocessing steps.
# Bucketing is used instead of random dispatching because the
# incremental detokenization during postprocessing relies on the
# prior CompletionOutput of a given request.
worker.postproc_queues[pid].put(inp)
else:
postproc_batches[pid].append(inp)
# Eliminate the finished GenerationRequest instances timely, which may
# take considerable memory.
if is_llm_response(response):
if response.has_error() or response.result.is_final:
worker._pop_result(response.client_id)
elif isinstance(response, ErrorResponse):
worker._pop_result(response.client_id)
else:
raise ValueError(f"Unknown response type: {response}")
def _get_metrics_dict(
response: tllm.Response) -> dict[RequestEventTiming, float]:
req_perf_metrics, metrics_dict = None, {}
res = response.result
if res:
if hasattr(res, '_result'):
if result := res.get_result():
req_perf_metrics = result.request_perf_metrics
else:
req_perf_metrics = res.request_perf_metrics
if req_perf_metrics and req_perf_metrics.timing_metrics:
metrics_dict = {
RequestEventTiming.ARRIVAL_TIME:
req_perf_metrics.timing_metrics.arrival_time.total_seconds(),
RequestEventTiming.FIRST_TOKEN_TIME:
req_perf_metrics.timing_metrics.first_token_time.total_seconds(
),
RequestEventTiming.FIRST_SCHEDULED_TIME:
req_perf_metrics.timing_metrics.first_scheduled_time.
total_seconds(),
RequestEventTiming.LAST_TOKEN_TIME:
req_perf_metrics.timing_metrics.last_token_time.total_seconds(),
RequestEventTiming.KV_CACHE_TRANSFER_START:
req_perf_metrics.timing_metrics.kv_cache_transfer_start.
total_seconds(),
RequestEventTiming.KV_CACHE_TRANSFER_END:
req_perf_metrics.timing_metrics.kv_cache_transfer_end.
total_seconds(),
RequestEventTiming.KV_CACHE_SIZE:
req_perf_metrics.timing_metrics.kv_cache_size,
}
return metrics_dict
def _maybe_wrap_response(
worker,
response: tllm.Response,
is_pytorch_backend=False) -> Union[tllm.Response, ResponseWrapper]:
logprobs_result = _get_logprobs(worker, response, is_pytorch_backend)
req_perf_metrics = _get_metrics_dict(response)
if logprobs_result or req_perf_metrics:
response = ResponseWrapper(response, logprobs_result, req_perf_metrics)
return response