from dataclasses import dataclass, field from typing import Any, List, Literal, Optional, Tuple, Union from tensorrt_llm.serve.responses_utils import ResponsesStreamingProcessor from tensorrt_llm.serve.responses_utils import \ create_response_non_store as responses_api_create_response_non_store from .._utils import nvtx_range_debug from ..executor import (DetokenizedGenerationResultBase, GenerationResult, GenerationResultBase) from ..executor.postproc_worker import PostprocArgs from ..executor.result import Logprob, TokenLogprobs from ..llmapi import SamplingParams from ..llmapi.reasoning_parser import (BaseReasoningParser, ReasoningParserFactory) from ..llmapi.tokenizer import TransformersTokenizer # yapf: disable from .chat_utils import make_tool_call_id from .harmony_adapter import (handle_non_streaming_response, handle_streaming_response) from .openai_protocol import (ChatCompletionLogProbs, ChatCompletionLogProbsContent, ChatCompletionNamedToolChoiceParam, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse, ChatCompletionToolsParam, ChatMessage, CompletionRequest, CompletionResponse, CompletionResponseChoice, CompletionResponseStreamChoice, CompletionStreamResponse, DeltaFunctionCall, DeltaMessage, DeltaToolCall, FunctionCall, PromptTokensDetails, ResponsesRequest, ResponsesResponse, StreamOptions, ToolCall, UsageInfo, to_disaggregated_params) from .tool_parser.base_tool_parser import BaseToolParser from .tool_parser.core_types import ToolCallItem from .tool_parser.tool_parser_factory import ToolParserFactory # yapf: enable @dataclass(kw_only=True) class ChatPostprocArgs(PostprocArgs): echo: bool = False role: str model: str num_choices: int = 1 tools: Optional[List[ChatCompletionToolsParam]] = None tool_choice: Optional[Union[Literal["none"], ChatCompletionNamedToolChoiceParam]] = "none" return_logprobs: bool = False top_logprobs: bool = False stream_options: Optional[StreamOptions] = None last_message_content: Optional[str] = None reasoning_parser: Optional[str] = None tool_parser: Optional[str] = None reasoning_parser_dict: dict[int, BaseReasoningParser] = field( default_factory=dict) tool_parser_dict: dict[int, BaseToolParser] = field(default_factory=dict) has_tool_call: dict[int, bool] = field(default_factory=dict) tool_call_id_type: str = "random" chat_template_kwargs: Optional[dict[str, Any]] = None @classmethod def from_request(cls, request: ChatCompletionRequest): return cls( echo=request.echo, role="assistant" if request.add_generation_prompt else request.messages[-1]["role"], model=request.model, num_choices=request.n if request.n else 1, tools=request.tools, tool_choice=request.tool_choice, stream_options=request.stream_options, return_logprobs=bool(request.logprobs), top_logprobs=bool(request.top_logprobs), chat_template_kwargs=request.chat_template_kwargs, ) def create_logprobs(token_ids: List[int], tokenizer: TransformersTokenizer, logprobs: List[float] | TokenLogprobs, top_logprobs: bool) -> ChatCompletionLogProbs: assert len(token_ids) == len(logprobs), \ "token_ids and logprobs have different lengths" content: List[ChatCompletionLogProbsContent] = [] for token_id, logprob in zip(token_ids, logprobs): logprob: float | dict[int, Logprob] token = tokenizer.decode(token_id) chat_logprob = ChatCompletionLogProbsContent( token=token, bytes=list(token.encode("utf-8", errors="replace")), ) if isinstance(logprob, dict): if token_id in logprob: chat_logprob.logprob = max(logprob[token_id].logprob, -9999.0) if top_logprobs: chat_logprob.top_logprobs = [ ChatCompletionLogProbsContent( token=(tk := tokenizer.decode(tid)), logprob=max(logprob.logprob, -9999.0), bytes=list(tk.encode("utf-8", errors="replace"))) for tid, logprob in logprob.items() ] else: chat_logprob.logprob = max(logprob, -9999.0) content.append(chat_logprob) chat_logprobs = ChatCompletionLogProbs(content=content) return chat_logprobs def apply_reasoning_parser(args: ChatPostprocArgs, output_index: int, text: str, streaming: bool) -> Tuple[str, str]: reasoning_parser = None if args.reasoning_parser is not None: if output_index not in args.reasoning_parser_dict: chat_template_kwargs = getattr(args, "chat_template_kwargs", None) args.reasoning_parser_dict[ output_index] = ReasoningParserFactory.create_reasoning_parser( args.reasoning_parser, chat_template_kwargs) reasoning_parser = args.reasoning_parser_dict[output_index] if reasoning_parser is not None: if not streaming: result = reasoning_parser.parse(text) else: result = reasoning_parser.parse_delta(text) content, reasoning_content = result.content, result.reasoning_content else: content, reasoning_content = text, "" return content, reasoning_content def apply_tool_parser(args: ChatPostprocArgs, output_index: int, text: str, streaming: bool) -> Tuple[str, List[ToolCallItem]]: tool_parser = None tools = args.tools if args.tool_parser is not None and tools is not None: if output_index not in args.tool_parser_dict: args.tool_parser_dict[ output_index] = ToolParserFactory.create_tool_parser( args.tool_parser) tool_parser = args.tool_parser_dict[output_index] if tool_parser is not None and tools is not None: if not streaming: result = tool_parser.detect_and_parse(text, tools) else: result = tool_parser.parse_streaming_increment(text, tools) normal_text, calls = result.normal_text, result.calls if result.calls: args.has_tool_call[output_index] = True else: normal_text, calls = text, [] return normal_text, calls @nvtx_range_debug("chat_stream_post_processor") def chat_stream_post_processor(rsp: GenerationResultBase, args: ChatPostprocArgs) -> List[str]: def yield_first_chat(num_tokens: int, idx: int, role: str | None = None, content: str | None = None): choice_data = ChatCompletionResponseStreamChoice(index=idx, delta=DeltaMessage( role=role, content=content), finish_reason=None) chunk = ChatCompletionStreamResponse(choices=[choice_data], model=args.model) if include_continuous_usage: chunk.usage = UsageInfo( prompt_tokens=num_tokens, total_tokens=num_tokens, completion_tokens=0, prompt_tokens_details=PromptTokensDetails( cached_tokens=rsp.cached_tokens), ) data = chunk.model_dump_json(exclude_none=True) return data res: List[str] = [] finish_reason_sent = [False] * args.num_choices prompt_tokens = args.num_prompt_tokens if stream_option := args.stream_options: include_usage = stream_option.include_usage include_continuous_usage = include_usage and stream_option.continuous_usage_stats else: include_usage = False include_continuous_usage = False if args.first_iteration: for i in range(args.num_choices): res.append( f"data: {yield_first_chat(prompt_tokens, i, role=args.role)} \n\n" ) if args.echo and args.last_message_content: res.append( f"data: {yield_first_chat(prompt_tokens, i, content=args.last_message_content)} \n\n" ) args.first_iteration = False for output in rsp.outputs: i = output.index if finish_reason_sent[i]: continue delta_text = output.text_diff delta_text, reasoning_delta_text = apply_reasoning_parser( args, i, delta_text, True) if args.tool_choice and type( args.tool_choice) is ChatCompletionNamedToolChoiceParam: delta_message = DeltaMessage(tool_calls=[ DeltaToolCall( function=DeltaFunctionCall( name=args.tool_choice.function.name, arguments=delta_text), index=i, ), ], ) else: delta_text, calls = apply_tool_parser(args, i, delta_text, True) tool_calls = [] for call_item in calls: # Tool call ID should be generated only once per tool call if call_item.name: # First chunk: include ID and function name tool_call_id = make_tool_call_id( id_type=args.tool_call_id_type, func_name=call_item.name, idx=call_item.tool_index) function_name = call_item.name else: # Subsequent chunks: null ID and name for argument deltas tool_call_id = None function_name = None tool_calls.append( DeltaToolCall( id=tool_call_id, index=call_item.tool_index, function=DeltaFunctionCall( name=function_name, arguments=call_item.parameters, ), )) if tool_calls or delta_text or reasoning_delta_text or output.finish_reason: delta_message = DeltaMessage( content=delta_text, reasoning_content=reasoning_delta_text, tool_calls=tool_calls if tool_calls else None) else: continue choice = ChatCompletionResponseStreamChoice( index=i, delta=delta_message, avg_decoded_tokens_per_iter=getattr(rsp, 'avg_decoded_tokens_per_iter', None), stop_reason=output.stop_reason, ) if args.return_logprobs: logprobs = output.logprobs_diff token_ids = output.token_ids_diff choice.logprobs = create_logprobs(token_ids, args.tokenizer, logprobs, args.top_logprobs) if output.finish_reason is not None: if output.finish_reason == "stop" and args.has_tool_call.get( i, False): choice.finish_reason = "tool_calls" else: choice.finish_reason = output.finish_reason choice.stop_reason = output.stop_reason finish_reason_sent[i] = True chunk = ChatCompletionStreamResponse(choices=[choice], model=args.model) if include_continuous_usage: chunk.usage = UsageInfo(prompt_tokens=prompt_tokens, completion_tokens=output.length, total_tokens=output.length + prompt_tokens, prompt_tokens_details=PromptTokensDetails( cached_tokens=rsp.cached_tokens)) data = chunk.model_dump_json(exclude_none=True) res.append(f"data: {data}\n\n") if include_usage and rsp._done: completion_tokens = sum(output.length for output in rsp.outputs) final_usage = UsageInfo( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens, prompt_tokens_details=PromptTokensDetails( cached_tokens=rsp.cached_tokens), ) final_usage_chunk = ChatCompletionStreamResponse(choices=[], model=args.model, usage=final_usage) final_usage_data = final_usage_chunk.model_dump_json() res.append(f"data: {final_usage_data}\n\n") return res @nvtx_range_debug("chat_response_post_processor") def chat_response_post_processor( rsp: GenerationResultBase, args: ChatPostprocArgs) -> ChatCompletionResponse: choices: List[ChatCompletionResponseChoice] = [] role = args.role for output in rsp.outputs: text, reasoning_text = apply_reasoning_parser(args, output.index, output.text, False) if args.tool_choice and isinstance(args.tool_choice, ChatCompletionNamedToolChoiceParam): message = ChatMessage( role=role, content="", tool_calls=[ ToolCall(function=FunctionCall( name=args.tool_choice.function.name, arguments=text)) ]) else: if text is None: text = "" text, calls = apply_tool_parser(args, output.index, text, False) tool_calls = [ ToolCall(function=FunctionCall(name=call.name or "", arguments=call.parameters)) for call in calls ] message = ChatMessage(role=role, content=text, reasoning_content=reasoning_text, tool_calls=tool_calls) disaggregated_params = to_disaggregated_params( output.disaggregated_params) choice = ChatCompletionResponseChoice( index=output.index, message=message, stop_reason=output.stop_reason, disaggregated_params=disaggregated_params, avg_decoded_tokens_per_iter=getattr(rsp, 'avg_decoded_tokens_per_iter', None), ) if output.finish_reason == "stop" and args.has_tool_call.get( output.index, False): choice.finish_reason = "tool_calls" else: choice.finish_reason = output.finish_reason if args.return_logprobs: choice.logprobs = create_logprobs(output.token_ids, args.tokenizer, output.logprobs, args.top_logprobs) choices.append(choice) if args.echo and args.last_message_content: for choice in choices: full_message = args.last_message_content + choice.message.content choice.message.content = full_message num_prompt_tokens = args.num_prompt_tokens num_generated_tokens = sum(len(output.token_ids) for output in rsp.outputs) usage = UsageInfo( prompt_tokens=num_prompt_tokens, completion_tokens=num_generated_tokens, total_tokens=num_prompt_tokens + num_generated_tokens, prompt_tokens_details=PromptTokensDetails( cached_tokens=rsp.cached_tokens), ) response = ChatCompletionResponse( model=args.model, choices=choices, usage=usage, ) return response @dataclass(kw_only=True) class CompletionPostprocArgs(PostprocArgs): echo: bool = False model: str = None num_choices: int = 1 prompt_idx: int = 0 detokenize: bool = True prompt: Optional[str] = None stream_options: Optional[StreamOptions] = None @classmethod def from_request(cls, request: CompletionRequest): return cls( echo=request.echo, model=request.model, num_choices=request.n if request.n else 1, stream_options=request.stream_options, detokenize=request.detokenize, ) @nvtx_range_debug("completion_stream_post_processor") def completion_stream_post_processor(rsp: DetokenizedGenerationResultBase, args: CompletionPostprocArgs) -> List[str]: res: List[str] = [] prompt_tokens = args.num_prompt_tokens if stream_option := args.stream_options: include_usage = stream_option.include_usage include_continuous_usage = include_usage and stream_option.continuous_usage_stats else: include_usage = False include_continuous_usage = False for output in rsp.outputs: delta_text = output.text_diff if args.echo and args.first_iteration: delta_text = args.prompt + delta_text choice = CompletionResponseStreamChoice( index=args.prompt_idx * args.num_choices + output.index, text=delta_text if args.detokenize else "", token_ids=None if args.detokenize else output.token_ids_diff, finish_reason=output.finish_reason, stop_reason=output.stop_reason, avg_decoded_tokens_per_iter=getattr(rsp, 'avg_decoded_tokens_per_iter', None), ) chunk = CompletionStreamResponse(model=args.model, choices=[choice]) if include_continuous_usage: chunk.usage = UsageInfo(prompt_tokens=prompt_tokens, completion_tokens=output.length, total_tokens=output.length + prompt_tokens, prompt_tokens_details=PromptTokensDetails( cached_tokens=rsp.cached_tokens)) data = chunk.model_dump_json(exclude_unset=False) res.append(f"data: {data}\n\n") if include_usage and rsp._done: completion_tokens = sum(output.length for output in rsp.outputs) final_usage = UsageInfo( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens, prompt_tokens_details=PromptTokensDetails( cached_tokens=rsp.cached_tokens), ) final_usage_chunk = ChatCompletionStreamResponse(choices=[], model=args.model, usage=final_usage) final_usage_data = final_usage_chunk.model_dump_json() res.append(f"data: {final_usage_data}\n\n") args.first_iteration = False return res @nvtx_range_debug("completion_response_post_processor") def completion_response_post_processor( rsp: GenerationResult, args: CompletionPostprocArgs) -> CompletionResponse: prompt_tokens = args.num_prompt_tokens completion_tokens = 0 choices = [] for output in rsp.outputs: text = output.text if args.echo: text = args.prompt + text disaggregated_params = to_disaggregated_params( output.disaggregated_params) choice = CompletionResponseChoice( text=text if args.detokenize else "", token_ids=None if args.detokenize else output.token_ids, index=args.prompt_idx * args.num_choices + output.index, disaggregated_params=disaggregated_params, context_logits=None if rsp.context_logits is None else rsp.context_logits.tolist(), stop_reason=output.stop_reason, finish_reason=output.finish_reason, avg_decoded_tokens_per_iter=getattr(rsp, 'avg_decoded_tokens_per_iter', None), ) completion_tokens += output.length choices.append(choice) usage = UsageInfo(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=completion_tokens + prompt_tokens, prompt_tokens_details=PromptTokensDetails( cached_tokens=rsp.cached_tokens)) response = CompletionResponse(choices=choices, model=args.model, usage=usage) return response @dataclass(kw_only=True) class ChatCompletionPostprocArgs(PostprocArgs): model: str tools: Optional[List[ChatCompletionToolsParam]] tool_choice: Optional[Union[Literal["none", "auto"], ChatCompletionNamedToolChoiceParam]] request_id: Optional[int] = None chat_template_kwargs: Optional[dict[str, Any]] = None @classmethod def from_request(cls, request: ChatCompletionRequest): return cls( model=request.model, tools=request.tools, tool_choice=request.tool_choice, chat_template_kwargs=request.chat_template_kwargs, ) @nvtx_range_debug("chat_harmony_post_processor") def chat_harmony_post_processor( rsp: GenerationResult, args: ChatCompletionPostprocArgs) -> ChatCompletionResponse: response = handle_non_streaming_response( tools=args.tools, tool_choice=args.tool_choice, outputs=rsp.outputs, model=args.model, num_prompt_tokens=args.num_prompt_tokens, ) return response @nvtx_range_debug("chat_harmony_streaming_post_processor") def chat_harmony_streaming_post_processor( rsp: GenerationResult, args: ChatCompletionPostprocArgs) -> List[str]: response = handle_streaming_response( tools=args.tools, tool_choice=args.tool_choice, result=rsp, model=args.model, request_id=args.request_id, done=rsp._done, num_prompt_tokens=args.num_prompt_tokens, ) return response @dataclass(kw_only=True) class ResponsesAPIPostprocArgs(PostprocArgs): model: str request: ResponsesRequest sampling_params: SamplingParams use_harmony: bool reasoning_parser: Optional[str] = None tool_parser: Optional[str] = None streaming_processor: Optional[ResponsesStreamingProcessor] = None @nvtx_range_debug("responses_api_post_processor") def responses_api_post_processor( rsp: GenerationResult, args: ResponsesAPIPostprocArgs) -> ResponsesResponse: return responses_api_create_response_non_store( generation_result=rsp, request=args.request, sampling_params=args.sampling_params, model_name=args.model, use_harmony=args.use_harmony, reasoning_parser=args.reasoning_parser, tool_parser=args.tool_parser, ) @nvtx_range_debug("responses_api_streaming_post_processor") def responses_api_streaming_post_processor( rsp: GenerationResult, args: ResponsesAPIPostprocArgs) -> List[str]: if args.streaming_processor is None: raise ValueError( "streaming_processor is required for streaming post-processing") outputs = args.streaming_processor.process_single_output(rsp) if rsp._done: outputs.append( args.streaming_processor.get_final_response_non_store(rsp)) return outputs