TensorRT-LLMs/tensorrt_llm/executor/worker.py
Yan Chunwei ea6cd76c55
[None][refactor] simplify get_stats and get_kvcache_events with rpc (#9980)
Signed-off-by: Yan Chunwei <328693+Superjomn@users.noreply.github.com>
Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>
2025-12-22 18:23:43 +08:00

332 lines
13 KiB
Python

import gc
import os
import traceback
from concurrent.futures import ProcessPoolExecutor
from pathlib import Path
from typing import List, Optional, Union
import zmq
from tensorrt_llm.logger import logger
from .._utils import mpi_comm, mpi_rank
from ..bindings import executor as tllm
from ..builder import Engine
from ..llmapi.llm_args import BaseLlmArgs
from ..llmapi.mpi_session import set_mpi_session_cpp
from ..llmapi.tokenizer import TokenizerBase
from ..llmapi.tracer import VizTracer, set_global_tracer
from ..llmapi.utils import ManagedThread, logger_debug, print_traceback_on_error
from ..sampling_params import BatchedLogitsProcessor
from .base_worker import BaseWorker, _init_hf_modules
from .ipc import FusedIpcQueue, IpcQueue
from .postproc_worker import (PostprocWorker, PostprocWorkerConfig,
postproc_worker_main)
from .request import CancellingRequest, GenerationRequest
from .rpc_worker_mixin import RpcWorkerMixin
from .utils import ErrorResponse, RequestError, WorkerCommIpcAddrs
__all__ = [
"GenerationExecutorWorker",
]
class GenerationExecutorWorker(RpcWorkerMixin, BaseWorker):
def __init__(
self,
engine: Union[Path, Engine],
executor_config: Optional[tllm.ExecutorConfig] = None,
batched_logits_processor: Optional[BatchedLogitsProcessor] = None,
postproc_worker_config: Optional[PostprocWorkerConfig] = None,
is_llm_executor: Optional[bool] = None,
hf_model_dir: Optional[Path] = None,
tokenizer: Optional[TokenizerBase] = None,
llm_args: Optional[BaseLlmArgs] = None,
rpc_addr: Optional[str] = None,
hmac_key: Optional[bytes] = None,
) -> None:
super().__init__(
engine=engine,
executor_config=executor_config,
batched_logits_processor=batched_logits_processor,
postproc_worker_config=postproc_worker_config,
is_llm_executor=is_llm_executor,
hf_model_dir=hf_model_dir,
tokenizer=tokenizer,
llm_args=llm_args,
)
self.setup_engine()
# Setup RPC server for stats (skip init_rpc_worker to keep IPC response queue)
# Only set up if rpc_addr is provided (for stats RPC support)
if rpc_addr is not None:
self.rpc_addr = rpc_addr
self.hmac_key = hmac_key
self.start_rpc_server() # Reuse from RpcWorkerMixin
self.await_response_thread = ManagedThread(
self.await_response_task,
error_queue=self._error_queue,
name="await_response_thread")
def start_thread(self, thread: ManagedThread):
if self.engine.can_enqueue_requests() and not thread.is_alive():
thread.start()
def await_response_task(self) -> bool:
return self._await_response_helper()
def start(self):
# Stats and KV events are now fetched on-demand via RPC,
# so we only need to start the response thread
self.start_thread(self.await_response_thread)
def shutdown(self):
if self.doing_shutdown:
return
else:
self.doing_shutdown = True
logger_debug(f'Worker {mpi_rank()} shutdown...\n', "yellow")
if self.engine is not None:
if self.engine.can_enqueue_requests():
if self.await_response_thread.is_alive():
self.await_response_thread.stop()
self.await_response_thread.join()
self.engine.shutdown()
self.engine = None
if self.llm_args is not None:
assert self._executor_config is None, "An empty executor_config is expected in shutdown when LLM arguments are defined."
if (self.llm_args.backend == "pytorch"
and hasattr(self, "checkpoint_loader")
and self.checkpoint_loader is not None):
self.checkpoint_loader.cleanup()
self.checkpoint_loader = None
else:
if hasattr(
self._executor_config, "checkpoint_loader"
) and self._executor_config.checkpoint_loader is not None:
self._executor_config.checkpoint_loader.cleanup()
self._executor_config.checkpoint_loader = None
# Check if there are any errors from the threads before shutdown.
self._handle_background_error()
logger_debug(f"Worker {mpi_rank()} shutdown done.\n", "yellow")
def block_subordinates(self):
if self.rank != 0:
if isinstance(self.engine, tllm.Executor):
self.shutdown()
raise self.WorkerExit(
"block_subordinates() should be used in a `with GenerationExecutorWorker() as ...:` block"
)
from tensorrt_llm._torch.pyexecutor.py_executor import PyExecutor
if isinstance(self.engine, PyExecutor):
self.engine.wait_shutdown()
@print_traceback_on_error
def worker_main(
engine: Path | Engine,
worker_queues: WorkerCommIpcAddrs,
log_level: str,
executor_config: Optional[tllm.ExecutorConfig] = None,
batched_logits_processor: Optional[BatchedLogitsProcessor] = None,
worker_cls: type = GenerationExecutorWorker,
tracer_init_kwargs: Optional[dict] = None,
_torch_model_class_mapping: Optional[dict] = None,
postproc_worker_config: Optional[PostprocWorkerConfig] = None,
ready_signal: Optional[str] = None,
is_llm_executor: Optional[
bool] = True, # whether it's the main executor instance
hf_model_dir: Optional[Path] = None,
tokenizer: Optional[TokenizerBase] = None,
llm_args: Optional[BaseLlmArgs] = None,
rpc_addr: Optional[str] = None,
hmac_key: Optional[bytes] = None,
) -> None:
mpi_comm().barrier()
if llm_args is not None and llm_args.env_overrides:
# this is needed because MPI_Init seems to cache the env at import time.
# The cached env snapshot is used to spawn workers.
# Any env overrides to the main process after tensorrt_llm import
# may not get reflected in the spawned worker process, no matter how early,
# unless we update it explicitly here.
os.environ.update(llm_args.env_overrides)
if llm_args is not None and llm_args.trust_remote_code:
_init_hf_modules()
logger_debug(f"Worker {mpi_rank()} entering worker_main...\n", "green")
result_queue: Optional[IpcQueue] = None
result_queues: Optional[List[IpcQueue]] = None
postproc_worker_config = postproc_worker_config or PostprocWorkerConfig()
is_leader: bool = mpi_rank() == 0
if tracer_init_kwargs is not None and is_leader:
tracer = VizTracer(**tracer_init_kwargs)
tracer.register_exit()
tracer.start()
set_global_tracer(tracer)
if _torch_model_class_mapping is not None:
from tensorrt_llm._torch.models.modeling_auto import MODEL_CLASS_MAPPING
MODEL_CLASS_MAPPING.update(**_torch_model_class_mapping)
set_mpi_session_cpp(mpi_comm())
if is_leader:
# Only set the log level for the leader process, the other processes will
# inherit the log level from "TLLM_LOG_LEVEL" environment variable
logger.set_level(log_level)
request_queue = IpcQueue(worker_queues.request_queue_addr,
is_server=False,
name="worker_request_queue")
worker_init_status_queue = IpcQueue(
worker_queues.worker_init_status_queue_addr,
is_server=False,
socket_type=zmq.DEALER,
name="worker_init_status_queue")
if postproc_worker_config.enabled:
# IPC queues for sending inputs to the postprocess parallel
# processes, each one is a PAIR zmq socket
result_queues = [
FusedIpcQueue(is_server=True,
fuse_message=False,
name=f"postprocess_{i}_feedin_queue")
for i in range(postproc_worker_config.num_postprocess_workers)
]
else:
# IPC queue for sending results back to the proxy, and let the
# Proxy process to handle the postprocess
result_queue = FusedIpcQueue(worker_queues.result_queue_addr,
is_server=False,
fuse_message=False,
name="worker_result_queue")
def notify_proxy_threads_to_quit():
# Signal the dispatcher thread in the proxy to quit
if result_queue is not None:
result_queue.put(None)
else:
assert result_queues is not None
for q in result_queues:
q.put(None)
postprocess_worker_futures = []
if is_leader and postproc_worker_config.enabled:
logger_debug(f"initiate postprocess workers...", "yellow")
proxy_result_queue: tuple[
str, Optional[bytes]] = worker_queues.result_queue_addr
assert result_queues is not None
postproc_worker_pool = ProcessPoolExecutor(
max_workers=postproc_worker_config.num_postprocess_workers)
assert isinstance(proxy_result_queue, tuple)
for i in range(postproc_worker_config.num_postprocess_workers):
fut = postproc_worker_pool.submit(
postproc_worker_main,
result_queues[i].address,
proxy_result_queue,
postproc_worker_config.postprocess_tokenizer_dir,
PostprocWorker.default_record_creator,
)
postprocess_worker_futures.append(fut)
# Error handling in the Worker/MPI process
# 1. During Executor initialization, the errors will be captured and
# send back via request_error_queue.
# 2. During execution, the errors will be captured by ManagedThreads
# a) For per-request error, the error will be send back via
# result_queue, and eventually raised in handle_response() in
# the main thread.
# b) For system error, the error will be raised in the MPI process
# and handled by future.done_callback, that will propagate the
# error to the error_queue in the main thread.
mpi_comm().barrier()
logger_debug(f"Worker {mpi_rank()} ready to setup backend...\n", "green")
try:
worker: GenerationExecutorWorker = worker_cls(
engine,
executor_config,
batched_logits_processor,
postproc_worker_config=postproc_worker_config,
is_llm_executor=is_llm_executor,
hf_model_dir=hf_model_dir,
tokenizer=tokenizer,
llm_args=llm_args,
rpc_addr=rpc_addr,
hmac_key=hmac_key)
except Exception as e:
logger.error(f"Failed to initialize executor on rank {mpi_rank()}: {e}")
logger.error(traceback.format_exc())
logger_debug(f"error: {traceback.format_exc()}", "red")
if is_leader:
# Send error message with confirmation
error_msg = (e, traceback.format_exc())
if not worker_init_status_queue.notify_with_retry(error_msg):
logger.error("Failed to deliver error message to proxy")
return
# Optionally disable GC (default: not disabled)
if os.getenv("TRTLLM_WORKER_DISABLE_GC", "0") == "1":
gc.disable()
with worker:
try:
worker.block_subordinates()
if is_leader:
if postproc_worker_config.enabled:
worker.set_postproc_queues(result_queues)
else:
worker.set_result_queue(result_queue)
# Send ready signal with confirmation
ready_msg = (ready_signal, None)
if not worker_init_status_queue.notify_with_retry(ready_msg):
logger.warning(
"Failed to deliver ready signal to proxy, continuing anyway"
)
while (req := request_queue.get()) is not None:
if isinstance(req, CancellingRequest):
worker.abort_request(req.id)
elif isinstance(req, GenerationRequest):
try:
worker.submit(req)
except RequestError as e:
logger.error(f"submit request failed: {e}")
logger.error(traceback.format_exc())
worker._await_response_helper.temp_error_responses.put(
ErrorResponse(req.id, e, req.id))
else:
raise ValueError(f"Unknown request type: {type(req)}")
notify_proxy_threads_to_quit()
except GenerationExecutorWorker.WorkerExit as e:
# This will capture by the with-statement and exit normally.
raise e
except Exception as e: # other critical errors
if is_leader:
notify_proxy_threads_to_quit()
logger.error(traceback.format_exc())
# This will be captured by mpi4py and handled by future.done_callback
raise e