mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
chore: code cleanup for error logging and SharedMemory in proxy.py (#3432)
* cleanup log Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com> * remove shared-memory Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com> * remove ExecutorResponse Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com> * add assert for postproc Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com> --------- Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>
This commit is contained in:
parent
d7a0bf934c
commit
c5e803ba48
@ -1,11 +1,8 @@
|
||||
import io
|
||||
import time
|
||||
import traceback
|
||||
from multiprocessing.shared_memory import SharedMemory
|
||||
from queue import Queue
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
|
||||
@ -13,7 +10,6 @@ from tensorrt_llm.logger import logger
|
||||
|
||||
from ..llmapi.utils import (ManagedThread, enable_llm_debug, nvtx_mark,
|
||||
nvtx_range, print_colored, print_colored_debug)
|
||||
from .utils import ExecutorResponse, ExecutorResponseTensors
|
||||
|
||||
|
||||
class ZeroMqQueue:
|
||||
@ -90,35 +86,11 @@ class ZeroMqQueue:
|
||||
|
||||
def put(self, obj: Any):
|
||||
self.setup_lazily()
|
||||
|
||||
if isinstance(obj, ExecutorResponse):
|
||||
tensors = self._store_tensors_in_shmm(obj.tensors)
|
||||
obj = ExecutorResponse(
|
||||
client_id=obj.client_id,
|
||||
sequence_index=obj.sequence_index,
|
||||
tensors=tensors,
|
||||
finish_reasons=obj.finish_reasons,
|
||||
is_final=obj.is_final,
|
||||
error=obj.error,
|
||||
timestamp=obj.timestamp,
|
||||
disaggregated_params=obj.disaggregated_params)
|
||||
|
||||
with nvtx_range("send", color="blue", category="IPC"):
|
||||
self.socket.send_pyobj(obj)
|
||||
|
||||
async def put_async(self, obj: Any):
|
||||
self.setup_lazily()
|
||||
if isinstance(obj, ExecutorResponse):
|
||||
tensors = self._store_tensors_in_shmm(obj.tensors)
|
||||
obj = ExecutorResponse(
|
||||
client_id=obj.client_id,
|
||||
tensors=tensors,
|
||||
finish_reasons=obj.finish_reasons,
|
||||
is_final=obj.is_final,
|
||||
error=obj.error,
|
||||
timestamp=obj.timestamp,
|
||||
disaggregated_params=obj.disaggregated_params)
|
||||
|
||||
try:
|
||||
await self.socket.send_pyobj(obj)
|
||||
except TypeError as e:
|
||||
@ -134,39 +106,12 @@ class ZeroMqQueue:
|
||||
def get(self) -> Any:
|
||||
self.setup_lazily()
|
||||
|
||||
obj = self.socket.recv_pyobj()
|
||||
nvtx_mark("ipc.get", color="orange", category="IPC")
|
||||
|
||||
if isinstance(obj, ExecutorResponse):
|
||||
tensors = self._load_tensors_from_shmm(obj.tensors)
|
||||
obj = ExecutorResponse(
|
||||
client_id=obj.client_id,
|
||||
tensors=tensors,
|
||||
finish_reasons=obj.finish_reasons,
|
||||
is_final=obj.is_final,
|
||||
error=obj.error,
|
||||
timestamp=obj.timestamp,
|
||||
disaggregated_params=obj.disaggregated_params)
|
||||
return obj
|
||||
return self.socket.recv_pyobj()
|
||||
|
||||
async def get_async(self) -> Any:
|
||||
self.setup_lazily()
|
||||
|
||||
obj = await self.socket.recv_pyobj()
|
||||
nvtx_mark("ipc.get", color="orange", category="IPC")
|
||||
|
||||
if isinstance(obj, ExecutorResponse):
|
||||
tensors = self._load_tensors_from_shmm(obj.tensors)
|
||||
obj = ExecutorResponse(
|
||||
client_id=obj.client_id,
|
||||
tensors=tensors,
|
||||
sequence_index=obj.sequence_index,
|
||||
finish_reasons=obj.finish_reasons,
|
||||
is_final=obj.is_final,
|
||||
error=obj.error,
|
||||
timestamp=obj.timestamp,
|
||||
disaggregated_params=obj.disaggregated_params)
|
||||
return obj
|
||||
return await self.socket.recv_pyobj()
|
||||
|
||||
def close(self):
|
||||
if self.socket:
|
||||
@ -176,59 +121,6 @@ class ZeroMqQueue:
|
||||
self.context.term()
|
||||
self.context = None
|
||||
|
||||
def _store_tensors_in_shmm(
|
||||
self, tensors: Optional["ExecutorResponseTensors"]
|
||||
) -> Optional["ExecutorResponseTensors"]:
|
||||
if tensors is None:
|
||||
return tensors
|
||||
|
||||
# The tensors are huge and cannot be transferred through socket directly. We need to store them in shared memory,
|
||||
# and replace the tensors with the shared memory path.
|
||||
def store_tensor(tensor: Optional[torch.Tensor]) -> Optional[str]:
|
||||
if tensor is None:
|
||||
return None
|
||||
# NOTE: We create random shmm here rather than two specific shmm for context and generation logit, since the
|
||||
# shmm may not be read timely by the IpcQueue.get() in the other side, so there might be multiple alive shmm
|
||||
# for logits.
|
||||
# A known issue: the shmm instance may leak if the IpcQueue.get() thread is stopped before the IpcQueue.put()
|
||||
# thread. This is not a big issue since the shmm will be automatically cleaned up when the process exits.
|
||||
shm = SharedMemory(create=True, size=tensor.nbytes + 2048)
|
||||
torch.save(tensor, shm._mmap)
|
||||
shm.close()
|
||||
return shm.name
|
||||
|
||||
return ExecutorResponseTensors(
|
||||
output_token_ids=tensors.output_token_ids,
|
||||
context_logits=store_tensor(tensors.context_logits),
|
||||
generation_logits=store_tensor(tensors.generation_logits),
|
||||
log_probs=tensors.log_probs,
|
||||
cum_log_probs=tensors.cum_log_probs,
|
||||
)
|
||||
|
||||
def _load_tensors_from_shmm(
|
||||
self, tensors: Optional["ExecutorResponseTensors"]
|
||||
) -> Optional["ExecutorResponseTensors"]:
|
||||
if tensors is None:
|
||||
return tensors
|
||||
|
||||
def load_tensor(tensor: Optional[str]) -> Optional[torch.Tensor]:
|
||||
if tensor is None or isinstance(tensor, torch.Tensor):
|
||||
return tensor
|
||||
|
||||
shm = SharedMemory(name=tensor, create=False)
|
||||
tensor = torch.load(io.BytesIO(shm.buf))
|
||||
shm.close()
|
||||
shm.unlink()
|
||||
return tensor
|
||||
|
||||
return ExecutorResponseTensors(
|
||||
output_token_ids=tensors.output_token_ids,
|
||||
context_logits=load_tensor(tensors.context_logits),
|
||||
generation_logits=load_tensor(tensors.generation_logits),
|
||||
log_probs=tensors.log_probs,
|
||||
cum_log_probs=tensors.cum_log_probs,
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
||||
|
||||
@ -284,45 +176,13 @@ class FusedIpcQueue:
|
||||
def put(self, obj: Any):
|
||||
self.setup_sender()
|
||||
if self.fuse_message:
|
||||
self.sending_queue.put_nowait(self._prepare_message(obj))
|
||||
self.sending_queue.put_nowait(obj)
|
||||
else:
|
||||
batch = obj if isinstance(obj, list) else [obj]
|
||||
batch = [self._prepare_message(x) for x in batch]
|
||||
self.queue.put(batch)
|
||||
|
||||
def get(self) -> Any:
|
||||
obj = self.queue.get()
|
||||
if isinstance(obj, list):
|
||||
return [self._process_message(o) for o in obj]
|
||||
return self._process_message(obj)
|
||||
|
||||
def _prepare_message(self, obj: Any) -> Any:
|
||||
if isinstance(obj, ExecutorResponse):
|
||||
tensors = self.queue._store_tensors_in_shmm(obj.tensors)
|
||||
return ExecutorResponse(
|
||||
client_id=obj.client_id,
|
||||
tensors=tensors,
|
||||
finish_reasons=obj.finish_reasons,
|
||||
is_final=obj.is_final,
|
||||
sequence_index=obj.sequence_index,
|
||||
error=obj.error,
|
||||
timestamp=obj.timestamp,
|
||||
disaggregated_params=obj.disaggregated_params)
|
||||
return obj
|
||||
|
||||
def _process_message(self, obj: Any) -> Any:
|
||||
if isinstance(obj, ExecutorResponse):
|
||||
tensors = self.queue._load_tensors_from_shmm(obj.tensors)
|
||||
return ExecutorResponse(
|
||||
client_id=obj.client_id,
|
||||
tensors=tensors,
|
||||
finish_reasons=obj.finish_reasons,
|
||||
is_final=obj.is_final,
|
||||
sequence_index=obj.sequence_index,
|
||||
error=obj.error,
|
||||
timestamp=obj.timestamp,
|
||||
disaggregated_params=obj.disaggregated_params)
|
||||
return obj
|
||||
return self.queue.get()
|
||||
|
||||
@property
|
||||
def address(self) -> str:
|
||||
|
||||
@ -119,6 +119,12 @@ class PostprocWorker:
|
||||
|
||||
async def _handle_input(self, input: "PostprocWorker.Input") -> Any:
|
||||
''' Handle a single response from await_response worker. '''
|
||||
if input.rsp.result.context_logits is not None or \
|
||||
input.rsp.result.generation_logits is not None:
|
||||
raise ValueError(
|
||||
"Context logits or generation logits are not supposed to be "
|
||||
"sent to postprocessing workers.")
|
||||
|
||||
with nvtx_range("handle_input", color="yellow", category="Postproc"):
|
||||
req_id = input.rsp.client_id
|
||||
if req_id not in self._records:
|
||||
|
||||
@ -183,12 +183,12 @@ class ExecutorBindingsProxy(GenerationExecutor):
|
||||
try:
|
||||
data = queue.get()
|
||||
except:
|
||||
logger.error(
|
||||
logger.debug(
|
||||
"proxy.py: Error in _iteration_result_task: queue.get()")
|
||||
return False
|
||||
|
||||
if data is None:
|
||||
logger.error("proxy.py: _iteration_result_task: data is None")
|
||||
logger.debug("proxy.py: _iteration_result_task: data is None")
|
||||
return False # shutdown the thread
|
||||
|
||||
data = data if isinstance(data, list) else [data]
|
||||
@ -201,7 +201,7 @@ class ExecutorBindingsProxy(GenerationExecutor):
|
||||
try:
|
||||
for d in data:
|
||||
if d is None:
|
||||
logger.error("proxy.py: _iteration_result_task: d is None")
|
||||
logger.debug("proxy.py: _iteration_result_task: d is None")
|
||||
return False
|
||||
|
||||
if isinstance(queue, _SyncQueue):
|
||||
@ -219,7 +219,7 @@ class ExecutorBindingsProxy(GenerationExecutor):
|
||||
# and therefore event loop can already be closed.
|
||||
logger.debug("proxy.py: EventLoopShutdownError")
|
||||
except Exception as e:
|
||||
logger.error(f"proxy.py: Error in _iteration_result_task: {e}")
|
||||
logger.debug(f"proxy.py: Error in _iteration_result_task: {e}")
|
||||
raise e
|
||||
|
||||
return True # success
|
||||
|
||||
@ -3,15 +3,11 @@ import concurrent.futures
|
||||
import os
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from queue import Empty, Queue
|
||||
from typing import Any, Callable, List, NamedTuple, Optional
|
||||
|
||||
import torch
|
||||
from typing import Any, Callable, List, NamedTuple
|
||||
|
||||
from tensorrt_llm.llmapi.utils import print_colored_debug
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
from ..bindings import executor as tllm
|
||||
from ..disaggregated_params import DisaggregatedParams
|
||||
from ..llmapi.mpi_session import (MpiCommSession, MpiPoolSession, MpiSession,
|
||||
RemoteMpiCommSessionClient)
|
||||
from ..llmapi.utils import print_colored_debug
|
||||
@ -90,41 +86,12 @@ class ProcessPoolExecutorSession(MpiSession):
|
||||
self.mpi_pool.shutdown(wait=True)
|
||||
|
||||
|
||||
class ExecutorResponseTensors(NamedTuple):
|
||||
output_token_ids: List[List[int]]
|
||||
# context_logits is a tensor or a string denoting the path to the shared memory.
|
||||
context_logits: Optional[torch.Tensor | str]
|
||||
# generation_logits is a tensor or a string denoting the path to the shared memory.
|
||||
generation_logits: Optional[torch.Tensor | str]
|
||||
log_probs: Optional[list]
|
||||
cum_log_probs: Optional[list]
|
||||
|
||||
|
||||
class ErrorResponse(NamedTuple):
|
||||
client_id: int
|
||||
error_msg: str
|
||||
request_id: int
|
||||
|
||||
|
||||
class ExecutorResponse(NamedTuple):
|
||||
""" The response from the cpp-executor to the Python main thread. """
|
||||
client_id: int
|
||||
tensors: Optional[ExecutorResponseTensors]
|
||||
finish_reasons: Optional[List[tllm.FinishReason]]
|
||||
is_final: Optional[bool]
|
||||
sequence_index: Optional[int]
|
||||
# There are two types of errors:
|
||||
# 1. str for the errors from the cpp-executor.await_responses, this will be dispatched to the user's
|
||||
# generate_async as a per-request error, and won't stop the whole service.
|
||||
# 2. Exception for the errors from the background threads/processes, this will be processed in the main thread,
|
||||
# and stop the whole service.
|
||||
error: Optional[str | Exception]
|
||||
# The timestamp of the creation of the response. We use this to track the IPC overhead.
|
||||
timestamp: Optional[float] = None
|
||||
# Optional disaggregated serving params needed by the generation instances
|
||||
disaggregated_params: Optional[DisaggregatedParams] = None
|
||||
|
||||
|
||||
class IntraProcessQueue:
|
||||
''' A Queue-like container for IPC within the same process. '''
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user