TensorRT-LLMs/tensorrt_llm/serve/postprocess_handlers.py
Pengyun Lin 2aade46d18
[TRTLLM-8214][feat] Support Qwen3 tool parser (#8216)
Signed-off-by: Pengyun Lin <81065165+LinPoly@users.noreply.github.com>
2025-10-29 15:48:29 +08:00

537 lines
22 KiB
Python

from dataclasses import dataclass, field
from typing import List, Literal, Optional, Tuple, Union
from .._utils import nvtx_range_debug
from ..executor import (DetokenizedGenerationResultBase, GenerationResult,
GenerationResultBase)
from ..executor.postproc_worker import PostprocArgs
from ..executor.result import Logprob, TokenLogprobs
from ..llmapi.reasoning_parser import (BaseReasoningParser,
ReasoningParserFactory)
from ..llmapi.tokenizer import TransformersTokenizer
# yapf: disable
from .chat_utils import make_tool_call_id
from .harmony_adapter import (handle_non_streaming_response,
handle_streaming_response)
from .openai_protocol import (ChatCompletionLogProbs,
ChatCompletionLogProbsContent,
ChatCompletionNamedToolChoiceParam,
ChatCompletionRequest, ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
ChatCompletionToolsParam, ChatMessage,
CompletionRequest, CompletionResponse,
CompletionResponseChoice,
CompletionResponseStreamChoice,
CompletionStreamResponse, DeltaFunctionCall,
DeltaMessage, DeltaToolCall, FunctionCall,
PromptTokensDetails, StreamOptions, ToolCall,
UsageInfo, to_disaggregated_params)
from .tool_parser.base_tool_parser import BaseToolParser
from .tool_parser.core_types import ToolCallItem
from .tool_parser.tool_parser_factory import ToolParserFactory
# yapf: enable
@dataclass(kw_only=True)
class ChatPostprocArgs(PostprocArgs):
echo: bool = False
role: str
model: str
num_choices: int = 1
tools: Optional[List[ChatCompletionToolsParam]] = None
tool_choice: Optional[Union[Literal["none"],
ChatCompletionNamedToolChoiceParam]] = "none"
return_logprobs: bool = False
top_logprobs: bool = False
stream_options: Optional[StreamOptions] = None
last_message_content: Optional[str] = None
reasoning_parser: Optional[str] = None
tool_parser: Optional[str] = None
reasoning_parser_dict: dict[int, BaseReasoningParser] = field(
default_factory=dict)
tool_parser_dict: dict[int, BaseToolParser] = field(default_factory=dict)
has_tool_call: dict[int, bool] = field(default_factory=dict)
@classmethod
def from_request(cls, request: ChatCompletionRequest):
return cls(
echo=request.echo,
role="assistant"
if request.add_generation_prompt else request.messages[-1]["role"],
model=request.model,
num_choices=request.n if request.n else 1,
tools=request.tools,
tool_choice=request.tool_choice,
stream_options=request.stream_options,
return_logprobs=bool(request.logprobs),
top_logprobs=bool(request.top_logprobs),
)
def create_logprobs(token_ids: List[int], tokenizer: TransformersTokenizer,
logprobs: List[float] | TokenLogprobs,
top_logprobs: bool) -> ChatCompletionLogProbs:
assert len(token_ids) == len(logprobs), \
"token_ids and logprobs have different lengths"
content: List[ChatCompletionLogProbsContent] = []
for token_id, logprob in zip(token_ids, logprobs):
logprob: float | dict[int, Logprob]
token = tokenizer.decode(token_id)
chat_logprob = ChatCompletionLogProbsContent(
token=token,
bytes=list(token.encode("utf-8", errors="replace")),
)
if isinstance(logprob, dict):
if token_id in logprob:
chat_logprob.logprob = max(logprob[token_id].logprob, -9999.0)
if top_logprobs:
chat_logprob.top_logprobs = [
ChatCompletionLogProbsContent(
token=(tk := tokenizer.decode(tid)),
logprob=max(logprob.logprob, -9999.0),
bytes=list(tk.encode("utf-8", errors="replace")))
for tid, logprob in logprob.items()
]
else:
chat_logprob.logprob = max(logprob, -9999.0)
content.append(chat_logprob)
chat_logprobs = ChatCompletionLogProbs(content=content)
return chat_logprobs
def apply_reasoning_parser(args: ChatPostprocArgs, output_index: int, text: str,
streaming: bool) -> Tuple[str, str]:
reasoning_parser = None
if args.reasoning_parser is not None:
if output_index not in args.reasoning_parser_dict:
args.reasoning_parser_dict[
output_index] = ReasoningParserFactory.create_reasoning_parser(
args.reasoning_parser)
reasoning_parser = args.reasoning_parser_dict[output_index]
if reasoning_parser is not None:
if not streaming:
result = reasoning_parser.parse(text)
else:
result = reasoning_parser.parse_delta(text)
content, reasoning_content = result.content, result.reasoning_content
else:
content, reasoning_content = text, ""
return content, reasoning_content
def apply_tool_parser(args: ChatPostprocArgs, output_index: int, text: str,
streaming: bool) -> Tuple[str, List[ToolCallItem]]:
tool_parser = None
tools = args.tools
if args.tool_parser is not None and tools is not None:
if output_index not in args.tool_parser_dict:
args.tool_parser_dict[
output_index] = ToolParserFactory.create_tool_parser(
args.tool_parser)
tool_parser = args.tool_parser_dict[output_index]
if tool_parser is not None and tools is not None:
if not streaming:
result = tool_parser.detect_and_parse(text, tools)
else:
result = tool_parser.parse_streaming_increment(text, tools)
normal_text, calls = result.normal_text, result.calls
if result.calls:
args.has_tool_call[output_index] = True
else:
normal_text, calls = text, []
return normal_text, calls
@nvtx_range_debug("chat_stream_post_processor")
def chat_stream_post_processor(rsp: GenerationResultBase,
args: ChatPostprocArgs) -> List[str]:
def yield_first_chat(num_tokens: int,
idx: int,
role: str | None = None,
content: str | None = None):
choice_data = ChatCompletionResponseStreamChoice(index=idx,
delta=DeltaMessage(
role=role,
content=content),
finish_reason=None)
chunk = ChatCompletionStreamResponse(choices=[choice_data],
model=args.model)
if include_continuous_usage:
chunk.usage = UsageInfo(
prompt_tokens=num_tokens,
total_tokens=num_tokens,
completion_tokens=0,
prompt_tokens_details=PromptTokensDetails(
cached_tokens=rsp.cached_tokens),
)
data = chunk.model_dump_json(exclude_none=True)
return data
res: List[str] = []
finish_reason_sent = [False] * args.num_choices
prompt_tokens = args.num_prompt_tokens
if stream_option := args.stream_options:
include_usage = stream_option.include_usage
include_continuous_usage = include_usage and stream_option.continuous_usage_stats
else:
include_usage = False
include_continuous_usage = False
if args.first_iteration:
for i in range(args.num_choices):
res.append(
f"data: {yield_first_chat(prompt_tokens, i, role=args.role)} \n\n"
)
if args.echo and args.last_message_content:
res.append(
f"data: {yield_first_chat(prompt_tokens, i, content=args.last_message_content)} \n\n"
)
args.first_iteration = False
for output in rsp.outputs:
i = output.index
if finish_reason_sent[i]:
continue
delta_text = output.text_diff
delta_text, reasoning_delta_text = apply_reasoning_parser(
args, i, delta_text, True)
if args.tool_choice and type(
args.tool_choice) is ChatCompletionNamedToolChoiceParam:
delta_message = DeltaMessage(tool_calls=[
DeltaToolCall(
function=DeltaFunctionCall(
name=args.tool_choice.function.name,
arguments=delta_text),
index=i,
),
], )
else:
delta_text, calls = apply_tool_parser(args, i, delta_text, True)
tool_calls = []
for call_item in calls:
# Tool call ID should be generated only once per tool call
if call_item.name:
# First chunk: include ID and function name
tool_call_id = make_tool_call_id()
function_name = call_item.name
else:
# Subsequent chunks: null ID and name for argument deltas
tool_call_id = None
function_name = None
tool_calls.append(
DeltaToolCall(
id=tool_call_id,
index=call_item.tool_index,
function=DeltaFunctionCall(
name=function_name,
arguments=call_item.parameters,
),
))
if tool_calls or delta_text or reasoning_delta_text or output.finish_reason:
delta_message = DeltaMessage(
content=delta_text,
reasoning_content=reasoning_delta_text,
tool_calls=tool_calls if tool_calls else None)
else:
continue
choice = ChatCompletionResponseStreamChoice(
index=i,
delta=delta_message,
avg_decoded_tokens_per_iter=getattr(rsp,
'avg_decoded_tokens_per_iter',
None),
stop_reason=output.stop_reason,
)
if args.return_logprobs:
logprobs = output.logprobs_diff
token_ids = output.token_ids_diff
choice.logprobs = create_logprobs(token_ids, args.tokenizer,
logprobs, args.top_logprobs)
if output.finish_reason is not None:
if output.finish_reason == "stop" and args.has_tool_call.get(
i, False):
choice.finish_reason = "tool_calls"
else:
choice.finish_reason = output.finish_reason
choice.stop_reason = output.stop_reason
finish_reason_sent[i] = True
chunk = ChatCompletionStreamResponse(choices=[choice], model=args.model)
if include_continuous_usage:
chunk.usage = UsageInfo(prompt_tokens=prompt_tokens,
completion_tokens=output.length,
total_tokens=output.length + prompt_tokens,
prompt_tokens_details=PromptTokensDetails(
cached_tokens=rsp.cached_tokens))
data = chunk.model_dump_json(exclude_none=True)
res.append(f"data: {data}\n\n")
if include_usage and rsp._done:
completion_tokens = sum(output.length for output in rsp.outputs)
final_usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
prompt_tokens_details=PromptTokensDetails(
cached_tokens=rsp.cached_tokens),
)
final_usage_chunk = ChatCompletionStreamResponse(choices=[],
model=args.model,
usage=final_usage)
final_usage_data = final_usage_chunk.model_dump_json()
res.append(f"data: {final_usage_data}\n\n")
return res
@nvtx_range_debug("chat_response_post_processor")
def chat_response_post_processor(
rsp: GenerationResultBase,
args: ChatPostprocArgs) -> ChatCompletionResponse:
choices: List[ChatCompletionResponseChoice] = []
role = args.role
for output in rsp.outputs:
text, reasoning_text = apply_reasoning_parser(args, output.index,
output.text, False)
if args.tool_choice and isinstance(args.tool_choice,
ChatCompletionNamedToolChoiceParam):
message = ChatMessage(
role=role,
content="",
tool_calls=[
ToolCall(function=FunctionCall(
name=args.tool_choice.function.name, arguments=text))
])
else:
if text is None:
text = ""
text, calls = apply_tool_parser(args, output.index, text, False)
tool_calls = [
ToolCall(function=FunctionCall(name=call.name or "",
arguments=call.parameters))
for call in calls
]
message = ChatMessage(role=role,
content=text,
reasoning_content=reasoning_text,
tool_calls=tool_calls)
disaggregated_params = to_disaggregated_params(
output.disaggregated_params)
choice = ChatCompletionResponseChoice(
index=output.index,
message=message,
stop_reason=output.stop_reason,
disaggregated_params=disaggregated_params,
avg_decoded_tokens_per_iter=getattr(rsp,
'avg_decoded_tokens_per_iter',
None),
)
if output.finish_reason == "stop" and args.has_tool_call.get(
output.index, False):
choice.finish_reason = "tool_calls"
else:
choice.finish_reason = output.finish_reason
if args.return_logprobs:
choice.logprobs = create_logprobs(output.token_ids, args.tokenizer,
output.logprobs,
args.top_logprobs)
choices.append(choice)
if args.echo and args.last_message_content:
for choice in choices:
full_message = args.last_message_content + choice.message.content
choice.message.content = full_message
num_prompt_tokens = args.num_prompt_tokens
num_generated_tokens = sum(len(output.token_ids) for output in rsp.outputs)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
prompt_tokens_details=PromptTokensDetails(
cached_tokens=rsp.cached_tokens),
)
response = ChatCompletionResponse(
model=args.model,
choices=choices,
usage=usage,
)
return response
@dataclass(kw_only=True)
class CompletionPostprocArgs(PostprocArgs):
echo: bool = False
model: str = None
num_choices: int = 1
prompt_idx: int = 0
detokenize: bool = True
prompt: Optional[str] = None
stream_options: Optional[StreamOptions] = None
@classmethod
def from_request(cls, request: CompletionRequest):
return cls(
echo=request.echo,
model=request.model,
num_choices=request.n if request.n else 1,
stream_options=request.stream_options,
detokenize=request.detokenize,
)
@nvtx_range_debug("completion_stream_post_processor")
def completion_stream_post_processor(rsp: DetokenizedGenerationResultBase,
args: CompletionPostprocArgs) -> List[str]:
res: List[str] = []
prompt_tokens = args.num_prompt_tokens
if stream_option := args.stream_options:
include_usage = stream_option.include_usage
include_continuous_usage = include_usage and stream_option.continuous_usage_stats
else:
include_usage = False
include_continuous_usage = False
for output in rsp.outputs:
delta_text = output.text_diff
if args.echo and args.first_iteration:
delta_text = args.prompt + delta_text
choice = CompletionResponseStreamChoice(
index=args.prompt_idx * args.num_choices + output.index,
text=delta_text if args.detokenize else "",
token_ids=None if args.detokenize else output.token_ids_diff,
finish_reason=output.finish_reason,
stop_reason=output.stop_reason,
avg_decoded_tokens_per_iter=getattr(rsp,
'avg_decoded_tokens_per_iter',
None),
)
chunk = CompletionStreamResponse(model=args.model, choices=[choice])
if include_continuous_usage:
chunk.usage = UsageInfo(prompt_tokens=prompt_tokens,
completion_tokens=output.length,
total_tokens=output.length + prompt_tokens,
prompt_tokens_details=PromptTokensDetails(
cached_tokens=rsp.cached_tokens))
data = chunk.model_dump_json(exclude_unset=False)
res.append(f"data: {data}\n\n")
if include_usage and rsp._done:
completion_tokens = sum(output.length for output in rsp.outputs)
final_usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
prompt_tokens_details=PromptTokensDetails(
cached_tokens=rsp.cached_tokens),
)
final_usage_chunk = ChatCompletionStreamResponse(choices=[],
model=args.model,
usage=final_usage)
final_usage_data = final_usage_chunk.model_dump_json()
res.append(f"data: {final_usage_data}\n\n")
args.first_iteration = False
return res
@nvtx_range_debug("completion_response_post_processor")
def completion_response_post_processor(
rsp: GenerationResult,
args: CompletionPostprocArgs) -> CompletionResponse:
prompt_tokens = args.num_prompt_tokens
completion_tokens = 0
choices = []
for output in rsp.outputs:
text = output.text
if args.echo:
text = args.prompt + text
disaggregated_params = to_disaggregated_params(
output.disaggregated_params)
choice = CompletionResponseChoice(
text=text if args.detokenize else "",
token_ids=None if args.detokenize else output.token_ids,
index=args.prompt_idx * args.num_choices + output.index,
disaggregated_params=disaggregated_params,
context_logits=None
if rsp.context_logits is None else rsp.context_logits.tolist(),
stop_reason=output.stop_reason,
finish_reason=output.finish_reason,
avg_decoded_tokens_per_iter=getattr(rsp,
'avg_decoded_tokens_per_iter',
None),
)
completion_tokens += output.length
choices.append(choice)
usage = UsageInfo(prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=completion_tokens + prompt_tokens,
prompt_tokens_details=PromptTokensDetails(
cached_tokens=rsp.cached_tokens))
response = CompletionResponse(choices=choices,
model=args.model,
usage=usage)
return response
@dataclass(kw_only=True)
class ChatCompletionPostprocArgs(PostprocArgs):
model: str
tools: Optional[List[ChatCompletionToolsParam]]
tool_choice: Optional[Union[Literal["none", "auto"],
ChatCompletionNamedToolChoiceParam]]
request_id: Optional[int] = None
@classmethod
def from_request(cls, request: ChatCompletionRequest):
return cls(
model=request.model,
tools=request.tools,
tool_choice=request.tool_choice,
)
@nvtx_range_debug("chat_harmony_post_processor")
def chat_harmony_post_processor(
rsp: GenerationResult,
args: ChatCompletionPostprocArgs) -> ChatCompletionResponse:
response = handle_non_streaming_response(
tools=args.tools,
tool_choice=args.tool_choice,
outputs=rsp.outputs,
model=args.model,
num_prompt_tokens=args.num_prompt_tokens,
)
return response
@nvtx_range_debug("chat_harmony_streaming_post_processor")
def chat_harmony_streaming_post_processor(
rsp: GenerationResult, args: ChatCompletionPostprocArgs) -> List[str]:
response = handle_streaming_response(
tools=args.tools,
tool_choice=args.tool_choice,
result=rsp,
model=args.model,
request_id=args.request_id,
done=rsp._done,
num_prompt_tokens=args.num_prompt_tokens,
)
return response