TensorRT-LLMs/tensorrt_llm/executor/worker.py
Yibin Li 351808efeb
fix: Use hmac authentication for pickle encryption (#3384)
* hmac initial implementation to encrypt worker and proxy queue

Signed-off-by: Yibin Li <109242046+yibinl-nvidia@users.noreply.github.com>

* set different hmac key for each pair of server/client queue

Signed-off-by: Yibin Li <109242046+yibinl-nvidia@users.noreply.github.com>

* fix comments

Signed-off-by: Yibin Li <109242046+yibinl-nvidia@users.noreply.github.com>

* fix style

Signed-off-by: Yibin Li <109242046+yibinl-nvidia@users.noreply.github.com>

---------

Signed-off-by: Yibin Li <109242046+yibinl-nvidia@users.noreply.github.com>
2025-04-17 00:40:13 +08:00

880 lines
38 KiB
Python

import copy
import datetime
import enum
import json
import os
import time
import traceback
from concurrent.futures import ProcessPoolExecutor
from pathlib import Path
from queue import Queue
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
from tensorrt_llm.logger import logger
from .._utils import (KVCacheEventSerializer, global_mpi_rank, mpi_comm,
mpi_rank, nvtx_range_debug)
from ..bindings import executor as tllm
from ..builder import ConfigEncoder, Engine, EngineConfig
from ..llmapi.llm_args import PybindMirror
from ..llmapi.mpi_session import set_mpi_session_cpp
from ..llmapi.tracer import VizTracer, global_tracer, set_global_tracer
from ..llmapi.utils import (AsyncQueue, ManagedThread, _SyncQueue,
clear_sched_affinity, print_colored_debug,
print_traceback_on_error)
from ..lora_manager import LoraManager
from ..prompt_adapter_manager import PromptAdapterManager
from ..runtime import ModelConfig
from ..runtime.model_runner import _engine_config_to_model_config
from ..sampling_params import BatchedLogitsProcessor, SamplingParams
from .executor import GenerationExecutor, IterationResultQueue
from .ipc import FusedIpcQueue, IpcQueue
from .postproc_worker import (PostprocParams, PostprocWorker,
PostprocWorkerConfig, postproc_worker_main)
from .request import (CancellingRequest, GenerationRequest, LoRARequest,
PromptAdapterRequest)
from .result import GenerationResult, IterationResult
from .utils import (PERIODICAL_RESP_IN_AWAIT, ErrorResponse, IntraProcessQueue,
RequestError, WorkerCommIpcAddrs, has_event_loop)
__all__ = [
"ExecutorBindingsWorker",
]
class ExecutorBindingsWorker(GenerationExecutor):
class WorkerExit(GeneratorExit):
pass
def __init__(
self,
engine: Union[Path, Engine],
executor_config: Optional[tllm.ExecutorConfig] = None,
batched_logits_processor: Optional[BatchedLogitsProcessor] = None,
postproc_worker_config: Optional[PostprocWorkerConfig] = None,
is_llm_executor: Optional[bool] = None,
) -> None:
postproc_config = postproc_worker_config or PostprocWorkerConfig()
super().__init__(
num_postprocess_workers=postproc_config.num_postprocess_workers,
postprocess_tokenizer_dir=postproc_config.postprocess_tokenizer_dir,
is_llm_executor=is_llm_executor,
)
self.engine = None
self.result_queue: Optional[IpcQueue] = None
self.postproc_queues: Optional[List[IpcQueue]] = None
self.rank = mpi_rank()
self.global_rank = global_mpi_rank()
# mapping: client_id -> GenerationResult
self._results: Dict[int, GenerationResult] = {}
# mapping: client_id from Proxy -> request_id returned from runtime backend
self._client_id_to_request_id: Dict[int, int] = {}
self._await_response_helper = AwaitResponseHelper(
self) # TODO: make it weakref
self._executor_config = executor_config
if isinstance(engine, list):
engine = engine[self.rank]
if executor_config is None:
executor_config = tllm.ExecutorConfig(1)
executor_config.logits_post_processor_config = tllm.LogitsPostProcessorConfig(
processor_batched=batched_logits_processor, replicate=False)
def _create_engine():
if isinstance(engine, Engine):
return tllm.Executor(engine.engine,
json.dumps(engine.config.to_dict(),
cls=ConfigEncoder),
tllm.ModelType.DECODER_ONLY,
executor_config=executor_config,
managed_weights=engine.managed_weights)
if not hasattr(executor_config, "backend"):
return tllm.Executor(engine, tllm.ModelType.DECODER_ONLY,
executor_config)
elif executor_config.backend == "pytorch":
from tensorrt_llm._torch.pyexecutor.py_executor_creator import \
create_py_executor
create_executor = create_py_executor
elif executor_config.backend == "autodeploy":
from tensorrt_llm._torch.auto_deploy.shim.ad_executor import \
create_autodeploy_executor
create_executor = create_autodeploy_executor
else:
raise ValueError(
f"Unsupported backend config: {executor_config.backend}")
device_id = self.global_rank % torch.cuda.device_count()
torch.cuda.set_device(device_id)
return create_executor(executor_config=executor_config,
checkpoint_dir=executor_config.hf_model_dir,
engine_dir=executor_config.trt_engine_dir)
self.engine = _create_engine()
self._lora_manager: Optional[LoraManager] = None
self._prompt_adapter_manager: Optional[PromptAdapterManager] = None
self._runtime_model_config: Optional[ModelConfig] = None
if self.rank == 0 and isinstance(self.engine, tllm.Executor):
if isinstance(engine, Engine):
engine_config = engine.config
else:
engine_config = EngineConfig.from_json_file(
f"{engine}/config.json")
self._runtime_model_config = _engine_config_to_model_config(
engine_config)
if engine_config.build_config.plugin_config.lora_plugin:
self._lora_manager = LoraManager()
if engine_config.build_config.max_prompt_embedding_table_size > 0:
self._prompt_adapter_manager = PromptAdapterManager()
self.await_response_thread = ManagedThread(
self.await_response_task,
error_queue=self._error_queue,
name="await_response_thread")
self.dispatch_stats_thread = ManagedThread(
self.dispatch_stats_task,
error_queue=self._error_queue,
name="dispatch_stats_thread")
self.dispatch_kv_cache_events_thread = ManagedThread(
self.dispatch_kv_cache_events_task,
error_queue=self._error_queue,
name="dispatch_kv_cache_events_thread")
def set_result_queue(self, queue):
"""In multi-gpu mode, result_queue will be set here to communicate between the proxy and the worker 0 process."""
assert self.postproc_queues is None
self.result_queue = queue
def set_postproc_queues(self, queues: List["IpcQueue"]):
""" Set the IPC queues for feeding post-processing processes. """
assert self.result_queue is None
self.postproc_queues = queues
def _create_iteration_result_queue(self,
it_result_queue: IterationResultQueue):
if not it_result_queue.is_initialized:
# not yet initialized
it_result_queue.is_initialized = True
if has_event_loop():
_queue = AsyncQueue()
it_result_queue.queue = _queue.sync_q
it_result_queue.aqueue = _queue
else:
_queue = Queue()
it_result_queue.queue = _queue
it_result_queue.aqueue = None
def _set_iteration_result_queue(self, it_result_queue: IterationResultQueue,
queue: Union[Queue, FusedIpcQueue,
IntraProcessQueue]):
assert not it_result_queue.is_initialized, "Iteration result queue should not already be initialized."
it_result_queue.is_initialized = True
it_result_queue.queue = queue
it_result_queue.aqueue = None
def return_queue(self, client_id: int):
""" If a centralized result queue is registered (used for communication with the proxy)
send the message there.
Otherwise, push the result directly in the GenerationResult queue.
"""
if self.result_queue is not None:
return self.result_queue
return self._results[client_id].queue
def start_thread(self, thread: ManagedThread):
if self.engine.can_enqueue_requests() and not thread.is_alive():
thread.start()
def abort_request(self, client_id: int) -> None:
# NOTE: the request_id is the request_id generated by cpp runtime, not the client_id
if self.engine.can_enqueue_requests():
request_id = self._client_id_to_request_id.get(client_id, None)
if request_id is None:
logger.warning(
f"Request of client_id {client_id} is finished, cannot abort it."
)
self.engine.cancel_request(request_id)
def _engine_response_callback(self, response: tllm.Response):
return response
def await_response_task(self) -> bool:
return self._await_response_helper()
def _has_background_error(self) -> bool:
return not self._error_queue.empty()
def _create_error_response(self, response: tllm.Response) -> ErrorResponse:
bck_error = self._error_queue.get_nowait()
assert isinstance(bck_error, Exception)
return ErrorResponse(response.client_id, str(bck_error),
response.request_id)
def _iteration_result_task(self, it_result_queue: IterationResultQueue,
engine_get_result_api: Callable,
result_singleton: IterationResult,
result_serializer: Callable):
time.sleep(0.2)
async_queues = []
queue = result_singleton.queue if self._is_llm_executor and result_singleton else it_result_queue.queue
try:
for results in engine_get_result_api():
res = result_serializer(results)
if self._is_llm_executor and result_singleton:
# In this case, there's no ExecutorBindingProxy.
# Worker needs to take care of putting to result queue.
while queue.full():
queue.get()
if isinstance(queue, _SyncQueue):
queue.put_nowait(res)
async_queues.append(queue)
else:
queue.put(res)
else:
# Send to ExecutorBindingProxy via IPC
queue.put(res)
if async_queues:
_SyncQueue.notify_many(queue.loop, async_queues)
except AsyncQueue.EventLoopShutdownError:
# This happens in the last results loop while the generate workflow is stopped.
logger.debug("worker.py: EventLoopShutdownError")
except Exception as e:
logger.error(f"worker.py: Error in _iteration_result_task: {e}")
raise e
return True # success
def dispatch_stats_task(self) -> bool:
return self._iteration_result_task(
self.stats_queues, self.engine.get_latest_iteration_stats,
self._iter_stats_result, lambda x: x.to_json_str())
def dispatch_kv_cache_events_task(self) -> bool:
if isinstance(self.engine, tllm.Executor):
# Check if the engine has a kv cache event manager
# If not, return an empty list for the events which will cause the thread to exit early.
event_manager = self.engine.get_kv_cache_event_manager()
if event_manager is None:
events_api = lambda: [None]
else:
events_api = event_manager.get_latest_events
return self._iteration_result_task(
self.kv_events_queues, events_api, self._iter_kv_events_result,
lambda x: json.dumps(KVCacheEventSerializer.serialize(x)))
else:
return self._iteration_result_task(
self.kv_events_queues, self.engine.get_latest_kv_cache_events,
self._iter_kv_events_result,
lambda x: json.dumps(KVCacheEventSerializer.serialize(x)))
def start(self):
# create iteration result queues
self._create_iteration_result_queue(self.stats_queues)
self._create_iteration_result_queue(self.kv_events_queues)
# start threads
self.start_thread(self.await_response_thread)
self.start_thread(self.dispatch_kv_cache_events_thread)
if mpi_rank() == 0:
self.start_thread(self.dispatch_stats_thread)
def _load_lora_adapter(self, lora_request: LoRARequest):
self._lora_manager.load_from_ckpt(
[lora_request.path],
model_config=self._runtime_model_config,
runtime_mapping=None,
uids=[str(lora_request.adapter_id)])
def _load_prompt_adapter(self,
prompt_adapter_request: PromptAdapterRequest):
self._prompt_adapter_manager.load_from_ckpt(
[prompt_adapter_request.local_path],
model_config=self._runtime_model_config,
uids=[str(prompt_adapter_request.adapter_id)])
def _enqueue_request(self, request: GenerationRequest) -> int:
assert request.id is not None
if self._lora_manager is not None and request.lora_request is not None:
self._load_lora_adapter(request.lora_request)
uid = str(request.lora_request.adapter_id)
lora_config = tllm.LoraConfig(
task_id=request.lora_request.adapter_id,
weights=self._lora_manager.cpp_lora_weights[uid],
config=self._lora_manager.cpp_lora_config[uid])
else:
lora_config = None
prompt_token_ids = copy.deepcopy(request.prompt_token_ids)
prompt_tuning_config = None
mrope_config = None
if request.prompt_adapter_request is not None:
assert request.prompt_tuning_config is None, \
"cannot accept both prompt_adapter_request and prompt_tuning_config in one request"
self._load_prompt_adapter(request.prompt_adapter_request)
uid = str(request.prompt_adapter_request.adapter_id)
prompt_tuning_config = tllm.PromptTuningConfig(
self._prompt_adapter_manager.uid_to_weights[uid])
vocab_size = self._runtime_model_config.vocab_size
pa_length = prompt_tuning_config.embedding_table.size(0)
prompt_token_ids = list(range(
vocab_size, vocab_size + pa_length)) + prompt_token_ids
elif request.prompt_tuning_config is not None:
prompt_tuning_config = tllm.PromptTuningConfig(
request.prompt_tuning_config[0])
if request.mrope_config is not None:
mrope_config = tllm.MropeConfig(**request.mrope_config)
context_phase_params = None
request_type = tllm.RequestType.REQUEST_TYPE_CONTEXT_AND_GENERATION
if request.disaggregated_params is not None:
request_type = request.disaggregated_params.get_request_type()
if request_type == tllm.RequestType.REQUEST_TYPE_GENERATION_ONLY:
context_phase_params = request.disaggregated_params.get_context_phase_params(
)
is_overlap_enabled = hasattr(
self._executor_config, "backend"
) and self._executor_config.backend == "pytorch" and self._executor_config.pytorch_backend_config.enable_overlap_scheduler
if is_overlap_enabled:
is_disaggregated = self.engine.kv_cache_transceiver is not None
if is_disaggregated and (
request_type == tllm.RequestType.REQUEST_TYPE_CONTEXT_ONLY):
raise ValueError(
"Context only requests are not supported in pytorch backend when overlap is enabled."
)
assert request.id is not None
try:
executor_request = tllm.Request(
client_id=request.id,
input_token_ids=prompt_token_ids,
max_tokens=request.sampling_params.max_tokens,
max_new_tokens=request.sampling_params.max_new_tokens,
streaming=request.streaming,
sampling_config=request.sampling_params._get_sampling_config(),
end_id=-1 if request.sampling_params.ignore_eos else
request.sampling_params.end_id,
pad_id=request.sampling_params.pad_id,
output_config=request.sampling_params._get_output_config(),
# Beam search enforces return_all_generated_tokens=True regardless of the passed value
return_all_generated_tokens=False,
# convert python config into pybind config
lookahead_config=PybindMirror.maybe_to_pybind(
request.sampling_params.lookahead_config),
guided_decoding_params=request.sampling_params.
_get_guided_decoding_params(),
bad_words=request.sampling_params._get_bad_words(),
stop_words=request.sampling_params._get_stop_words(),
embedding_bias=request.sampling_params.embedding_bias,
lora_config=lora_config,
prompt_tuning_config=prompt_tuning_config,
mrope_config=mrope_config,
logits_post_processor_name=(
tllm.Request.BATCHED_POST_PROCESSOR_NAME
if request.sampling_params.apply_batched_logits_processor
else None),
logits_post_processor=request.sampling_params.logits_processor,
kv_cache_retention_config=request.kv_cache_retention_config,
context_phase_params=context_phase_params,
type=request_type)
if request.query_token_ids is not None:
# pytorch star attention workflow
# a workaround to avoid public interface update
req_id = self.engine.enqueue_request(executor_request,
request.query_token_ids)
else:
req_id = self.engine.enqueue_request(executor_request)
return req_id
except Exception as e:
raise RequestError(str(e))
def submit(self, request: GenerationRequest) -> GenerationResult:
""" Low-level API to the executor. Return a "future" GenerationResult which can be waited. """
self.start()
if self.rank != 0:
raise RuntimeError(
"Only rank 0 can submit requests.\n"
"To fix this, ensure that the llm.generate(...) method is "
"guarded with the `if __name__ == '__main__':` block.")
client_id = request.id if request.id is not None else self._get_next_client_id(
)
if request.id is None:
request.set_id(client_id)
result = GenerationResult(
request,
background_error_handler=self._handle_background_error,
executor=self,
disaggregated_params=request.disaggregated_params)
self._results[client_id] = result
request_id = self._enqueue_request(request)
# request_id returned from backend is necessary for the abort_request method.
self._client_id_to_request_id[client_id] = request_id
self._handle_background_error()
return result
def _pop_result(self, client_id: int):
self._results.pop(client_id, None)
self._client_id_to_request_id.pop(client_id, None)
def shutdown(self):
if self.doing_shutdown:
return
else:
self.doing_shutdown = True
print_colored_debug(f'Worker {mpi_rank()} shutdown...\n', "yellow")
if self.engine is not None:
if self.engine.can_enqueue_requests():
if self.await_response_thread.is_alive():
self.await_response_thread.stop()
self.await_response_thread.join()
if self.dispatch_stats_thread.is_alive():
self.dispatch_stats_thread.stop()
self.dispatch_stats_thread.join()
if self.dispatch_kv_cache_events_thread.is_alive():
self.dispatch_kv_cache_events_thread.stop()
self.dispatch_kv_cache_events_thread.join()
self.engine.shutdown()
self.engine = None
# Check if there are any errors from the threads before shutdown.
self._handle_background_error()
print_colored_debug(f"Worker {mpi_rank()} shutdown done.\n", "yellow")
def block_subordinates(self):
if self.rank != 0:
if isinstance(self.engine, tllm.Executor):
self.shutdown()
raise self.WorkerExit(
"block_subordinates() should be used in a `with ExecutorBindingsWorker() as ...:` block"
)
from tensorrt_llm._torch.pyexecutor.py_executor import PyExecutor
if isinstance(self.engine, PyExecutor):
self.engine.wait_shutdown()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback) -> bool:
self.shutdown()
return exc_type is None or exc_type == ExecutorBindingsWorker.WorkerExit
def __del__(self):
self.shutdown()
@print_traceback_on_error
def worker_main(
engine: Path | Engine,
worker_queues: WorkerCommIpcAddrs,
log_level: str,
executor_config: Optional[tllm.ExecutorConfig] = None,
batched_logits_processor: Optional[BatchedLogitsProcessor] = None,
worker_cls: type = ExecutorBindingsWorker,
tracer_init_kwargs: Optional[dict] = None,
_torch_model_class_mapping: Optional[dict] = None,
postproc_worker_config: Optional[PostprocWorkerConfig] = None,
ready_signal: Optional[str] = None,
is_llm_executor: Optional[
bool] = True, # whether it's the main executor instance
) -> None:
mpi_comm().barrier()
print_colored_debug(f"Worker {mpi_rank()} entering worker_main...\n",
"green")
pid = os.getpid()
cpus = os.sched_getaffinity(pid)
if cpus:
logger.warning(
f"Found worker process {pid} was bound to {cpus}, this may harm"
"performance.", )
logger.warning(f"Will clear the cpu affinity")
clear_sched_affinity(pid)
result_queue: Optional[IpcQueue] = None
result_queues: Optional[List[IpcQueue]] = None
postproc_worker_config = postproc_worker_config or PostprocWorkerConfig()
is_leader: bool = mpi_rank() == 0
if tracer_init_kwargs is not None and is_leader:
tracer = VizTracer(**tracer_init_kwargs)
tracer.register_exit()
tracer.start()
set_global_tracer(tracer)
if _torch_model_class_mapping is not None:
from tensorrt_llm._torch.models.modeling_auto import MODEL_CLASS_MAPPING
MODEL_CLASS_MAPPING.update(**_torch_model_class_mapping)
set_mpi_session_cpp(mpi_comm())
if is_leader:
# Only set the log level for the leader process, the other processes will
# inherit the log level from "TLLM_LOG_LEVEL" environment variable
logger.set_level(log_level)
request_queue = IpcQueue(worker_queues.request_queue_addr,
is_server=False,
name="worker_request_queue")
request_error_queue = IpcQueue(worker_queues.request_error_queue_addr,
is_server=False,
name="worker_request_error_queue")
mp_stats_queue = FusedIpcQueue(worker_queues.stats_queue_addr,
is_server=False,
fuse_message=True,
name="worker_stats_queue")
kv_cache_events_queue = FusedIpcQueue(
worker_queues.kv_cache_events_queue_addr,
is_server=False,
fuse_message=False,
name="worker_kv_cache_events_queue")
if postproc_worker_config.enabled:
# IPC queues for sending inputs to the postprocess parallel
# processes, each one is a PAIR zmq socket
result_queues = [
FusedIpcQueue(is_server=True,
fuse_message=PERIODICAL_RESP_IN_AWAIT,
name=f"postprocess_{i}_feedin_queue")
for i in range(postproc_worker_config.num_postprocess_workers)
]
else:
# IPC queue for sending results back to the proxy, and let the
# Proxy process to handle the postprocess
result_queue = FusedIpcQueue(worker_queues.result_queue_addr,
is_server=False,
fuse_message=PERIODICAL_RESP_IN_AWAIT,
name="worker_result_queue")
def notify_proxy_threads_to_quit():
# Signal the dispatcher thread in the proxy to quit
if result_queue is not None:
result_queue.put(None)
else:
assert result_queues is not None
for q in result_queues:
q.put(None)
# Signal the stats thread in the proxy to quit
mp_stats_queue.put(None)
kv_cache_events_queue.put(None)
postprocess_worker_futures = []
if is_leader and postproc_worker_config.enabled:
print_colored_debug(f"initiate postprocess workers...", "yellow")
proxy_result_queue: tuple[
str, Optional[bytes]] = worker_queues.result_queue_addr
assert result_queues is not None
assert postproc_worker_config.postprocess_tokenizer_dir is not None
postproc_worker_pool = ProcessPoolExecutor(
max_workers=postproc_worker_config.num_postprocess_workers)
assert isinstance(proxy_result_queue, tuple)
for i in range(postproc_worker_config.num_postprocess_workers):
fut = postproc_worker_pool.submit(
postproc_worker_main,
result_queues[i].address,
proxy_result_queue,
postproc_worker_config.postprocess_tokenizer_dir,
PostprocWorker.default_record_creator,
)
postprocess_worker_futures.append(fut)
# Error handling in the Worker/MPI process
# 1. During Executor initialization, the errors will be captured and
# send back via request_error_queue.
# 2. During execution, the errors will be captured by ManagedThreads
# a) For per-request error, the error will be send back via
# result_queue, and eventually raised in handle_response() in
# the main thread.
# b) For system error, the error will be raised in the MPI process
# and handled by future.done_callback, that will propagate the
# error to the error_queue in the main thread.
mpi_comm().barrier()
print_colored_debug(f"Worker {mpi_rank()} ready to setup backend...\n",
"green")
try:
worker: ExecutorBindingsWorker = worker_cls(
engine,
executor_config,
batched_logits_processor,
postproc_worker_config=postproc_worker_config,
is_llm_executor=is_llm_executor)
except Exception as e:
logger.error(f"Failed to initialize executor on rank {mpi_rank()}: {e}")
logger.error(traceback.format_exc())
if is_leader:
request_error_queue.put(e)
return
with worker:
try:
worker.block_subordinates()
if is_leader:
if postproc_worker_config.enabled:
worker.set_postproc_queues(result_queues)
else:
worker.set_result_queue(result_queue)
# initialize the iteration result queues
worker._set_iteration_result_queue(worker.stats_queues,
mp_stats_queue)
worker._set_iteration_result_queue(worker.kv_events_queues,
kv_cache_events_queue)
request_error_queue.put(ready_signal)
while (req := request_queue.get()) is not None:
if isinstance(req, CancellingRequest):
worker.abort_request(req.id)
elif isinstance(req, GenerationRequest):
try:
worker.submit(req)
request_error_queue.put(None) # None means success
except RequestError as e:
request_error_queue.put(e)
else:
raise ValueError(f"Unknown request type: {type(req)}")
notify_proxy_threads_to_quit()
except ExecutorBindingsWorker.WorkerExit as e:
# This will capture by the with-statement and exit normally.
raise e
except Exception as e: # other critical errors
if is_leader:
notify_proxy_threads_to_quit()
err = Exception(f"Failed during generation: {e}")
logger.error(traceback.format_exc())
if is_leader:
request_error_queue.put(err)
class AwaitResponseHelper:
''' Multiple-implementations for await_response for performance. '''
class HandlerKind(enum.Enum):
unknown = 0
single_process_worker = 1
ipc_periodically = 2
ipc_batched = 3
def __init__(self, worker: "ExecutorBindingsWorker"):
# TODO: make worker weakref
self.worker = worker
self.handler_kind: AwaitResponseHelper.HandlerKind = AwaitResponseHelper.HandlerKind.unknown
self.enable_postprocprocess_parallel = self.worker.enable_postprocess_parallel
def responses_handler(self, responses: List[tllm.Response]):
HandlerKind = AwaitResponseHelper.HandlerKind
if self.handler_kind is HandlerKind.unknown:
if not (self.worker.result_queue is not None
or self.worker.postproc_queues is not None):
print_colored_debug(
f"creating await_response helper for Worker\n",
color="yellow")
# When ExecutorBindingWorker is used in the main process
# aka the single process mode
self.handler_kind = HandlerKind.single_process_worker
elif self.worker.result_queue is not None or self.worker.postproc_queues is not None:
# The ExecutorBindingProxy is used
print_colored_debug(f"creating await_response helper for IPC\n",
color="yellow")
if PERIODICAL_RESP_IN_AWAIT:
self.handler_kind = HandlerKind.ipc_periodically
else:
self.handler_kind = HandlerKind.ipc_batched
else:
raise NotImplementedError
match self.handler_kind:
case HandlerKind.single_process_worker:
return self.handle_for_worker(responses)
case HandlerKind.ipc_batched:
return self.handle_for_ipc_batched(responses)
case HandlerKind.ipc_periodically:
return self.handle_for_ipc_periodically(responses)
case _:
raise NotImplementedError
def __call__(self) -> bool:
''' This method should be called by a ManagedThread. '''
responses = self.worker.engine.await_responses(
timeout=datetime.timedelta(milliseconds=100))
# filter since The _engine_response_callback may return None
responses = list(
filter(
lambda _: _,
[self.worker._engine_response_callback(r) for r in responses]))
with nvtx_range_debug(f"await_response-{len(responses)}",
color="red",
category="Worker"):
self.responses_handler(responses)
return True
def handle_for_worker(self, responses: List[tllm.Response]) -> None:
''' Return the responses to asyncio.event_loop. '''
event_loop = None
async_queues = []
for response in responses:
assert response is not None
queue = self.worker.return_queue(response.client_id)
# For AsyncQueue.sync_q, we will batch the events to avoid too many
# event notifications, thus put without wait here.
if isinstance(queue, _SyncQueue):
global_tracer().log_instant("worker-rsp.put")
queue.put_nowait(response)
async_queues.append(queue)
# all the loops are identical
event_loop = event_loop or queue.loop
else:
queue.put(response)
if response.has_error() or response.result.is_final:
self.worker._pop_result(response.client_id)
# Notify the events in bulk for performance.
if async_queues:
_SyncQueue.notify_many(event_loop, async_queues)
def handle_for_ipc_periodically(self,
responses: List[tllm.Response]) -> None:
''' Return the responses to Proxy via IPC. This will put Rsp to a Queue
in a FusedIpcQueue, and a background thread will batch them and invoke
IPC periodically. '''
with nvtx_range_debug(f"handle_for_ipc_periodically-{len(responses)}",
color="red",
category="Worker"):
for response in responses:
if self.worker._has_background_error():
response = self.worker._create_error_response(response)
elif response.has_error():
response = ErrorResponse(response.client_id,
response.error_msg,
response.request_id)
# TODO: To verify the performance of using ZMQ instead of SharedMemory
# to send the logits tensor back to the Proxy process.
_send_rsp(self.worker, response)
def handle_for_ipc_batched(self, responses: List[tllm.Response]) -> None:
''' Perform the IPC in batch explicitly. '''
postproc_batches = [
[]
for _ in range(self.worker.postproc_config.num_postprocess_workers)
] if self.enable_postprocprocess_parallel else None
rsp_batch = [] if not self.enable_postprocprocess_parallel else None
for response in responses:
if self.worker._has_background_error():
response = self.worker._create_error_response(response)
elif response.has_error():
# Convert to ErrorResponse, because tllm.Response cannot be
# serialized when it has error.
response = ErrorResponse(response.client_id, response.error_msg,
response.request_id)
_send_rsp(self.worker,
response,
postproc_batches=postproc_batches,
rsp_batch=rsp_batch)
if postproc_batches:
for wid, batch in enumerate(postproc_batches):
if batch:
self.worker.postproc_queues[wid].put(batch)
if rsp_batch:
self.worker.result_queue.put(rsp_batch)
def _get_params_for_first_rsp(
worker,
client_id) -> Tuple[Optional[SamplingParams], Optional[PostprocParams]]:
res = worker._results.get(client_id, None)
assert res is not None
if not res._params_transmitted:
res._params_transmitted = True
return res.sampling_params, res.postproc_params
return None, None
def _send_rsp(
worker,
response: Union[tllm.Response, ErrorResponse],
postproc_batches: Optional[List[List["PostprocWorker.Input"]]] = None,
rsp_batch: Optional[List[tllm.Response]] = None):
# if postproc_batches is set, append to batch instead of putting to IpcQueue
if worker.result_queue is not None:
if rsp_batch is not None:
rsp_batch.append(response)
else:
worker.result_queue.put(response)
else:
sampling_params, postproc_params = _get_params_for_first_rsp(
worker, response.client_id)
inp = PostprocWorker.Input(
response,
# sampling_params is necessary for creating fake GenerationResult
# instances in the postproc processes. They are for incremental
# detokenize. They should be transmitted only once for each
# Request.
sampling_params=sampling_params,
postproc_params=postproc_params,
streaming=worker._results.get(response.client_id, None)._streaming)
pid = response.client_id % worker.postproc_config.num_postprocess_workers
if not postproc_batches:
# Group the responses into buckets for the postprocessing steps.
# Bucketing is used instead of random dispatching because the
# incremental detokenization during postprocessing relies on the
# prior CompletionOutput of a given request.
worker.postproc_queues[pid].put(inp)
else:
postproc_batches[pid].append(inp)
# Eliminate the finished GenerationRequest instances timely, which may
# take considerable memory.
if isinstance(response, tllm.Response):
if response.has_error() or response.result.is_final:
worker._pop_result(response.client_id)
elif isinstance(response, ErrorResponse):
worker._pop_result(response.client_id)
else:
raise ValueError(f"Unknown response type: {response}")