TensorRT-LLMs/tensorrt_llm/executor/postproc_worker.py
Kaiyu Xie dce1dcc4f9
feat: Support post_proc for bench (#5122)
Signed-off-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>
2025-06-15 13:02:38 +08:00

221 lines
8.0 KiB
Python

import asyncio
import traceback
from collections import deque
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple,
Optional)
import zmq
import zmq.asyncio
from .._utils import nvtx_range_debug
from ..bindings import executor as tllm
from ..llmapi.tokenizer import TransformersTokenizer, load_hf_tokenizer
from ..llmapi.utils import print_traceback_on_error
from ..sampling_params import SamplingParams
from .ipc import ZeroMqQueue
from .utils import is_llm_response
if TYPE_CHECKING:
from .result import (DetokenizedGenerationResultBase, GenerationResult,
GenerationResultBase)
__all__ = [
"PostprocWorker",
"PostprocWorkerConfig",
]
@dataclass(kw_only=True)
class PostprocArgs:
first_iteration: bool = True
num_prompt_tokens: Optional[int] = None
tokenizer: Optional[TransformersTokenizer] = None
@dataclass(kw_only=True)
class PostprocParams:
post_processor: Callable[["GenerationResultBase", PostprocArgs], Any] = None
postproc_args: PostprocArgs = None
@dataclass
class PostprocWorkerConfig:
''' The config for the postprocess worker. '''
num_postprocess_workers: int = 0
postprocess_tokenizer_dir: Optional[str] = None
@property
def enabled(self) -> bool:
return self.num_postprocess_workers > 0
class PostprocWorker:
'''
The worker to postprocess the responses from the executor's await_response.
'''
@dataclass
class Input:
rsp: "tllm.Response"
# The information necessary for creating a GenerationResult in the first Input for each request
sampling_params: Optional[SamplingParams] = None
postproc_params: Optional[PostprocParams] = None
streaming: Optional[bool] = None
class Output(NamedTuple):
client_id: int
res: Any
is_final: bool
error: str = ""
def __init__(
self,
pull_pipe_addr: tuple[str, Optional[bytes]],
push_pipe_addr: tuple[str, Optional[bytes]],
tokenizer_dir: str,
record_creator: Callable[
["PostprocWorker.Input", TransformersTokenizer], Any],
):
'''
Args:
pull_pipe_addr (tuple[str, Optional[bytes]]): The address and HMAC key of the input IPC.
push_pipe_addr (tuple[str, Optional[bytes]]): The address and HMAC key of the output IPC.
tokenizer_dir (str): The directory to load tokenizer.
record_creator (Callable[["ResponsePostprocessWorker.Input"], Any]): A creator for creating a record for a request.
result_handler (Optional[Callable[[GenerationResultBase], Any]]): A callback handles the final result.
'''
self._records: Dict[int, GenerationResult] = {}
self._record_creator = record_creator
self._pull_pipe = ZeroMqQueue(address=pull_pipe_addr,
is_async=True,
is_server=False,
name="postprocess_pull_pipe")
self._push_pipe = ZeroMqQueue(address=push_pipe_addr,
is_async=True,
is_server=False,
socket_type=zmq.PUSH,
name="postprocess_push_pipe")
self._to_stop = asyncio.Event()
self._q = deque()
# Load the tokenizer and share in all records
self._tokenizer = load_hf_tokenizer(tokenizer_dir)
@staticmethod
def default_record_creator(
inp: "PostprocWorker.Input", tokenizer: TransformersTokenizer
) -> "DetokenizedGenerationResultBase":
from .result import DetokenizedGenerationResultBase
assert inp.sampling_params is not None
return DetokenizedGenerationResultBase(
inp.rsp.client_id,
sampling_params=inp.sampling_params,
postproc_params=inp.postproc_params,
streaming=inp.streaming,
tokenizer=tokenizer)
async def _handle_input(self, input: "PostprocWorker.Input") -> Any:
''' Handle a single response from await_response worker. '''
if input.rsp.result.context_logits is not None or \
input.rsp.result.generation_logits is not None:
raise ValueError(
"Context logits or generation logits are not supposed to be "
"sent to postprocessing workers.")
with nvtx_range_debug("handle_input",
color="yellow",
category="Postproc"):
req_id = input.rsp.client_id
if req_id not in self._records:
# TODO: support variant creation later
self._records[req_id] = self._record_creator(
input, self._tokenizer)
record = self._records[req_id]
record._handle_response(input.rsp) # inplace
# Left the result_handler determine the final output dtype.
# NOTE: This will change the CompletionOutput._postprocess_result
if postproc_params := record.postproc_params:
result_handler, args = postproc_params.post_processor, postproc_params.postproc_args
args.tokenizer = self._tokenizer
out = result_handler(record, args)
else:
# This should only be called in streaming mode, and each time it
# produces a single output.
out = record.outputs[0]
# TODO: Keep only the diff token_ids and text in streaming mode when
# result_handler is not set
return out
async def _batched_put(self):
''' Batched IPC send. '''
async for batch in self._mainloop():
if batch is None:
# notify dispatch_result corountine to quit
await self._push_pipe.put_async(None)
break
assert isinstance(batch, list)
await self._push_pipe.put_async(batch)
async def _mainloop(self):
''' The loop for handle_response and keep producing outputs. '''
async def handle_single_input(inp: PostprocWorker.Input,
batch: List[PostprocWorker.Output]):
assert isinstance(
inp, PostprocWorker.Input
), f"Expect PostprocWorker.Input, got {type(inp)}."
client_id = inp.rsp.client_id
is_final = inp.rsp.result.is_final if is_llm_response(
inp.rsp) else True
res = await self._handle_input(inp)
batch.append(PostprocWorker.Output(client_id, res, is_final))
if is_final:
self._records.pop(client_id)
while not self._to_stop.is_set():
batch = []
inputs: Optional[List[PostprocWorker.Input]
| PostprocWorker.
Input] = await self._pull_pipe.get_async()
if not isinstance(inputs, list):
inputs = [inputs]
for inp in inputs:
if inp is None:
self._to_stop.set()
yield None
break
await handle_single_input(inp, batch)
yield batch
def start(self):
''' Start the workflow in the current thread. '''
async def main():
await asyncio.gather(self._batched_put())
try:
asyncio.run(main())
except Exception as e:
print(traceback.format_exc())
raise e
@print_traceback_on_error
def postproc_worker_main(feedin_ipc_addr: tuple[str, Optional[bytes]],
feedout_ipc_addr: tuple[str, Optional[bytes]],
tokenizer_dir: str, record_creator: Callable):
worker = PostprocWorker(feedin_ipc_addr,
feedout_ipc_addr,
tokenizer_dir=tokenizer_dir,
record_creator=record_creator)
worker.start()