TensorRT-LLMs/tensorrt_llm/executor/rpc_worker.py
QI JUN 4a8ac8dd62
[TRTLLM-8480][chore] clean create_py_executor API (#8412)
Signed-off-by: junq <22017000+QiJune@users.noreply.github.com>
2025-10-17 23:52:02 -04:00

245 lines
9.6 KiB
Python

import asyncio
from pathlib import Path
from queue import Queue
from threading import Event
from typing import AsyncGenerator, Optional, Union
from tensorrt_llm._utils import mpi_comm
from tensorrt_llm.llmapi.utils import enable_llm_debug, logger_debug
from .._utils import mpi_rank
from ..bindings import executor as tllm
from ..builder import Engine
from ..llmapi.llm_args import BaseLlmArgs
from ..llmapi.tokenizer import TokenizerBase
from ..logger import set_level
from ..sampling_params import BatchedLogitsProcessor
from .base_worker import BaseWorker
from .postproc_worker import PostprocWorkerConfig
from .request import GenerationRequest
from .rpc import RPCServer
class RpcWorker(BaseWorker):
"""
A RPC wrapper for the BaseWorker class.
Actions:
- `setup_engine`: Setup the engine.
- `submit`: Submit a request to the worker.
- `fetch_responses`: Fetch the latest responses from engine.
- `fetch_stats`: Fetch the latest stats from engine.
- `fetch_kv_cache_events`: Fetch the latest kv cache events from engine.
- `shutdown`: Shutdown the worker.
"""
# Number of RPC server workers
NUM_WORKERS = 6
def __init__(
self,
engine: Union[Path, Engine],
executor_config: Optional[tllm.ExecutorConfig] = None,
is_llm_executor: Optional[bool] = None,
batched_logits_processor: Optional[BatchedLogitsProcessor] = None,
postproc_worker_config: Optional[PostprocWorkerConfig] = None,
hf_model_dir: Optional[Path] = None,
tokenizer: Optional[TokenizerBase] = None,
llm_args: Optional[BaseLlmArgs] = None,
) -> None:
super().__init__(
engine=engine,
executor_config=executor_config,
is_llm_executor=is_llm_executor,
llm_args=llm_args,
batched_logits_processor=batched_logits_processor,
postproc_worker_config=postproc_worker_config,
hf_model_dir=hf_model_dir,
tokenizer=tokenizer,
)
# Extract garbage_collection_gen0_threshold from llm_args if available
self.garbage_collection_gen0_threshold = (
llm_args.garbage_collection_gen0_threshold if llm_args is not None
and hasattr(llm_args, 'garbage_collection_gen0_threshold') else
None)
self.shutdown_event = Event()
self._response_queue = Queue()
self.set_result_queue(self._response_queue)
def submit(self, request: GenerationRequest):
""" Submits a request to the worker. """
super().submit(request)
def fetch_responses(self, timeout: Optional[float] = None) -> list:
logger_debug(f"RpcWorker {mpi_rank()} is fetching responses",
color="yellow")
# NOTE: This is a blocking call, it will wait for the responses to be available.
responses = super().await_responses(timeout)
self._await_response_helper.responses_handler(responses)
qsize = self._response_queue.qsize()
logger_debug(f"RpcWorker returning {qsize} responses", color="yellow")
all_responses = []
for _ in range(qsize):
# The queue contains batches of responses, so extend the list
all_responses.extend(self._response_queue.get())
return all_responses
async def fetch_responses_async(self,
timeout: Optional[float] = None) -> list:
# A really async version of fetch_responses
logger_debug(f"RpcWorker {mpi_rank()} is fetching responses async",
color="yellow")
# First, await any pending responses without blocking the event loop
responses = await asyncio.to_thread(self.fetch_responses,
timeout=timeout)
return responses
async def fetch_stats_async(self, timeout: Optional[float] = None) -> list:
return await asyncio.to_thread(self.fetch_stats)
async def fetch_kv_cache_events_async(self,
timeout: Optional[float] = None
) -> list:
return await asyncio.to_thread(self.fetch_kv_cache_events)
# for streaming performance
async def fetch_responses_loop_async(self) -> AsyncGenerator[list, None]:
while not self.shutdown_event.is_set():
responses = await self.fetch_responses_async()
if responses: # Only yield if there are actual responses
logger_debug(
f"RpcWorker {mpi_rank()} is yielding responses: {responses}",
color="yellow")
yield responses # batching the responses to opt IPC performance
else:
# Small delay to prevent busy waiting when no responses
await asyncio.sleep(0)
logger_debug(
f"RpcWorker {mpi_rank()} quitting fetch_responses_loop_async",
color="yellow")
async def _generic_fetch_loop_async(
self,
fetch_method,
serializer,
method_name: str,
timeout: Optional[float] = None) -> AsyncGenerator[list, None]:
"""Generic method for fetching data in a loop.
Args:
fetch_method: The async method to call for fetching data
serializer: The serializer function to apply to each item
method_name: Name of the method for logging
timeout: Optional timeout between fetches
"""
while not self.shutdown_event.is_set():
timeout = timeout or 0.1
await asyncio.sleep(timeout)
data = await fetch_method()
# Always yield data, even if empty, to prevent the client looks like hanging
# TODO: Remove the empty data to reduce the IPC overhead
yield [serializer(item) for item in data]
logger_debug(f"RpcWorker {mpi_rank()} quitting {method_name}",
color="yellow")
async def fetch_stats_loop_async(
self,
timeout: Optional[float] = None) -> AsyncGenerator[list, None]:
async for data in self._generic_fetch_loop_async(
fetch_method=self.fetch_stats_async,
serializer=self._stats_serializer,
method_name="fetch_stats_loop_async",
timeout=timeout):
yield data
async def fetch_kv_cache_events_loop_async(
self,
timeout: Optional[float] = None) -> AsyncGenerator[list, None]:
async for data in self._generic_fetch_loop_async(
fetch_method=self.fetch_kv_cache_events_async,
serializer=self._kv_cache_events_serializer,
method_name="fetch_kv_cache_events_loop_async",
timeout=timeout):
yield data
def setup_engine(self):
# Force all the ranks to wait here, and start creating the executor simultaneously.
# Only call barrier if we have multiple ranks to avoid hanging in single-process tests
if mpi_comm().Get_size() > 1:
mpi_comm().barrier()
super().setup_engine()
def shutdown(self):
logger_debug(f"RPC worker {mpi_rank()} is shutting down",
color="yellow")
self.shutdown_event.set()
super().shutdown()
logger_debug(f"RPC worker {mpi_rank()} is shutdown", color="yellow")
def start(self):
pass
@staticmethod
def main_task(
engine: Union[Path, Engine],
rpc_addr: str,
*,
executor_config: Optional[tllm.ExecutorConfig] = None,
batched_logits_processor: Optional[BatchedLogitsProcessor] = None,
postproc_worker_config: Optional[PostprocWorkerConfig] = None,
is_llm_executor: Optional[bool] = None,
llm_args: Optional[BaseLlmArgs] = None,
hf_model_dir: Optional[Path] = None,
tokenizer: Optional[TokenizerBase] = None,
**kwargs,
) -> None:
if enable_llm_debug():
set_level("debug")
# Step 1: Create the worker instance
worker = RpcWorker(
engine=engine,
executor_config=executor_config,
is_llm_executor=is_llm_executor,
llm_args=llm_args,
batched_logits_processor=batched_logits_processor,
postproc_worker_config=postproc_worker_config,
hf_model_dir=hf_model_dir,
tokenizer=tokenizer,
)
if mpi_rank() != 0:
# The non-leader worker will setup the engine immediately.
# The leader worker will wait for the RPC call to propagate the
# potential error.
logger_debug(f"Worker {mpi_rank()} is setting up the engine",
color="yellow")
worker.setup_engine()
else:
logger_debug(f"Worker {mpi_rank()} is creating the RPC service",
color="yellow")
# Step 2: Create the RPC service, it will expose all the APIs of the worker as remote call to the client
# Set num_workers to larger than 1 since there are some streaming tasks runs infinitely, such as await_responses_async.
rpc_server = RPCServer(worker, num_workers=RpcWorker.NUM_WORKERS)
rpc_server.bind(rpc_addr)
rpc_server.start()
# Step 3: Wait for the worker to shutdown
logger_debug(
f"Worker {mpi_rank()} is waiting for the worker to shutdown")
worker.shutdown_event.wait()
rpc_server.shutdown()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.shutdown()
return True