TensorRT-LLMs/tensorrt_llm/executor/worker.py
Yan Chunwei ac17142495
chore: rename ExecutorBindingsWorker/Proxy (#4716)
Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>
2025-05-29 10:32:35 +08:00

1020 lines
44 KiB
Python

import copy
import datetime
import enum
import json
import os
import time
import traceback
from concurrent.futures import ProcessPoolExecutor
from pathlib import Path
from queue import Queue
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
import tensorrt_llm.executor.serialization as serialization
from tensorrt_llm.logger import logger
from .._utils import (KVCacheEventSerializer, global_mpi_rank, mpi_comm,
mpi_rank, nvtx_range_debug)
from ..bindings import executor as tllm
from ..builder import ConfigEncoder, Engine, EngineConfig
from ..llmapi.llm_args import PybindMirror
from ..llmapi.mpi_session import set_mpi_session_cpp
from ..llmapi.tracer import VizTracer, global_tracer, set_global_tracer
from ..llmapi.utils import (AsyncQueue, ManagedThread, _SyncQueue,
clear_sched_affinity, print_colored_debug,
print_traceback_on_error)
from ..lora_manager import LoraConfig, LoraManager
from ..prompt_adapter_manager import PromptAdapterManager
from ..runtime import ModelConfig
from ..runtime.model_runner import _engine_config_to_model_config
from ..sampling_params import BatchedLogitsProcessor, SamplingParams
from .executor import GenerationExecutor, IterationResultQueue
from .ipc import FusedIpcQueue, IpcQueue
from .postproc_worker import (PostprocParams, PostprocWorker,
PostprocWorkerConfig, postproc_worker_main)
from .request import (CancellingRequest, GenerationRequest, LoRARequest,
PromptAdapterRequest)
from .result import (GenerationResult, IterationResult, LogProbsResult,
ResponseWrapper, compute_logprobs)
from .utils import (PERIODICAL_RESP_IN_AWAIT, ErrorResponse, IntraProcessQueue,
RequestError, WorkerCommIpcAddrs, has_event_loop,
is_llm_response)
__all__ = [
"GenerationExecutorWorker",
]
class GenerationExecutorWorker(GenerationExecutor):
class WorkerExit(GeneratorExit):
pass
def __init__(
self,
engine: Union[Path, Engine],
executor_config: Optional[tllm.ExecutorConfig] = None,
batched_logits_processor: Optional[BatchedLogitsProcessor] = None,
postproc_worker_config: Optional[PostprocWorkerConfig] = None,
is_llm_executor: Optional[bool] = None,
lora_config: Optional[LoraConfig] = None,
) -> None:
postproc_config = postproc_worker_config or PostprocWorkerConfig()
super().__init__(
num_postprocess_workers=postproc_config.num_postprocess_workers,
postprocess_tokenizer_dir=postproc_config.postprocess_tokenizer_dir,
is_llm_executor=is_llm_executor,
)
self.engine = None
self.result_queue: Optional[IpcQueue] = None
self.postproc_queues: Optional[List[IpcQueue]] = None
self.rank = mpi_rank()
self.global_rank = global_mpi_rank()
# mapping: client_id -> GenerationResult
self._results: Dict[int, GenerationResult] = {}
# mapping: client_id from Proxy -> request_id returned from runtime backend
self._client_id_to_request_id: Dict[int, int] = {}
self._await_response_helper = AwaitResponseHelper(
self) # TODO: make it weakref
self._executor_config = executor_config
self._is_pytorch_backend = getattr(self._executor_config, "backend",
None) == "pytorch"
if isinstance(engine, list):
engine = engine[self.rank]
if executor_config is None:
executor_config = tllm.ExecutorConfig(1)
executor_config.logits_post_processor_config = tllm.LogitsPostProcessorConfig(
processor_batched=batched_logits_processor, replicate=False)
def _create_engine():
device_id = self.global_rank % torch.cuda.device_count()
torch.cuda.set_device(device_id)
# Make sure C++ executor would use same devices/ranks as py_executor
global_rank = global_mpi_rank()
comm_ranks = mpi_comm().allgather(global_rank)
device_ids = mpi_comm().allgather(device_id)
executor_config.parallel_config = tllm.ParallelConfig(
participant_ids=comm_ranks, device_ids=device_ids)
if isinstance(engine, Engine):
return tllm.Executor(engine.engine,
json.dumps(engine.config.to_dict(),
cls=ConfigEncoder),
tllm.ModelType.DECODER_ONLY,
executor_config=executor_config,
managed_weights=engine.managed_weights)
if not hasattr(executor_config, "backend"):
return tllm.Executor(engine, tllm.ModelType.DECODER_ONLY,
executor_config)
args = {
"executor_config": executor_config,
"checkpoint_dir": executor_config.hf_model_dir,
"engine_dir": executor_config.trt_engine_dir,
}
if executor_config.backend == "pytorch":
from tensorrt_llm._torch.pyexecutor.py_executor_creator import \
create_py_executor
create_executor = create_py_executor
args["lora_config"] = lora_config
elif executor_config.backend == "autodeploy":
from tensorrt_llm._torch.auto_deploy.shim.ad_executor import \
create_autodeploy_executor
create_executor = create_autodeploy_executor
else:
raise ValueError(
f"Unsupported backend config: {executor_config.backend}")
return create_executor(**args)
self.engine = _create_engine()
self._lora_manager: Optional[LoraManager] = None
self._prompt_adapter_manager: Optional[PromptAdapterManager] = None
self._runtime_model_config: Optional[ModelConfig] = None
if self.rank == 0 and isinstance(self.engine, tllm.Executor):
if isinstance(engine, Engine):
engine_config = engine.config
else:
engine_config = EngineConfig.from_json_file(
f"{engine}/config.json")
self._runtime_model_config = _engine_config_to_model_config(
engine_config)
if engine_config.build_config.plugin_config.lora_plugin:
self._lora_manager = LoraManager()
if engine_config.build_config.max_prompt_embedding_table_size > 0:
self._prompt_adapter_manager = PromptAdapterManager()
if getattr(executor_config, "backend",
"") == "pytorch" and lora_config is not None:
self._lora_manager = LoraManager()
lora_model_config = self.engine.model_engine.lora_model_config
assert lora_model_config is not None
self._lora_model_config = lora_model_config
self.await_response_thread = ManagedThread(
self.await_response_task,
error_queue=self._error_queue,
name="await_response_thread")
self.dispatch_stats_thread = ManagedThread(
self.dispatch_stats_task,
error_queue=self._error_queue,
name="dispatch_stats_thread")
self.dispatch_kv_cache_events_thread = ManagedThread(
self.dispatch_kv_cache_events_task,
error_queue=self._error_queue,
name="dispatch_kv_cache_events_thread")
def set_result_queue(self, queue):
"""In multi-gpu mode, result_queue will be set here to communicate between the proxy and the worker 0 process."""
assert self.postproc_queues is None
self.result_queue = queue
def set_postproc_queues(self, queues: List["IpcQueue"]):
""" Set the IPC queues for feeding post-processing processes. """
assert self.result_queue is None
self.postproc_queues = queues
def _create_iteration_result_queue(self,
it_result_queue: IterationResultQueue):
if not it_result_queue.is_initialized:
# not yet initialized
it_result_queue.is_initialized = True
if has_event_loop():
_queue = AsyncQueue()
it_result_queue.queue = _queue.sync_q
it_result_queue.aqueue = _queue
else:
_queue = Queue()
it_result_queue.queue = _queue
it_result_queue.aqueue = None
def _set_iteration_result_queue(self, it_result_queue: IterationResultQueue,
queue: Union[Queue, FusedIpcQueue,
IntraProcessQueue]):
assert not it_result_queue.is_initialized, "Iteration result queue should not already be initialized."
it_result_queue.is_initialized = True
it_result_queue.queue = queue
it_result_queue.aqueue = None
def return_queue(self, client_id: int):
""" If a centralized result queue is registered (used for communication with the proxy)
send the message there.
Otherwise, push the result directly in the GenerationResult queue.
"""
if self.result_queue is not None:
return self.result_queue
return self._results[client_id].queue
def start_thread(self, thread: ManagedThread):
if self.engine.can_enqueue_requests() and not thread.is_alive():
thread.start()
def abort_request(self, client_id: int) -> None:
# NOTE: the request_id is the request_id generated by cpp runtime, not the client_id
if self.engine.can_enqueue_requests():
request_id = self._client_id_to_request_id.get(client_id, None)
if request_id is None:
logger.warning(
f"Request of client_id {client_id} is finished, cannot abort it."
)
return
self.engine.cancel_request(request_id)
def _engine_response_callback(self, response: tllm.Response):
return response
def await_response_task(self) -> bool:
return self._await_response_helper()
def _has_background_error(self) -> bool:
return not self._error_queue.empty()
def _create_error_response(self, response: tllm.Response) -> ErrorResponse:
bck_error = self._error_queue.get_nowait()
assert isinstance(bck_error, Exception)
return ErrorResponse(response.client_id, str(bck_error),
response.request_id)
def _iteration_result_task(self, it_result_queue: IterationResultQueue,
engine_get_result_api: Callable,
result_singleton: IterationResult,
result_serializer: Callable):
time.sleep(0.2)
async_queues = []
queue = result_singleton.queue if self._is_llm_executor and result_singleton else it_result_queue.queue
try:
for results in engine_get_result_api():
res = result_serializer(results)
if self._is_llm_executor and result_singleton:
# In this case, there's no ExecutorBindingProxy.
# Worker needs to take care of putting to result queue.
while queue.full():
queue.get()
if isinstance(queue, _SyncQueue):
queue.put_nowait(res)
async_queues.append(queue)
else:
queue.put(res)
else:
# Send to ExecutorBindingProxy via IPC
queue.put(res)
if async_queues:
_SyncQueue.notify_many(queue.loop, async_queues)
except AsyncQueue.EventLoopShutdownError:
# This happens in the last results loop while the generate workflow is stopped.
logger.debug("worker.py: EventLoopShutdownError")
except Exception as e:
logger.error(f"worker.py: Error in _iteration_result_task: {e}")
raise e
return True # success
def dispatch_stats_task(self) -> bool:
# Define a Callable to join iteration and request stats
def stats_serializer(
stats: Tuple[tllm.IterationStats, tllm.RequestStats]) -> str:
iteration_stats, req_stats = stats
stats_dict = json.loads(iteration_stats.to_json_str())
if req_stats is not None and len(req_stats) > 0:
stats_dict["requestStats"] = []
for req_stat in req_stats:
stats_dict["requestStats"].append(
json.loads(req_stat.to_json_str()))
# Convert back to JSON string
return json.dumps(stats_dict)
def get_stats():
if isinstance(self.engine, tllm.Executor):
iter_stats = self.engine.get_latest_iteration_stats()
#TODO: Support req stats with TRT engine
# This would require ensuring iter and req stats have same size
return [(iter_stat, None) for iter_stat in iter_stats]
else:
return self.engine.get_latest_iteration_stats()
return self._iteration_result_task(self.stats_queues, get_stats,
self._iter_stats_result,
stats_serializer)
def dispatch_kv_cache_events_task(self) -> bool:
if isinstance(self.engine, tllm.Executor):
# Check if the engine has a kv cache event manager
# If not, return an empty list for the events which will cause the thread to exit early.
event_manager = self.engine.get_kv_cache_event_manager()
if event_manager is None:
events_api = lambda: [None]
else:
events_api = event_manager.get_latest_events
return self._iteration_result_task(
self.kv_events_queues, events_api, self._iter_kv_events_result,
lambda x: json.dumps(KVCacheEventSerializer.serialize(x)))
else:
return self._iteration_result_task(
self.kv_events_queues, self.engine.get_latest_kv_cache_events,
self._iter_kv_events_result,
lambda x: json.dumps(KVCacheEventSerializer.serialize(x)))
def start(self):
# create iteration result queues
self._create_iteration_result_queue(self.stats_queues)
self._create_iteration_result_queue(self.kv_events_queues)
# start threads
self.start_thread(self.await_response_thread)
self.start_thread(self.dispatch_kv_cache_events_thread)
if mpi_rank() == 0:
self.start_thread(self.dispatch_stats_thread)
def _load_lora_adapter(self, lora_request: LoRARequest):
self._lora_manager.load_from_ckpt(
[lora_request.path],
model_config=self._runtime_model_config if
self._runtime_model_config is not None else self._lora_model_config,
runtime_mapping=None,
uids=[str(lora_request.adapter_id)])
def _load_prompt_adapter(self,
prompt_adapter_request: PromptAdapterRequest):
self._prompt_adapter_manager.load_from_ckpt(
[prompt_adapter_request.local_path],
model_config=self._runtime_model_config,
uids=[str(prompt_adapter_request.adapter_id)])
def _enqueue_request(self, request: GenerationRequest) -> int:
assert request.id is not None
if self._lora_manager is not None and request.lora_request is not None:
self._load_lora_adapter(request.lora_request)
uid = str(request.lora_request.adapter_id)
lora_config = tllm.LoraConfig(
task_id=request.lora_request.adapter_id,
weights=self._lora_manager.cpp_lora_weights[uid],
config=self._lora_manager.cpp_lora_config[uid])
else:
lora_config = None
prompt_token_ids = copy.deepcopy(request.prompt_token_ids)
prompt_tuning_config = None
multimodal_embedding = None
mrope_config = None
if request.multimodal_embedding is not None:
multimodal_embedding = request.multimodal_embedding
if request.prompt_adapter_request is not None:
self._load_prompt_adapter(request.prompt_adapter_request)
uid = str(request.prompt_adapter_request.adapter_id)
prompt_tuning_config = tllm.PromptTuningConfig(
self._prompt_adapter_manager.uid_to_weights[uid])
vocab_size = self._runtime_model_config.vocab_size
pa_length = prompt_tuning_config.embedding_table.size(0)
prompt_token_ids = list(range(
vocab_size, vocab_size + pa_length)) + prompt_token_ids
if request.mrope_config is not None:
mrope_config = tllm.MropeConfig(**request.mrope_config)
context_phase_params = None
request_type = tllm.RequestType.REQUEST_TYPE_CONTEXT_AND_GENERATION
if request.disaggregated_params is not None:
request_type = request.disaggregated_params.get_request_type()
if request_type == tllm.RequestType.REQUEST_TYPE_GENERATION_ONLY:
context_phase_params = request.disaggregated_params.get_context_phase_params(
)
is_overlap_enabled = self._is_pytorch_backend and not self._executor_config.pytorch_backend_config.disable_overlap_scheduler
if is_overlap_enabled:
is_disaggregated = self.engine.kv_cache_transceiver is not None
if is_disaggregated and (
request_type == tllm.RequestType.REQUEST_TYPE_CONTEXT_ONLY):
raise ValueError(
"Context only requests are not supported in pytorch backend when overlap is enabled."
)
assert request.id is not None
def _deduce_max_tokens(request: GenerationRequest,
executor_config: tllm.ExecutorConfig) -> int:
if request.sampling_params.max_tokens:
return request.sampling_params.max_tokens
# deduce max_tokens when it's not set by user
query_token_len = len(
request.query_token_ids) if request.query_token_ids else 0
cp_size = 1 if (not hasattr(executor_config, "mapping")
or executor_config.mapping.cp_size
is None) else executor_config.mapping.cp_size
if not hasattr(executor_config, "max_seq_len"):
raise RuntimeError(
"max_tokens for sampling is not set and cannot be deduced")
splited_prompt_len = int(len(prompt_token_ids) / cp_size)
default_max_tokens = executor_config.max_seq_len - splited_prompt_len - query_token_len
if default_max_tokens < 0:
raise ValueError(
f"Deduced max_tokens {default_max_tokens} is less than 0, because"
f"prompt length {splited_prompt_len} plus query length {query_token_len} "
f"is larger than max_seq_len {executor_config.max_seq_len}")
return default_max_tokens
try:
executor_request = tllm.Request(
client_id=request.id,
input_token_ids=prompt_token_ids,
max_tokens=_deduce_max_tokens(request, self._executor_config),
streaming=request.streaming,
sampling_config=request.sampling_params._get_sampling_config(),
end_id=-1 if request.sampling_params.ignore_eos else
request.sampling_params.end_id,
pad_id=request.sampling_params.pad_id,
output_config=request.sampling_params._get_output_config(
is_pytorch_backend=self._is_pytorch_backend),
# Beam search enforces return_all_generated_tokens=True regardless of the passed value
return_all_generated_tokens=False,
# convert python config into pybind config
lookahead_config=PybindMirror.maybe_to_pybind(
request.sampling_params.lookahead_config),
guided_decoding_params=request.sampling_params.
_get_guided_decoding_params(),
bad_words=request.sampling_params._get_bad_words(),
stop_words=request.sampling_params._get_stop_words(),
embedding_bias=request.sampling_params.embedding_bias,
lora_config=lora_config,
prompt_tuning_config=prompt_tuning_config,
multimodal_embedding=multimodal_embedding,
mrope_config=mrope_config,
logits_post_processor_name=(
tllm.Request.BATCHED_POST_PROCESSOR_NAME
if request.sampling_params.apply_batched_logits_processor
else None),
logits_post_processor=None if self._is_pytorch_backend else
request.sampling_params.logits_processor,
kv_cache_retention_config=request.kv_cache_retention_config,
context_phase_params=context_phase_params,
type=request_type)
if self._is_pytorch_backend and request.sampling_params.logits_processor:
# For PyTorch backend, we attach logits processors as a dynamic Python attribute
# instead of using the C++ binding, since the latter will cause PyCapsule pickling issues.
lp = request.sampling_params.logits_processor
executor_request.py_logits_post_processors = lp if isinstance(
lp, list) else [lp]
if request.query_token_ids is not None:
# pytorch star attention workflow
# a workaround to avoid public interface update
req_id = self.engine.enqueue_request(executor_request,
request.query_token_ids)
else:
req_id = self.engine.enqueue_request(executor_request)
return req_id
except Exception as e:
raise RequestError(str(e)) from e
def submit(self, request: GenerationRequest) -> GenerationResult:
""" Low-level API to the executor. Return a "future" GenerationResult which can be waited. """
self.start()
if self.rank != 0:
raise RuntimeError(
"Only rank 0 can submit requests.\n"
"To fix this, ensure that the llm.generate(...) method is "
"guarded with the `if __name__ == '__main__':` block.")
client_id = request.id if request.id is not None else self._get_next_client_id(
)
if request.id is None:
request.set_id(client_id)
logprob_params = self._get_logprob_params(request)
result = GenerationResult(
request,
background_error_handler=self._handle_background_error,
executor=self,
disaggregated_params=request.disaggregated_params,
logprob_params=logprob_params)
self._results[client_id] = result
request_id = self._enqueue_request(request)
# request_id returned from backend is necessary for the abort_request method.
self._client_id_to_request_id[client_id] = request_id
self._handle_background_error()
return result
def _pop_result(self, client_id: int):
self._results.pop(client_id, None)
self._client_id_to_request_id.pop(client_id, None)
def shutdown(self):
if self.doing_shutdown:
return
else:
self.doing_shutdown = True
print_colored_debug(f'Worker {mpi_rank()} shutdown...\n', "yellow")
if self.engine is not None:
if self.engine.can_enqueue_requests():
if self.await_response_thread.is_alive():
self.await_response_thread.stop()
self.await_response_thread.join()
if self.dispatch_stats_thread.is_alive():
self.dispatch_stats_thread.stop()
self.dispatch_stats_thread.join()
if self.dispatch_kv_cache_events_thread.is_alive():
self.dispatch_kv_cache_events_thread.stop()
self.dispatch_kv_cache_events_thread.join()
self.engine.shutdown()
self.engine = None
# Check if there are any errors from the threads before shutdown.
self._handle_background_error()
print_colored_debug(f"Worker {mpi_rank()} shutdown done.\n", "yellow")
def block_subordinates(self):
if self.rank != 0:
if isinstance(self.engine, tllm.Executor):
self.shutdown()
raise self.WorkerExit(
"block_subordinates() should be used in a `with GenerationExecutorWorker() as ...:` block"
)
from tensorrt_llm._torch.pyexecutor.py_executor import PyExecutor
if isinstance(self.engine, PyExecutor):
self.engine.wait_shutdown()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback) -> bool:
self.shutdown()
return exc_type is None or exc_type == GenerationExecutorWorker.WorkerExit
def __del__(self):
self.shutdown()
@print_traceback_on_error
def worker_main(
engine: Path | Engine,
worker_queues: WorkerCommIpcAddrs,
log_level: str,
executor_config: Optional[tllm.ExecutorConfig] = None,
batched_logits_processor: Optional[BatchedLogitsProcessor] = None,
worker_cls: type = GenerationExecutorWorker,
tracer_init_kwargs: Optional[dict] = None,
_torch_model_class_mapping: Optional[dict] = None,
postproc_worker_config: Optional[PostprocWorkerConfig] = None,
ready_signal: Optional[str] = None,
is_llm_executor: Optional[
bool] = True, # whether it's the main executor instance
lora_config: Optional[LoraConfig] = None,
BASE_ZMQ_CLASSES: Dict = serialization.BASE_ZMQ_CLASSES,
) -> None:
# The base classes for ZMQ serialization. Passed through from the parent process to ensure
# that children processes include any classes added at runtime (such as those from `register_approved_ipc_class`).
serialization.BASE_ZMQ_CLASSES = BASE_ZMQ_CLASSES
mpi_comm().barrier()
print_colored_debug(f"Worker {mpi_rank()} entering worker_main...\n",
"green")
pid = os.getpid()
cpus = os.sched_getaffinity(pid)
if cpus:
logger.warning(
f"Found worker process {pid} was bound to {cpus}, this may harm"
"performance.", )
logger.warning(f"Will clear the cpu affinity")
clear_sched_affinity(pid)
result_queue: Optional[IpcQueue] = None
result_queues: Optional[List[IpcQueue]] = None
postproc_worker_config = postproc_worker_config or PostprocWorkerConfig()
is_leader: bool = mpi_rank() == 0
if tracer_init_kwargs is not None and is_leader:
tracer = VizTracer(**tracer_init_kwargs)
tracer.register_exit()
tracer.start()
set_global_tracer(tracer)
if _torch_model_class_mapping is not None:
from tensorrt_llm._torch.models.modeling_auto import MODEL_CLASS_MAPPING
MODEL_CLASS_MAPPING.update(**_torch_model_class_mapping)
set_mpi_session_cpp(mpi_comm())
if is_leader:
# Only set the log level for the leader process, the other processes will
# inherit the log level from "TLLM_LOG_LEVEL" environment variable
logger.set_level(log_level)
request_queue = IpcQueue(worker_queues.request_queue_addr,
is_server=False,
name="worker_request_queue")
request_error_queue = IpcQueue(worker_queues.request_error_queue_addr,
is_server=False,
name="worker_request_error_queue")
mp_stats_queue = FusedIpcQueue(worker_queues.stats_queue_addr,
is_server=False,
fuse_message=True,
name="worker_stats_queue")
kv_cache_events_queue = FusedIpcQueue(
worker_queues.kv_cache_events_queue_addr,
is_server=False,
fuse_message=False,
name="worker_kv_cache_events_queue")
if postproc_worker_config.enabled:
# IPC queues for sending inputs to the postprocess parallel
# processes, each one is a PAIR zmq socket
result_queues = [
FusedIpcQueue(is_server=True,
fuse_message=PERIODICAL_RESP_IN_AWAIT,
name=f"postprocess_{i}_feedin_queue")
for i in range(postproc_worker_config.num_postprocess_workers)
]
else:
# IPC queue for sending results back to the proxy, and let the
# Proxy process to handle the postprocess
result_queue = FusedIpcQueue(worker_queues.result_queue_addr,
is_server=False,
fuse_message=PERIODICAL_RESP_IN_AWAIT,
name="worker_result_queue")
def notify_proxy_threads_to_quit():
# Signal the dispatcher thread in the proxy to quit
if result_queue is not None:
result_queue.put(None)
else:
assert result_queues is not None
for q in result_queues:
q.put(None)
# Signal the stats thread in the proxy to quit
mp_stats_queue.put(None)
kv_cache_events_queue.put(None)
postprocess_worker_futures = []
if is_leader and postproc_worker_config.enabled:
print_colored_debug(f"initiate postprocess workers...", "yellow")
proxy_result_queue: tuple[
str, Optional[bytes]] = worker_queues.result_queue_addr
assert result_queues is not None
assert postproc_worker_config.postprocess_tokenizer_dir is not None
postproc_worker_pool = ProcessPoolExecutor(
max_workers=postproc_worker_config.num_postprocess_workers)
assert isinstance(proxy_result_queue, tuple)
for i in range(postproc_worker_config.num_postprocess_workers):
fut = postproc_worker_pool.submit(
postproc_worker_main, result_queues[i].address,
proxy_result_queue,
postproc_worker_config.postprocess_tokenizer_dir,
PostprocWorker.default_record_creator,
serialization.BASE_ZMQ_CLASSES)
postprocess_worker_futures.append(fut)
# Error handling in the Worker/MPI process
# 1. During Executor initialization, the errors will be captured and
# send back via request_error_queue.
# 2. During execution, the errors will be captured by ManagedThreads
# a) For per-request error, the error will be send back via
# result_queue, and eventually raised in handle_response() in
# the main thread.
# b) For system error, the error will be raised in the MPI process
# and handled by future.done_callback, that will propagate the
# error to the error_queue in the main thread.
mpi_comm().barrier()
print_colored_debug(f"Worker {mpi_rank()} ready to setup backend...\n",
"green")
try:
worker: GenerationExecutorWorker = worker_cls(
engine,
executor_config,
batched_logits_processor,
postproc_worker_config=postproc_worker_config,
is_llm_executor=is_llm_executor,
lora_config=lora_config)
except Exception as e:
logger.error(f"Failed to initialize executor on rank {mpi_rank()}: {e}")
logger.error(traceback.format_exc())
print_colored_debug(f"error: {traceback.format_exc()}", "red")
if is_leader:
request_error_queue.put(e)
return
with worker:
try:
worker.block_subordinates()
if is_leader:
if postproc_worker_config.enabled:
worker.set_postproc_queues(result_queues)
else:
worker.set_result_queue(result_queue)
# initialize the iteration result queues
worker._set_iteration_result_queue(worker.stats_queues,
mp_stats_queue)
worker._set_iteration_result_queue(worker.kv_events_queues,
kv_cache_events_queue)
request_error_queue.put(ready_signal)
while (req := request_queue.get()) is not None:
if isinstance(req, CancellingRequest):
worker.abort_request(req.id)
elif isinstance(req, GenerationRequest):
try:
worker.submit(req)
request_error_queue.put(None) # None means success
except RequestError as e:
request_error_queue.put(e)
else:
raise ValueError(f"Unknown request type: {type(req)}")
notify_proxy_threads_to_quit()
except GenerationExecutorWorker.WorkerExit as e:
# This will capture by the with-statement and exit normally.
raise e
except Exception as e: # other critical errors
if is_leader:
notify_proxy_threads_to_quit()
err = Exception(f"Failed during generation: {e}")
logger.error(traceback.format_exc())
if is_leader:
request_error_queue.put(err)
class AwaitResponseHelper:
''' Multiple-implementations for await_response for performance. '''
class HandlerKind(enum.Enum):
unknown = 0
single_process_worker = 1
ipc_periodically = 2
ipc_batched = 3
def __init__(self, worker: "GenerationExecutorWorker"):
# TODO: make worker weakref
self.worker = worker
self.handler_kind: AwaitResponseHelper.HandlerKind = AwaitResponseHelper.HandlerKind.unknown
self.enable_postprocprocess_parallel = self.worker.enable_postprocess_parallel
def responses_handler(self, responses: List[tllm.Response]):
HandlerKind = AwaitResponseHelper.HandlerKind
if self.handler_kind is HandlerKind.unknown:
if not (self.worker.result_queue is not None
or self.worker.postproc_queues is not None):
print_colored_debug(
f"creating await_response helper for Worker\n",
color="yellow")
# When ExecutorBindingWorker is used in the main process
# aka the single process mode
self.handler_kind = HandlerKind.single_process_worker
elif self.worker.result_queue is not None or self.worker.postproc_queues is not None:
# The ExecutorBindingProxy is used
print_colored_debug(f"creating await_response helper for IPC\n",
color="yellow")
if PERIODICAL_RESP_IN_AWAIT:
self.handler_kind = HandlerKind.ipc_periodically
else:
self.handler_kind = HandlerKind.ipc_batched
else:
raise NotImplementedError
match self.handler_kind:
case HandlerKind.single_process_worker:
return self.handle_for_worker(responses)
case HandlerKind.ipc_batched:
return self.handle_for_ipc_batched(responses)
case HandlerKind.ipc_periodically:
return self.handle_for_ipc_periodically(responses)
case _:
raise NotImplementedError
def __call__(self) -> bool:
''' This method should be called by a ManagedThread. '''
responses = self.worker.engine.await_responses(
timeout=datetime.timedelta(milliseconds=100))
# filter since The _engine_response_callback may return None
responses = list(
filter(
lambda _: _,
[self.worker._engine_response_callback(r) for r in responses]))
with nvtx_range_debug(f"await_response-{len(responses)}",
color="red",
category="Worker"):
self.responses_handler(responses)
return True
def handle_for_worker(self, responses: List[tllm.Response]) -> None:
''' Return the responses to asyncio.event_loop. '''
event_loop = None
async_queues = []
for response in responses:
assert response is not None
queue = self.worker.return_queue(response.client_id)
logprobs_result = _get_logprobs(self.worker, response,
self.worker._is_pytorch_backend)
if logprobs_result:
response = ResponseWrapper(response, logprobs_result)
# For AsyncQueue.sync_q, we will batch the events to avoid too many
# event notifications, thus put without wait here.
if isinstance(queue, _SyncQueue):
global_tracer().log_instant("worker-rsp.put")
queue.put_nowait(response)
async_queues.append(queue)
# all the loops are identical
event_loop = event_loop or queue.loop
else:
queue.put(response)
if response.has_error() or response.result.is_final:
self.worker._pop_result(response.client_id)
# Notify the events in bulk for performance.
if async_queues:
_SyncQueue.notify_many(event_loop, async_queues)
def handle_for_ipc_periodically(self,
responses: List[tllm.Response]) -> None:
''' Return the responses to Proxy via IPC. This will put Rsp to a Queue
in a FusedIpcQueue, and a background thread will batch them and invoke
IPC periodically. '''
with nvtx_range_debug(f"handle_for_ipc_periodically-{len(responses)}",
color="red",
category="Worker"):
for response in responses:
if self.worker._has_background_error():
response = self.worker._create_error_response(response)
elif response.has_error():
response = ErrorResponse(response.client_id,
response.error_msg,
response.request_id)
else:
logprobs_result = _get_logprobs(
self.worker, response, self.worker._is_pytorch_backend)
if logprobs_result:
response = ResponseWrapper(response, logprobs_result)
# TODO: To verify the performance of using ZMQ instead of SharedMemory
# to send the logits tensor back to the Proxy process.
_send_rsp(self.worker, response)
def handle_for_ipc_batched(self, responses: List[tllm.Response]) -> None:
''' Perform the IPC in batch explicitly. '''
postproc_batches = [
[]
for _ in range(self.worker.postproc_config.num_postprocess_workers)
] if self.enable_postprocprocess_parallel else None
rsp_batch = [] if not self.enable_postprocprocess_parallel else None
for response in responses:
if self.worker._has_background_error():
response = self.worker._create_error_response(response)
elif response.has_error():
# Convert to ErrorResponse, because tllm.Response cannot be
# serialized when it has error.
response = ErrorResponse(response.client_id, response.error_msg,
response.request_id)
else:
logprobs_result = _get_logprobs(self.worker, response,
self.worker._is_pytorch_backend)
if logprobs_result:
response = ResponseWrapper(response, logprobs_result)
_send_rsp(self.worker,
response,
postproc_batches=postproc_batches,
rsp_batch=rsp_batch)
if postproc_batches:
for wid, batch in enumerate(postproc_batches):
if batch:
self.worker.postproc_queues[wid].put(batch)
if rsp_batch:
self.worker.result_queue.put(rsp_batch)
def _get_params_for_first_rsp(
worker,
client_id) -> Tuple[Optional[SamplingParams], Optional[PostprocParams]]:
res = worker._results.get(client_id, None)
assert res is not None
if not res._params_transmitted:
res._params_transmitted = True
return res.sampling_params, res.postproc_params
return None, None
def _get_logprobs(worker,
response: tllm.Response,
is_pytorch_backend=False) -> Optional[LogProbsResult]:
"""Compute logprob and prompt logprob and clear out logits if applicable.
"""
if is_pytorch_backend:
# _get_logprobs() is a WAR for the TRT backend, where top-k logprobs are computed post runtime.
# In the PyTorch backend, logprobs are already computed during runtime if requested.
return None
logprobs_result = None
generation_result = worker._results.get(response.client_id, None)
if not generation_result:
return
logprob_params = getattr(generation_result, "_logprob_params", None)
if logprob_params:
logprobs_result = compute_logprobs(logprob_params.prompt_logprobs,
logprob_params.logprobs,
response.result.context_logits,
response.result.generation_logits,
response.result.output_token_ids[0])
if logprob_params.drop_context_logits:
response.clear_context_logits()
if logprob_params.drop_generation_logits:
response.clear_generation_logits()
if response.result.is_final:
generation_result.clear_logprob_params()
return logprobs_result
def _send_rsp(
worker,
response: Union[tllm.Response, ResponseWrapper, ErrorResponse],
postproc_batches: Optional[List[List["PostprocWorker.Input"]]] = None,
rsp_batch: Optional[List[tllm.Response]] = None):
# if postproc_batches is set, append to batch instead of putting to IpcQueue
if worker.result_queue is not None:
if rsp_batch is not None:
rsp_batch.append(response)
else:
worker.result_queue.put(response)
else:
sampling_params, postproc_params = _get_params_for_first_rsp(
worker, response.client_id)
inp = PostprocWorker.Input(
response,
# sampling_params is necessary for creating fake GenerationResult
# instances in the postproc processes. They are for incremental
# detokenize. They should be transmitted only once for each
# Request.
sampling_params=sampling_params,
postproc_params=postproc_params,
streaming=worker._results.get(response.client_id, None)._streaming)
pid = response.client_id % worker.postproc_config.num_postprocess_workers
if not postproc_batches:
# Group the responses into buckets for the postprocessing steps.
# Bucketing is used instead of random dispatching because the
# incremental detokenization during postprocessing relies on the
# prior CompletionOutput of a given request.
worker.postproc_queues[pid].put(inp)
else:
postproc_batches[pid].append(inp)
# Eliminate the finished GenerationRequest instances timely, which may
# take considerable memory.
if is_llm_response(response):
if response.has_error() or response.result.is_final:
worker._pop_result(response.client_id)
elif isinstance(response, ErrorResponse):
worker._pop_result(response.client_id)
else:
raise ValueError(f"Unknown response type: {response}")