mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Signed-off-by: Yan Chunwei <328693+Superjomn@users.noreply.github.com> Signed-off-by: chunweiy <chunweiy@nvidia.com> Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com> Signed-off-by: chunweiy <328693+Superjomn@users.noreply.github.com>
254 lines
10 KiB
Python
254 lines
10 KiB
Python
import asyncio
|
|
from pathlib import Path
|
|
from queue import Queue
|
|
from threading import Event
|
|
from typing import AsyncGenerator, Optional, Union
|
|
|
|
from tensorrt_llm._utils import mpi_comm
|
|
from tensorrt_llm.llmapi.utils import enable_llm_debug, logger_debug
|
|
|
|
from .._utils import mpi_rank
|
|
from ..bindings import executor as tllm
|
|
from ..builder import Engine
|
|
from ..llmapi.llm_args import BaseLlmArgs, KvCacheConnectorConfig
|
|
from ..llmapi.tokenizer import TokenizerBase
|
|
from ..logger import set_level
|
|
from ..lora_manager import LoraConfig
|
|
from ..sampling_params import BatchedLogitsProcessor
|
|
from .base_worker import BaseWorker
|
|
from .postproc_worker import PostprocWorkerConfig
|
|
from .request import GenerationRequest
|
|
from .rpc import RPCServer
|
|
|
|
|
|
class RpcWorker(BaseWorker):
|
|
"""
|
|
A RPC wrapper for the BaseWorker class.
|
|
|
|
Actions:
|
|
- `setup_engine`: Setup the engine.
|
|
- `submit`: Submit a request to the worker.
|
|
- `fetch_responses`: Fetch the latest responses from engine.
|
|
- `fetch_stats`: Fetch the latest stats from engine.
|
|
- `fetch_kv_cache_events`: Fetch the latest kv cache events from engine.
|
|
- `shutdown`: Shutdown the worker.
|
|
"""
|
|
|
|
# Number of RPC server workers
|
|
NUM_WORKERS = 6
|
|
|
|
def __init__(
|
|
self,
|
|
engine: Union[Path, Engine],
|
|
executor_config: Optional[tllm.ExecutorConfig] = None,
|
|
is_llm_executor: Optional[bool] = None,
|
|
lora_config: Optional[LoraConfig] = None,
|
|
batched_logits_processor: Optional[BatchedLogitsProcessor] = None,
|
|
postproc_worker_config: Optional[PostprocWorkerConfig] = None,
|
|
kv_connector_config: Optional[KvCacheConnectorConfig] = None,
|
|
hf_model_dir: Optional[Path] = None,
|
|
tokenizer: Optional[TokenizerBase] = None,
|
|
llm_args: Optional[BaseLlmArgs] = None,
|
|
) -> None:
|
|
super().__init__(
|
|
engine=engine,
|
|
executor_config=executor_config,
|
|
is_llm_executor=is_llm_executor,
|
|
lora_config=lora_config,
|
|
llm_args=llm_args,
|
|
batched_logits_processor=batched_logits_processor,
|
|
postproc_worker_config=postproc_worker_config,
|
|
kv_connector_config=kv_connector_config,
|
|
hf_model_dir=hf_model_dir,
|
|
tokenizer=tokenizer,
|
|
)
|
|
# Extract garbage_collection_gen0_threshold from llm_args if available
|
|
self.garbage_collection_gen0_threshold = (
|
|
llm_args.garbage_collection_gen0_threshold if llm_args is not None
|
|
and hasattr(llm_args, 'garbage_collection_gen0_threshold') else
|
|
None)
|
|
self.shutdown_event = Event()
|
|
|
|
self._response_queue = Queue()
|
|
self.set_result_queue(self._response_queue)
|
|
|
|
def submit(self, request: GenerationRequest):
|
|
""" Submits a request to the worker. """
|
|
super().submit(request)
|
|
|
|
def fetch_responses(self, timeout: Optional[float] = None) -> list:
|
|
logger_debug(f"RpcWorker {mpi_rank()} is fetching responses",
|
|
color="yellow")
|
|
# NOTE: This is a blocking call, it will wait for the responses to be available.
|
|
responses = super().await_responses(timeout)
|
|
self._await_response_helper.responses_handler(responses)
|
|
|
|
qsize = self._response_queue.qsize()
|
|
logger_debug(f"RpcWorker returning {qsize} responses", color="yellow")
|
|
|
|
all_responses = []
|
|
for _ in range(qsize):
|
|
# The queue contains batches of responses, so extend the list
|
|
all_responses.extend(self._response_queue.get())
|
|
return all_responses
|
|
|
|
async def fetch_responses_async(self,
|
|
timeout: Optional[float] = None) -> list:
|
|
# A really async version of fetch_responses
|
|
logger_debug(f"RpcWorker {mpi_rank()} is fetching responses async",
|
|
color="yellow")
|
|
|
|
# First, await any pending responses without blocking the event loop
|
|
responses = await asyncio.to_thread(self.fetch_responses,
|
|
timeout=timeout)
|
|
return responses
|
|
|
|
async def fetch_stats_async(self, timeout: Optional[float] = None) -> list:
|
|
return await asyncio.to_thread(self.fetch_stats)
|
|
|
|
async def fetch_kv_cache_events_async(self,
|
|
timeout: Optional[float] = None
|
|
) -> list:
|
|
return await asyncio.to_thread(self.fetch_kv_cache_events)
|
|
|
|
# for streaming performance
|
|
async def fetch_responses_loop_async(self) -> AsyncGenerator[list, None]:
|
|
while not self.shutdown_event.is_set():
|
|
responses = await self.fetch_responses_async()
|
|
if responses: # Only yield if there are actual responses
|
|
logger_debug(
|
|
f"RpcWorker {mpi_rank()} is yielding responses: {responses}",
|
|
color="yellow")
|
|
yield responses # batching the responses to opt IPC performance
|
|
else:
|
|
# Small delay to prevent busy waiting when no responses
|
|
await asyncio.sleep(0)
|
|
logger_debug(
|
|
f"RpcWorker {mpi_rank()} quitting fetch_responses_loop_async",
|
|
color="yellow")
|
|
|
|
async def _generic_fetch_loop_async(
|
|
self,
|
|
fetch_method,
|
|
serializer,
|
|
method_name: str,
|
|
timeout: Optional[float] = None) -> AsyncGenerator[list, None]:
|
|
"""Generic method for fetching data in a loop.
|
|
|
|
Args:
|
|
fetch_method: The async method to call for fetching data
|
|
serializer: The serializer function to apply to each item
|
|
method_name: Name of the method for logging
|
|
timeout: Optional timeout between fetches
|
|
"""
|
|
while not self.shutdown_event.is_set():
|
|
timeout = timeout or 0.1
|
|
await asyncio.sleep(timeout)
|
|
data = await fetch_method()
|
|
# Always yield data, even if empty, to prevent the client looks like hanging
|
|
# TODO: Remove the empty data to reduce the IPC overhead
|
|
yield [serializer(item) for item in data]
|
|
logger_debug(f"RpcWorker {mpi_rank()} quitting {method_name}",
|
|
color="yellow")
|
|
|
|
async def fetch_stats_loop_async(
|
|
self,
|
|
timeout: Optional[float] = None) -> AsyncGenerator[list, None]:
|
|
async for data in self._generic_fetch_loop_async(
|
|
fetch_method=self.fetch_stats_async,
|
|
serializer=self._stats_serializer,
|
|
method_name="fetch_stats_loop_async",
|
|
timeout=timeout):
|
|
yield data
|
|
|
|
async def fetch_kv_cache_events_loop_async(
|
|
self,
|
|
timeout: Optional[float] = None) -> AsyncGenerator[list, None]:
|
|
async for data in self._generic_fetch_loop_async(
|
|
fetch_method=self.fetch_kv_cache_events_async,
|
|
serializer=self._kv_cache_events_serializer,
|
|
method_name="fetch_kv_cache_events_loop_async",
|
|
timeout=timeout):
|
|
yield data
|
|
|
|
def setup_engine(self):
|
|
# Force all the ranks to wait here, and start creating the executor simultaneously.
|
|
# Only call barrier if we have multiple ranks to avoid hanging in single-process tests
|
|
if mpi_comm().Get_size() > 1:
|
|
mpi_comm().barrier()
|
|
|
|
super().setup_engine()
|
|
|
|
def shutdown(self):
|
|
logger_debug(f"RPC worker {mpi_rank()} is shutting down",
|
|
color="yellow")
|
|
self.shutdown_event.set()
|
|
super().shutdown()
|
|
logger_debug(f"RPC worker {mpi_rank()} is shutdown", color="yellow")
|
|
|
|
def start(self):
|
|
pass
|
|
|
|
@staticmethod
|
|
def main_task(
|
|
engine: Union[Path, Engine],
|
|
rpc_addr: str,
|
|
*,
|
|
executor_config: Optional[tllm.ExecutorConfig] = None,
|
|
batched_logits_processor: Optional[BatchedLogitsProcessor] = None,
|
|
postproc_worker_config: Optional[PostprocWorkerConfig] = None,
|
|
is_llm_executor: Optional[bool] = None,
|
|
lora_config: Optional[LoraConfig] = None,
|
|
llm_args: Optional[BaseLlmArgs] = None,
|
|
kv_connector_config: Optional[KvCacheConnectorConfig] = None,
|
|
hf_model_dir: Optional[Path] = None,
|
|
tokenizer: Optional[TokenizerBase] = None,
|
|
**kwargs,
|
|
) -> None:
|
|
if enable_llm_debug():
|
|
set_level("debug")
|
|
|
|
# Step 1: Create the worker instance
|
|
worker = RpcWorker(
|
|
engine=engine,
|
|
executor_config=executor_config,
|
|
is_llm_executor=is_llm_executor,
|
|
lora_config=lora_config,
|
|
llm_args=llm_args,
|
|
batched_logits_processor=batched_logits_processor,
|
|
postproc_worker_config=postproc_worker_config,
|
|
kv_connector_config=kv_connector_config,
|
|
hf_model_dir=hf_model_dir,
|
|
tokenizer=tokenizer,
|
|
)
|
|
|
|
if mpi_rank() != 0:
|
|
# The non-leader worker will setup the engine immediately.
|
|
# The leader worker will wait for the RPC call to propagate the
|
|
# potential error.
|
|
logger_debug(f"Worker {mpi_rank()} is setting up the engine",
|
|
color="yellow")
|
|
worker.setup_engine()
|
|
|
|
else:
|
|
logger_debug(f"Worker {mpi_rank()} is creating the RPC service",
|
|
color="yellow")
|
|
# Step 2: Create the RPC service, it will expose all the APIs of the worker as remote call to the client
|
|
# Set num_workers to larger than 1 since there are some streaming tasks runs infinitely, such as await_responses_async.
|
|
rpc_server = RPCServer(worker, num_workers=RpcWorker.NUM_WORKERS)
|
|
rpc_server.bind(rpc_addr)
|
|
rpc_server.start()
|
|
|
|
# Step 3: Wait for the worker to shutdown
|
|
logger_debug(
|
|
f"Worker {mpi_rank()} is waiting for the worker to shutdown")
|
|
worker.shutdown_event.wait()
|
|
rpc_server.shutdown()
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
self.shutdown()
|
|
return True
|