TensorRT-LLMs/tensorrt_llm/executor/rpc_proxy.py
Yan Chunwei ea6cd76c55
[None][refactor] simplify get_stats and get_kvcache_events with rpc (#9980)
Signed-off-by: Yan Chunwei <328693+Superjomn@users.noreply.github.com>
Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>
2025-12-22 18:23:43 +08:00

247 lines
9.0 KiB
Python

import json
import threading
from typing import List, Optional
from ..llmapi.mpi_session import MpiPoolSession, MpiSession
from ..llmapi.utils import logger_debug, print_colored
from ..logger import logger
from .executor import GenerationExecutor
from .postproc_worker import PostprocWorkerConfig
from .result import IterationResult
from .rpc_proxy_mixin import RpcExecutorMixin
from .rpc_worker import RpcWorker
from .utils import create_mpi_comm_session, get_spawn_proxy_process_env
class GenerationExecutorRpcProxy(RpcExecutorMixin, GenerationExecutor):
# NOTE: this is a global counter for the number of instances of this class
INSTANCE_COUNTER = 0
def __init__(
self,
worker_kwargs: dict,
model_world_size: int = 1,
mpi_session: Optional[MpiSession] = None,
*,
postproc_worker_config: Optional[PostprocWorkerConfig] = None,
is_llm_executor: Optional[bool] = None,
):
"""
Args:
worker_kwargs: kwargs for the rpc worker
model_world_size: the world size of the model
mpi_session: the mpi session to use
postproc_worker_config: the postproc worker config
is_llm_executor: whether this is an llm executor
"""
GenerationExecutorRpcProxy.INSTANCE_COUNTER += 1
self.init_rpc_executor()
postproc_worker_config = postproc_worker_config or PostprocWorkerConfig(
)
super().__init__(
num_postprocess_workers=postproc_worker_config.
num_postprocess_workers,
postprocess_tokenizer_dir=postproc_worker_config.
postprocess_tokenizer_dir,
is_llm_executor=is_llm_executor,
)
self._create_mpi_session(model_world_size, mpi_session)
# Inject the generated HMAC key into worker_kwargs for workers
worker_kwargs['hmac_key'] = self.hmac_key
self.worker_kwargs = worker_kwargs
self.launch_workers()
# Invoke model creation on the remote
# TBD: Move model creation to the mpi task, or left in RPC?
self.setup_engine_remote()
# Setup main loop after engine is ready
self._setup_mainloop_with_tasks()
def launch_workers(self):
logger.debug(f"Launching workers")
assert self.mpi_session is not None
self.mpi_session.submit(RpcWorker.main_task,
rpc_addr=self.rpc_addr,
**self.worker_kwargs)
def _setup_mainloop_with_tasks(self):
"""Setup mainloop with tasks needed for RpcProxy.
Note: Stats and kv_events are now fetched on-demand via direct RPC calls
(get_stats, aget_stats, get_kv_events, aget_kv_events), not via streaming loops.
"""
tasks = [
self._fetch_responses_loop_async,
]
# Call mixin's setup_mainloop with custom tasks
self.setup_mainloop(tasks=tasks, thread_name="rpc_proxy_main_loop")
def get_stats(self, timeout: float) -> List[dict]:
"""Get iteration statistics from the runtime via RPC.
Args:
timeout (float): Max wait time in seconds for the RPC call.
Returns:
List[dict]: A list of runtime stats as dict.
"""
try:
stats = self.rpc_client.fetch_stats_wait_async(
timeout=timeout).remote()
return [json.loads(s) if isinstance(s, str) else s for s in stats]
except Exception as e:
logger.debug(f"Error fetching stats via RPC: {e}")
return []
def aget_stats(self, timeout: float) -> IterationResult:
"""Get iteration statistics from the runtime via RPC (async).
Args:
timeout (float): Max wait time in seconds for the RPC call.
Returns:
IterationResult: An async iterable object containing runtime stats.
"""
self._maybe_initialize_iteration_results()
if self._iter_stats_result is None:
print_colored("Iteration statistics are not available yet.\n",
"yellow")
from .executor import empty_async_iterable
return empty_async_iterable()
# Fetch stats via RPC and populate the result
try:
stats = self.rpc_client.fetch_stats_wait_async(
timeout=timeout).remote()
except Exception:
stats = []
for stat in stats:
self._iter_stats_result.queue.put(stat)
self._iter_stats_result.set_timeout(timeout)
return self._iter_stats_result
def get_kv_events(self, timeout: float) -> List[dict]:
"""Get iteration KV events from the runtime via RPC.
Args:
timeout (float): Max wait time in seconds for the RPC call.
Returns:
List[dict]: A list of runtime events as dict.
"""
try:
# Events are already serialized by the worker's fetch_kv_cache_events_wait_async()
events = self.rpc_client.fetch_kv_cache_events_wait_async(
timeout=timeout).remote()
return [json.loads(e) if isinstance(e, str) else e for e in events]
except Exception as e:
logger.debug(f"Error fetching kv events via RPC: {e}")
return []
def aget_kv_events(self, timeout: float) -> IterationResult:
"""Get iteration KV events from the runtime via RPC (async).
Args:
timeout (float): Max wait time in seconds for the RPC call.
Returns:
IterationResult: An async iterable object containing runtime events.
"""
# Initialize iteration result if needed
self._maybe_initialize_iteration_results()
if self._iter_kv_events_result is None:
from .executor import empty_async_iterable
return empty_async_iterable()
# Fetch kv events via RPC and populate the result
try:
events = self.rpc_client.fetch_kv_cache_events_wait_async(
timeout=timeout).remote()
except Exception:
events = []
for event in events:
self._iter_kv_events_result.queue.put(event)
self._iter_kv_events_result.set_timeout(timeout)
return self._iter_kv_events_result
def setup_engine_remote(self):
return self.rpc_client.setup_engine().remote(need_response=True)
def shutdown_remote(self):
logger_debug(f"Shutting down rpc remote", color="yellow")
self.rpc_client.shutdown().remote(need_response=False)
def abort_request(self, request_id: int) -> None:
return self.rpc_client.abort_request(request_id).remote()
def shutdown(self):
if self._shutdown_event.is_set():
return
self._shutdown_event.set()
logger_debug(f"Shutting down GenerationExecutorRpcProxy",
color="yellow")
# 1. shutdown the rpc server (PyExecutor Rank 0 + RPC server)
self.shutdown_remote()
# 2. stop the main loop, so that no new rpc requests
if self.main_loop and self.main_loop_task_obj:
logger_debug("Cancelling main loop task.", color="yellow")
# The cancel() is thread-safe
try:
self.main_loop.call_soon_threadsafe(
self.main_loop_task_obj.cancel)
except Exception as e:
logger_debug(f"Error cancelling main loop task: {e}",
color="yellow")
# Only join if we're not calling from the main_loop_thread itself
# (e.g., during garbage collection in that thread)
if self.main_loop_thread and threading.current_thread(
) != self.main_loop_thread:
self.main_loop_thread.join(timeout=2.0)
if self.main_loop_thread.is_alive():
logger.warning("Main loop thread did not exit gracefully")
# 3. shutdown the mpi session, this should wait until all the PyExecutor
# processes are shutdown
if self.mpi_session is not None:
logger_debug(f"Shutting down mpi session", color="yellow")
self.mpi_session.shutdown()
logger_debug(f"Mpi session shutdown", color="yellow")
self.mpi_session = None
self.rpc_client.close()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.shutdown()
def _create_mpi_session(self, model_world_size: int,
mpi_session: Optional[MpiSession]):
mpi_process_pre_spawned: bool = get_spawn_proxy_process_env()
if mpi_session is None:
if mpi_process_pre_spawned:
logger_debug('[proxy] create comm session ...\n', "yellow")
self.mpi_session = create_mpi_comm_session(model_world_size)
else:
logger_debug('[proxy] create pool session ...\n', "yellow")
self.mpi_session = MpiPoolSession(n_workers=model_world_size)
else:
logger_debug('[proxy] using external mpi session ...\n', "yellow")
self.mpi_session = mpi_session