mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Signed-off-by: Yan Chunwei <328693+Superjomn@users.noreply.github.com> Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>
247 lines
9.0 KiB
Python
247 lines
9.0 KiB
Python
import json
|
|
import threading
|
|
from typing import List, Optional
|
|
|
|
from ..llmapi.mpi_session import MpiPoolSession, MpiSession
|
|
from ..llmapi.utils import logger_debug, print_colored
|
|
from ..logger import logger
|
|
from .executor import GenerationExecutor
|
|
from .postproc_worker import PostprocWorkerConfig
|
|
from .result import IterationResult
|
|
from .rpc_proxy_mixin import RpcExecutorMixin
|
|
from .rpc_worker import RpcWorker
|
|
from .utils import create_mpi_comm_session, get_spawn_proxy_process_env
|
|
|
|
|
|
class GenerationExecutorRpcProxy(RpcExecutorMixin, GenerationExecutor):
|
|
# NOTE: this is a global counter for the number of instances of this class
|
|
INSTANCE_COUNTER = 0
|
|
|
|
def __init__(
|
|
self,
|
|
worker_kwargs: dict,
|
|
model_world_size: int = 1,
|
|
mpi_session: Optional[MpiSession] = None,
|
|
*,
|
|
postproc_worker_config: Optional[PostprocWorkerConfig] = None,
|
|
is_llm_executor: Optional[bool] = None,
|
|
):
|
|
"""
|
|
Args:
|
|
worker_kwargs: kwargs for the rpc worker
|
|
model_world_size: the world size of the model
|
|
mpi_session: the mpi session to use
|
|
postproc_worker_config: the postproc worker config
|
|
is_llm_executor: whether this is an llm executor
|
|
"""
|
|
GenerationExecutorRpcProxy.INSTANCE_COUNTER += 1
|
|
self.init_rpc_executor()
|
|
|
|
postproc_worker_config = postproc_worker_config or PostprocWorkerConfig(
|
|
)
|
|
|
|
super().__init__(
|
|
num_postprocess_workers=postproc_worker_config.
|
|
num_postprocess_workers,
|
|
postprocess_tokenizer_dir=postproc_worker_config.
|
|
postprocess_tokenizer_dir,
|
|
is_llm_executor=is_llm_executor,
|
|
)
|
|
|
|
self._create_mpi_session(model_world_size, mpi_session)
|
|
|
|
# Inject the generated HMAC key into worker_kwargs for workers
|
|
worker_kwargs['hmac_key'] = self.hmac_key
|
|
self.worker_kwargs = worker_kwargs
|
|
|
|
self.launch_workers()
|
|
|
|
# Invoke model creation on the remote
|
|
# TBD: Move model creation to the mpi task, or left in RPC?
|
|
self.setup_engine_remote()
|
|
|
|
# Setup main loop after engine is ready
|
|
self._setup_mainloop_with_tasks()
|
|
|
|
def launch_workers(self):
|
|
logger.debug(f"Launching workers")
|
|
assert self.mpi_session is not None
|
|
self.mpi_session.submit(RpcWorker.main_task,
|
|
rpc_addr=self.rpc_addr,
|
|
**self.worker_kwargs)
|
|
|
|
def _setup_mainloop_with_tasks(self):
|
|
"""Setup mainloop with tasks needed for RpcProxy.
|
|
|
|
Note: Stats and kv_events are now fetched on-demand via direct RPC calls
|
|
(get_stats, aget_stats, get_kv_events, aget_kv_events), not via streaming loops.
|
|
"""
|
|
tasks = [
|
|
self._fetch_responses_loop_async,
|
|
]
|
|
# Call mixin's setup_mainloop with custom tasks
|
|
self.setup_mainloop(tasks=tasks, thread_name="rpc_proxy_main_loop")
|
|
|
|
def get_stats(self, timeout: float) -> List[dict]:
|
|
"""Get iteration statistics from the runtime via RPC.
|
|
|
|
Args:
|
|
timeout (float): Max wait time in seconds for the RPC call.
|
|
|
|
Returns:
|
|
List[dict]: A list of runtime stats as dict.
|
|
"""
|
|
try:
|
|
stats = self.rpc_client.fetch_stats_wait_async(
|
|
timeout=timeout).remote()
|
|
return [json.loads(s) if isinstance(s, str) else s for s in stats]
|
|
except Exception as e:
|
|
logger.debug(f"Error fetching stats via RPC: {e}")
|
|
return []
|
|
|
|
def aget_stats(self, timeout: float) -> IterationResult:
|
|
"""Get iteration statistics from the runtime via RPC (async).
|
|
|
|
Args:
|
|
timeout (float): Max wait time in seconds for the RPC call.
|
|
|
|
Returns:
|
|
IterationResult: An async iterable object containing runtime stats.
|
|
"""
|
|
self._maybe_initialize_iteration_results()
|
|
|
|
if self._iter_stats_result is None:
|
|
print_colored("Iteration statistics are not available yet.\n",
|
|
"yellow")
|
|
from .executor import empty_async_iterable
|
|
return empty_async_iterable()
|
|
|
|
# Fetch stats via RPC and populate the result
|
|
try:
|
|
stats = self.rpc_client.fetch_stats_wait_async(
|
|
timeout=timeout).remote()
|
|
except Exception:
|
|
stats = []
|
|
|
|
for stat in stats:
|
|
self._iter_stats_result.queue.put(stat)
|
|
|
|
self._iter_stats_result.set_timeout(timeout)
|
|
return self._iter_stats_result
|
|
|
|
def get_kv_events(self, timeout: float) -> List[dict]:
|
|
"""Get iteration KV events from the runtime via RPC.
|
|
|
|
Args:
|
|
timeout (float): Max wait time in seconds for the RPC call.
|
|
|
|
Returns:
|
|
List[dict]: A list of runtime events as dict.
|
|
"""
|
|
try:
|
|
# Events are already serialized by the worker's fetch_kv_cache_events_wait_async()
|
|
events = self.rpc_client.fetch_kv_cache_events_wait_async(
|
|
timeout=timeout).remote()
|
|
return [json.loads(e) if isinstance(e, str) else e for e in events]
|
|
except Exception as e:
|
|
logger.debug(f"Error fetching kv events via RPC: {e}")
|
|
return []
|
|
|
|
def aget_kv_events(self, timeout: float) -> IterationResult:
|
|
"""Get iteration KV events from the runtime via RPC (async).
|
|
|
|
Args:
|
|
timeout (float): Max wait time in seconds for the RPC call.
|
|
|
|
Returns:
|
|
IterationResult: An async iterable object containing runtime events.
|
|
"""
|
|
# Initialize iteration result if needed
|
|
self._maybe_initialize_iteration_results()
|
|
|
|
if self._iter_kv_events_result is None:
|
|
from .executor import empty_async_iterable
|
|
return empty_async_iterable()
|
|
|
|
# Fetch kv events via RPC and populate the result
|
|
try:
|
|
events = self.rpc_client.fetch_kv_cache_events_wait_async(
|
|
timeout=timeout).remote()
|
|
except Exception:
|
|
events = []
|
|
|
|
for event in events:
|
|
self._iter_kv_events_result.queue.put(event)
|
|
|
|
self._iter_kv_events_result.set_timeout(timeout)
|
|
return self._iter_kv_events_result
|
|
|
|
def setup_engine_remote(self):
|
|
return self.rpc_client.setup_engine().remote(need_response=True)
|
|
|
|
def shutdown_remote(self):
|
|
logger_debug(f"Shutting down rpc remote", color="yellow")
|
|
self.rpc_client.shutdown().remote(need_response=False)
|
|
|
|
def abort_request(self, request_id: int) -> None:
|
|
return self.rpc_client.abort_request(request_id).remote()
|
|
|
|
def shutdown(self):
|
|
if self._shutdown_event.is_set():
|
|
return
|
|
self._shutdown_event.set()
|
|
logger_debug(f"Shutting down GenerationExecutorRpcProxy",
|
|
color="yellow")
|
|
|
|
# 1. shutdown the rpc server (PyExecutor Rank 0 + RPC server)
|
|
self.shutdown_remote()
|
|
|
|
# 2. stop the main loop, so that no new rpc requests
|
|
if self.main_loop and self.main_loop_task_obj:
|
|
logger_debug("Cancelling main loop task.", color="yellow")
|
|
# The cancel() is thread-safe
|
|
try:
|
|
self.main_loop.call_soon_threadsafe(
|
|
self.main_loop_task_obj.cancel)
|
|
except Exception as e:
|
|
logger_debug(f"Error cancelling main loop task: {e}",
|
|
color="yellow")
|
|
|
|
# Only join if we're not calling from the main_loop_thread itself
|
|
# (e.g., during garbage collection in that thread)
|
|
if self.main_loop_thread and threading.current_thread(
|
|
) != self.main_loop_thread:
|
|
self.main_loop_thread.join(timeout=2.0)
|
|
if self.main_loop_thread.is_alive():
|
|
logger.warning("Main loop thread did not exit gracefully")
|
|
|
|
# 3. shutdown the mpi session, this should wait until all the PyExecutor
|
|
# processes are shutdown
|
|
if self.mpi_session is not None:
|
|
logger_debug(f"Shutting down mpi session", color="yellow")
|
|
self.mpi_session.shutdown()
|
|
logger_debug(f"Mpi session shutdown", color="yellow")
|
|
self.mpi_session = None
|
|
|
|
self.rpc_client.close()
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
self.shutdown()
|
|
|
|
def _create_mpi_session(self, model_world_size: int,
|
|
mpi_session: Optional[MpiSession]):
|
|
mpi_process_pre_spawned: bool = get_spawn_proxy_process_env()
|
|
if mpi_session is None:
|
|
if mpi_process_pre_spawned:
|
|
logger_debug('[proxy] create comm session ...\n', "yellow")
|
|
self.mpi_session = create_mpi_comm_session(model_world_size)
|
|
else:
|
|
logger_debug('[proxy] create pool session ...\n', "yellow")
|
|
self.mpi_session = MpiPoolSession(n_workers=model_world_size)
|
|
else:
|
|
logger_debug('[proxy] using external mpi session ...\n', "yellow")
|
|
self.mpi_session = mpi_session
|