From c5e803ba48bea76292d23cf5a99f81924b9e32d8 Mon Sep 17 00:00:00 2001 From: Yan Chunwei <328693+Superjomn@users.noreply.github.com> Date: Thu, 10 Apr 2025 21:57:06 +0800 Subject: [PATCH] chore: code cleanup for error logging and SharedMemory in proxy.py (#3432) * cleanup log Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com> * remove shared-memory Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com> * remove ExecutorResponse Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com> * add assert for postproc Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com> --------- Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com> --- tensorrt_llm/executor/ipc.py | 148 +---------------------- tensorrt_llm/executor/postproc_worker.py | 6 + tensorrt_llm/executor/proxy.py | 8 +- tensorrt_llm/executor/utils.py | 35 +----- 4 files changed, 15 insertions(+), 182 deletions(-) diff --git a/tensorrt_llm/executor/ipc.py b/tensorrt_llm/executor/ipc.py index f2ca3d8767..d05bade96a 100644 --- a/tensorrt_llm/executor/ipc.py +++ b/tensorrt_llm/executor/ipc.py @@ -1,11 +1,8 @@ -import io import time import traceback -from multiprocessing.shared_memory import SharedMemory from queue import Queue from typing import Any, Optional -import torch import zmq import zmq.asyncio @@ -13,7 +10,6 @@ from tensorrt_llm.logger import logger from ..llmapi.utils import (ManagedThread, enable_llm_debug, nvtx_mark, nvtx_range, print_colored, print_colored_debug) -from .utils import ExecutorResponse, ExecutorResponseTensors class ZeroMqQueue: @@ -90,35 +86,11 @@ class ZeroMqQueue: def put(self, obj: Any): self.setup_lazily() - - if isinstance(obj, ExecutorResponse): - tensors = self._store_tensors_in_shmm(obj.tensors) - obj = ExecutorResponse( - client_id=obj.client_id, - sequence_index=obj.sequence_index, - tensors=tensors, - finish_reasons=obj.finish_reasons, - is_final=obj.is_final, - error=obj.error, - timestamp=obj.timestamp, - disaggregated_params=obj.disaggregated_params) - with nvtx_range("send", color="blue", category="IPC"): self.socket.send_pyobj(obj) async def put_async(self, obj: Any): self.setup_lazily() - if isinstance(obj, ExecutorResponse): - tensors = self._store_tensors_in_shmm(obj.tensors) - obj = ExecutorResponse( - client_id=obj.client_id, - tensors=tensors, - finish_reasons=obj.finish_reasons, - is_final=obj.is_final, - error=obj.error, - timestamp=obj.timestamp, - disaggregated_params=obj.disaggregated_params) - try: await self.socket.send_pyobj(obj) except TypeError as e: @@ -134,39 +106,12 @@ class ZeroMqQueue: def get(self) -> Any: self.setup_lazily() - obj = self.socket.recv_pyobj() - nvtx_mark("ipc.get", color="orange", category="IPC") - - if isinstance(obj, ExecutorResponse): - tensors = self._load_tensors_from_shmm(obj.tensors) - obj = ExecutorResponse( - client_id=obj.client_id, - tensors=tensors, - finish_reasons=obj.finish_reasons, - is_final=obj.is_final, - error=obj.error, - timestamp=obj.timestamp, - disaggregated_params=obj.disaggregated_params) - return obj + return self.socket.recv_pyobj() async def get_async(self) -> Any: self.setup_lazily() - obj = await self.socket.recv_pyobj() - nvtx_mark("ipc.get", color="orange", category="IPC") - - if isinstance(obj, ExecutorResponse): - tensors = self._load_tensors_from_shmm(obj.tensors) - obj = ExecutorResponse( - client_id=obj.client_id, - tensors=tensors, - sequence_index=obj.sequence_index, - finish_reasons=obj.finish_reasons, - is_final=obj.is_final, - error=obj.error, - timestamp=obj.timestamp, - disaggregated_params=obj.disaggregated_params) - return obj + return await self.socket.recv_pyobj() def close(self): if self.socket: @@ -176,59 +121,6 @@ class ZeroMqQueue: self.context.term() self.context = None - def _store_tensors_in_shmm( - self, tensors: Optional["ExecutorResponseTensors"] - ) -> Optional["ExecutorResponseTensors"]: - if tensors is None: - return tensors - - # The tensors are huge and cannot be transferred through socket directly. We need to store them in shared memory, - # and replace the tensors with the shared memory path. - def store_tensor(tensor: Optional[torch.Tensor]) -> Optional[str]: - if tensor is None: - return None - # NOTE: We create random shmm here rather than two specific shmm for context and generation logit, since the - # shmm may not be read timely by the IpcQueue.get() in the other side, so there might be multiple alive shmm - # for logits. - # A known issue: the shmm instance may leak if the IpcQueue.get() thread is stopped before the IpcQueue.put() - # thread. This is not a big issue since the shmm will be automatically cleaned up when the process exits. - shm = SharedMemory(create=True, size=tensor.nbytes + 2048) - torch.save(tensor, shm._mmap) - shm.close() - return shm.name - - return ExecutorResponseTensors( - output_token_ids=tensors.output_token_ids, - context_logits=store_tensor(tensors.context_logits), - generation_logits=store_tensor(tensors.generation_logits), - log_probs=tensors.log_probs, - cum_log_probs=tensors.cum_log_probs, - ) - - def _load_tensors_from_shmm( - self, tensors: Optional["ExecutorResponseTensors"] - ) -> Optional["ExecutorResponseTensors"]: - if tensors is None: - return tensors - - def load_tensor(tensor: Optional[str]) -> Optional[torch.Tensor]: - if tensor is None or isinstance(tensor, torch.Tensor): - return tensor - - shm = SharedMemory(name=tensor, create=False) - tensor = torch.load(io.BytesIO(shm.buf)) - shm.close() - shm.unlink() - return tensor - - return ExecutorResponseTensors( - output_token_ids=tensors.output_token_ids, - context_logits=load_tensor(tensors.context_logits), - generation_logits=load_tensor(tensors.generation_logits), - log_probs=tensors.log_probs, - cum_log_probs=tensors.cum_log_probs, - ) - def __del__(self): self.close() @@ -284,45 +176,13 @@ class FusedIpcQueue: def put(self, obj: Any): self.setup_sender() if self.fuse_message: - self.sending_queue.put_nowait(self._prepare_message(obj)) + self.sending_queue.put_nowait(obj) else: batch = obj if isinstance(obj, list) else [obj] - batch = [self._prepare_message(x) for x in batch] self.queue.put(batch) def get(self) -> Any: - obj = self.queue.get() - if isinstance(obj, list): - return [self._process_message(o) for o in obj] - return self._process_message(obj) - - def _prepare_message(self, obj: Any) -> Any: - if isinstance(obj, ExecutorResponse): - tensors = self.queue._store_tensors_in_shmm(obj.tensors) - return ExecutorResponse( - client_id=obj.client_id, - tensors=tensors, - finish_reasons=obj.finish_reasons, - is_final=obj.is_final, - sequence_index=obj.sequence_index, - error=obj.error, - timestamp=obj.timestamp, - disaggregated_params=obj.disaggregated_params) - return obj - - def _process_message(self, obj: Any) -> Any: - if isinstance(obj, ExecutorResponse): - tensors = self.queue._load_tensors_from_shmm(obj.tensors) - return ExecutorResponse( - client_id=obj.client_id, - tensors=tensors, - finish_reasons=obj.finish_reasons, - is_final=obj.is_final, - sequence_index=obj.sequence_index, - error=obj.error, - timestamp=obj.timestamp, - disaggregated_params=obj.disaggregated_params) - return obj + return self.queue.get() @property def address(self) -> str: diff --git a/tensorrt_llm/executor/postproc_worker.py b/tensorrt_llm/executor/postproc_worker.py index b3347a09c4..f07d1d3b43 100644 --- a/tensorrt_llm/executor/postproc_worker.py +++ b/tensorrt_llm/executor/postproc_worker.py @@ -119,6 +119,12 @@ class PostprocWorker: async def _handle_input(self, input: "PostprocWorker.Input") -> Any: ''' Handle a single response from await_response worker. ''' + if input.rsp.result.context_logits is not None or \ + input.rsp.result.generation_logits is not None: + raise ValueError( + "Context logits or generation logits are not supposed to be " + "sent to postprocessing workers.") + with nvtx_range("handle_input", color="yellow", category="Postproc"): req_id = input.rsp.client_id if req_id not in self._records: diff --git a/tensorrt_llm/executor/proxy.py b/tensorrt_llm/executor/proxy.py index 8b5bbd804d..37deae30bc 100644 --- a/tensorrt_llm/executor/proxy.py +++ b/tensorrt_llm/executor/proxy.py @@ -183,12 +183,12 @@ class ExecutorBindingsProxy(GenerationExecutor): try: data = queue.get() except: - logger.error( + logger.debug( "proxy.py: Error in _iteration_result_task: queue.get()") return False if data is None: - logger.error("proxy.py: _iteration_result_task: data is None") + logger.debug("proxy.py: _iteration_result_task: data is None") return False # shutdown the thread data = data if isinstance(data, list) else [data] @@ -201,7 +201,7 @@ class ExecutorBindingsProxy(GenerationExecutor): try: for d in data: if d is None: - logger.error("proxy.py: _iteration_result_task: d is None") + logger.debug("proxy.py: _iteration_result_task: d is None") return False if isinstance(queue, _SyncQueue): @@ -219,7 +219,7 @@ class ExecutorBindingsProxy(GenerationExecutor): # and therefore event loop can already be closed. logger.debug("proxy.py: EventLoopShutdownError") except Exception as e: - logger.error(f"proxy.py: Error in _iteration_result_task: {e}") + logger.debug(f"proxy.py: Error in _iteration_result_task: {e}") raise e return True # success diff --git a/tensorrt_llm/executor/utils.py b/tensorrt_llm/executor/utils.py index 5f8d42a522..3f85b12561 100644 --- a/tensorrt_llm/executor/utils.py +++ b/tensorrt_llm/executor/utils.py @@ -3,15 +3,11 @@ import concurrent.futures import os from concurrent.futures import ProcessPoolExecutor from queue import Empty, Queue -from typing import Any, Callable, List, NamedTuple, Optional - -import torch +from typing import Any, Callable, List, NamedTuple from tensorrt_llm.llmapi.utils import print_colored_debug from tensorrt_llm.logger import logger -from ..bindings import executor as tllm -from ..disaggregated_params import DisaggregatedParams from ..llmapi.mpi_session import (MpiCommSession, MpiPoolSession, MpiSession, RemoteMpiCommSessionClient) from ..llmapi.utils import print_colored_debug @@ -90,41 +86,12 @@ class ProcessPoolExecutorSession(MpiSession): self.mpi_pool.shutdown(wait=True) -class ExecutorResponseTensors(NamedTuple): - output_token_ids: List[List[int]] - # context_logits is a tensor or a string denoting the path to the shared memory. - context_logits: Optional[torch.Tensor | str] - # generation_logits is a tensor or a string denoting the path to the shared memory. - generation_logits: Optional[torch.Tensor | str] - log_probs: Optional[list] - cum_log_probs: Optional[list] - - class ErrorResponse(NamedTuple): client_id: int error_msg: str request_id: int -class ExecutorResponse(NamedTuple): - """ The response from the cpp-executor to the Python main thread. """ - client_id: int - tensors: Optional[ExecutorResponseTensors] - finish_reasons: Optional[List[tllm.FinishReason]] - is_final: Optional[bool] - sequence_index: Optional[int] - # There are two types of errors: - # 1. str for the errors from the cpp-executor.await_responses, this will be dispatched to the user's - # generate_async as a per-request error, and won't stop the whole service. - # 2. Exception for the errors from the background threads/processes, this will be processed in the main thread, - # and stop the whole service. - error: Optional[str | Exception] - # The timestamp of the creation of the response. We use this to track the IPC overhead. - timestamp: Optional[float] = None - # Optional disaggregated serving params needed by the generation instances - disaggregated_params: Optional[DisaggregatedParams] = None - - class IntraProcessQueue: ''' A Queue-like container for IPC within the same process. '''