chore: code cleanup for error logging and SharedMemory in proxy.py (#3432)

* cleanup log

Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>

* remove shared-memory

Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>

* remove ExecutorResponse

Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>

* add assert for postproc

Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>

---------

Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>
This commit is contained in:
Yan Chunwei 2025-04-10 21:57:06 +08:00 committed by GitHub
parent d7a0bf934c
commit c5e803ba48
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 15 additions and 182 deletions

View File

@ -1,11 +1,8 @@
import io
import time
import traceback
from multiprocessing.shared_memory import SharedMemory
from queue import Queue
from typing import Any, Optional
import torch
import zmq
import zmq.asyncio
@ -13,7 +10,6 @@ from tensorrt_llm.logger import logger
from ..llmapi.utils import (ManagedThread, enable_llm_debug, nvtx_mark,
nvtx_range, print_colored, print_colored_debug)
from .utils import ExecutorResponse, ExecutorResponseTensors
class ZeroMqQueue:
@ -90,35 +86,11 @@ class ZeroMqQueue:
def put(self, obj: Any):
self.setup_lazily()
if isinstance(obj, ExecutorResponse):
tensors = self._store_tensors_in_shmm(obj.tensors)
obj = ExecutorResponse(
client_id=obj.client_id,
sequence_index=obj.sequence_index,
tensors=tensors,
finish_reasons=obj.finish_reasons,
is_final=obj.is_final,
error=obj.error,
timestamp=obj.timestamp,
disaggregated_params=obj.disaggregated_params)
with nvtx_range("send", color="blue", category="IPC"):
self.socket.send_pyobj(obj)
async def put_async(self, obj: Any):
self.setup_lazily()
if isinstance(obj, ExecutorResponse):
tensors = self._store_tensors_in_shmm(obj.tensors)
obj = ExecutorResponse(
client_id=obj.client_id,
tensors=tensors,
finish_reasons=obj.finish_reasons,
is_final=obj.is_final,
error=obj.error,
timestamp=obj.timestamp,
disaggregated_params=obj.disaggregated_params)
try:
await self.socket.send_pyobj(obj)
except TypeError as e:
@ -134,39 +106,12 @@ class ZeroMqQueue:
def get(self) -> Any:
self.setup_lazily()
obj = self.socket.recv_pyobj()
nvtx_mark("ipc.get", color="orange", category="IPC")
if isinstance(obj, ExecutorResponse):
tensors = self._load_tensors_from_shmm(obj.tensors)
obj = ExecutorResponse(
client_id=obj.client_id,
tensors=tensors,
finish_reasons=obj.finish_reasons,
is_final=obj.is_final,
error=obj.error,
timestamp=obj.timestamp,
disaggregated_params=obj.disaggregated_params)
return obj
return self.socket.recv_pyobj()
async def get_async(self) -> Any:
self.setup_lazily()
obj = await self.socket.recv_pyobj()
nvtx_mark("ipc.get", color="orange", category="IPC")
if isinstance(obj, ExecutorResponse):
tensors = self._load_tensors_from_shmm(obj.tensors)
obj = ExecutorResponse(
client_id=obj.client_id,
tensors=tensors,
sequence_index=obj.sequence_index,
finish_reasons=obj.finish_reasons,
is_final=obj.is_final,
error=obj.error,
timestamp=obj.timestamp,
disaggregated_params=obj.disaggregated_params)
return obj
return await self.socket.recv_pyobj()
def close(self):
if self.socket:
@ -176,59 +121,6 @@ class ZeroMqQueue:
self.context.term()
self.context = None
def _store_tensors_in_shmm(
self, tensors: Optional["ExecutorResponseTensors"]
) -> Optional["ExecutorResponseTensors"]:
if tensors is None:
return tensors
# The tensors are huge and cannot be transferred through socket directly. We need to store them in shared memory,
# and replace the tensors with the shared memory path.
def store_tensor(tensor: Optional[torch.Tensor]) -> Optional[str]:
if tensor is None:
return None
# NOTE: We create random shmm here rather than two specific shmm for context and generation logit, since the
# shmm may not be read timely by the IpcQueue.get() in the other side, so there might be multiple alive shmm
# for logits.
# A known issue: the shmm instance may leak if the IpcQueue.get() thread is stopped before the IpcQueue.put()
# thread. This is not a big issue since the shmm will be automatically cleaned up when the process exits.
shm = SharedMemory(create=True, size=tensor.nbytes + 2048)
torch.save(tensor, shm._mmap)
shm.close()
return shm.name
return ExecutorResponseTensors(
output_token_ids=tensors.output_token_ids,
context_logits=store_tensor(tensors.context_logits),
generation_logits=store_tensor(tensors.generation_logits),
log_probs=tensors.log_probs,
cum_log_probs=tensors.cum_log_probs,
)
def _load_tensors_from_shmm(
self, tensors: Optional["ExecutorResponseTensors"]
) -> Optional["ExecutorResponseTensors"]:
if tensors is None:
return tensors
def load_tensor(tensor: Optional[str]) -> Optional[torch.Tensor]:
if tensor is None or isinstance(tensor, torch.Tensor):
return tensor
shm = SharedMemory(name=tensor, create=False)
tensor = torch.load(io.BytesIO(shm.buf))
shm.close()
shm.unlink()
return tensor
return ExecutorResponseTensors(
output_token_ids=tensors.output_token_ids,
context_logits=load_tensor(tensors.context_logits),
generation_logits=load_tensor(tensors.generation_logits),
log_probs=tensors.log_probs,
cum_log_probs=tensors.cum_log_probs,
)
def __del__(self):
self.close()
@ -284,45 +176,13 @@ class FusedIpcQueue:
def put(self, obj: Any):
self.setup_sender()
if self.fuse_message:
self.sending_queue.put_nowait(self._prepare_message(obj))
self.sending_queue.put_nowait(obj)
else:
batch = obj if isinstance(obj, list) else [obj]
batch = [self._prepare_message(x) for x in batch]
self.queue.put(batch)
def get(self) -> Any:
obj = self.queue.get()
if isinstance(obj, list):
return [self._process_message(o) for o in obj]
return self._process_message(obj)
def _prepare_message(self, obj: Any) -> Any:
if isinstance(obj, ExecutorResponse):
tensors = self.queue._store_tensors_in_shmm(obj.tensors)
return ExecutorResponse(
client_id=obj.client_id,
tensors=tensors,
finish_reasons=obj.finish_reasons,
is_final=obj.is_final,
sequence_index=obj.sequence_index,
error=obj.error,
timestamp=obj.timestamp,
disaggregated_params=obj.disaggregated_params)
return obj
def _process_message(self, obj: Any) -> Any:
if isinstance(obj, ExecutorResponse):
tensors = self.queue._load_tensors_from_shmm(obj.tensors)
return ExecutorResponse(
client_id=obj.client_id,
tensors=tensors,
finish_reasons=obj.finish_reasons,
is_final=obj.is_final,
sequence_index=obj.sequence_index,
error=obj.error,
timestamp=obj.timestamp,
disaggregated_params=obj.disaggregated_params)
return obj
return self.queue.get()
@property
def address(self) -> str:

View File

@ -119,6 +119,12 @@ class PostprocWorker:
async def _handle_input(self, input: "PostprocWorker.Input") -> Any:
''' Handle a single response from await_response worker. '''
if input.rsp.result.context_logits is not None or \
input.rsp.result.generation_logits is not None:
raise ValueError(
"Context logits or generation logits are not supposed to be "
"sent to postprocessing workers.")
with nvtx_range("handle_input", color="yellow", category="Postproc"):
req_id = input.rsp.client_id
if req_id not in self._records:

View File

@ -183,12 +183,12 @@ class ExecutorBindingsProxy(GenerationExecutor):
try:
data = queue.get()
except:
logger.error(
logger.debug(
"proxy.py: Error in _iteration_result_task: queue.get()")
return False
if data is None:
logger.error("proxy.py: _iteration_result_task: data is None")
logger.debug("proxy.py: _iteration_result_task: data is None")
return False # shutdown the thread
data = data if isinstance(data, list) else [data]
@ -201,7 +201,7 @@ class ExecutorBindingsProxy(GenerationExecutor):
try:
for d in data:
if d is None:
logger.error("proxy.py: _iteration_result_task: d is None")
logger.debug("proxy.py: _iteration_result_task: d is None")
return False
if isinstance(queue, _SyncQueue):
@ -219,7 +219,7 @@ class ExecutorBindingsProxy(GenerationExecutor):
# and therefore event loop can already be closed.
logger.debug("proxy.py: EventLoopShutdownError")
except Exception as e:
logger.error(f"proxy.py: Error in _iteration_result_task: {e}")
logger.debug(f"proxy.py: Error in _iteration_result_task: {e}")
raise e
return True # success

View File

@ -3,15 +3,11 @@ import concurrent.futures
import os
from concurrent.futures import ProcessPoolExecutor
from queue import Empty, Queue
from typing import Any, Callable, List, NamedTuple, Optional
import torch
from typing import Any, Callable, List, NamedTuple
from tensorrt_llm.llmapi.utils import print_colored_debug
from tensorrt_llm.logger import logger
from ..bindings import executor as tllm
from ..disaggregated_params import DisaggregatedParams
from ..llmapi.mpi_session import (MpiCommSession, MpiPoolSession, MpiSession,
RemoteMpiCommSessionClient)
from ..llmapi.utils import print_colored_debug
@ -90,41 +86,12 @@ class ProcessPoolExecutorSession(MpiSession):
self.mpi_pool.shutdown(wait=True)
class ExecutorResponseTensors(NamedTuple):
output_token_ids: List[List[int]]
# context_logits is a tensor or a string denoting the path to the shared memory.
context_logits: Optional[torch.Tensor | str]
# generation_logits is a tensor or a string denoting the path to the shared memory.
generation_logits: Optional[torch.Tensor | str]
log_probs: Optional[list]
cum_log_probs: Optional[list]
class ErrorResponse(NamedTuple):
client_id: int
error_msg: str
request_id: int
class ExecutorResponse(NamedTuple):
""" The response from the cpp-executor to the Python main thread. """
client_id: int
tensors: Optional[ExecutorResponseTensors]
finish_reasons: Optional[List[tllm.FinishReason]]
is_final: Optional[bool]
sequence_index: Optional[int]
# There are two types of errors:
# 1. str for the errors from the cpp-executor.await_responses, this will be dispatched to the user's
# generate_async as a per-request error, and won't stop the whole service.
# 2. Exception for the errors from the background threads/processes, this will be processed in the main thread,
# and stop the whole service.
error: Optional[str | Exception]
# The timestamp of the creation of the response. We use this to track the IPC overhead.
timestamp: Optional[float] = None
# Optional disaggregated serving params needed by the generation instances
disaggregated_params: Optional[DisaggregatedParams] = None
class IntraProcessQueue:
''' A Queue-like container for IPC within the same process. '''