From 2aade46d18afbcaf9ee5be2fc1548f077ce03bbf Mon Sep 17 00:00:00 2001 From: Pengyun Lin <81065165+LinPoly@users.noreply.github.com> Date: Wed, 29 Oct 2025 15:48:29 +0800 Subject: [PATCH] [TRTLLM-8214][feat] Support Qwen3 tool parser (#8216) Signed-off-by: Pengyun Lin <81065165+LinPoly@users.noreply.github.com> --- requirements.txt | 1 + tensorrt_llm/commands/serve.py | 16 +- tensorrt_llm/serve/chat_utils.py | 9 + tensorrt_llm/serve/openai_server.py | 3 + tensorrt_llm/serve/postprocess_handlers.py | 112 +++- tensorrt_llm/serve/tool_parser/__init__.py | 3 + .../serve/tool_parser/base_tool_parser.py | 324 ++++++++++ tensorrt_llm/serve/tool_parser/core_types.py | 35 + .../serve/tool_parser/qwen3_tool_parser.py | 114 ++++ .../serve/tool_parser/tool_parser_factory.py | 21 + tensorrt_llm/serve/tool_parser/utils.py | 56 ++ tests/integration/defs/test_e2e.py | 7 + .../integration/test_lists/test-db/l0_a10.yml | 3 + tests/unittest/llmapi/apps/README.md | 2 +- .../llmapi/apps/_test_openai_tool_call.py | 121 ++++ .../unittest/llmapi/apps/test_tool_parsers.py | 597 ++++++++++++++++++ 16 files changed, 1405 insertions(+), 19 deletions(-) create mode 100644 tensorrt_llm/serve/tool_parser/__init__.py create mode 100644 tensorrt_llm/serve/tool_parser/base_tool_parser.py create mode 100644 tensorrt_llm/serve/tool_parser/core_types.py create mode 100644 tensorrt_llm/serve/tool_parser/qwen3_tool_parser.py create mode 100644 tensorrt_llm/serve/tool_parser/tool_parser_factory.py create mode 100644 tensorrt_llm/serve/tool_parser/utils.py create mode 100644 tests/unittest/llmapi/apps/_test_openai_tool_call.py create mode 100644 tests/unittest/llmapi/apps/test_tool_parsers.py diff --git a/requirements.txt b/requirements.txt index 702004c8e1..7906df222a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -78,3 +78,4 @@ nvidia-cutlass-dsl==4.2.1; python_version >= "3.10" numba-cuda>=0.19.0 # WAR for nvbugs/5501820 plotly numexpr<2.14.0 # WAR for attempted use of nonexistent numpy.typing +partial_json_parser diff --git a/tensorrt_llm/commands/serve.py b/tensorrt_llm/commands/serve.py index c28d199408..f4f188fdea 100644 --- a/tensorrt_llm/commands/serve.py +++ b/tensorrt_llm/commands/serve.py @@ -33,6 +33,7 @@ from tensorrt_llm.llmapi.mpi_session import find_free_port from tensorrt_llm.llmapi.reasoning_parser import ReasoningParserFactory from tensorrt_llm.logger import logger, severity_map from tensorrt_llm.serve import OpenAIDisaggServer, OpenAIServer +from tensorrt_llm.serve.tool_parser import ToolParserFactory # Global variable to store the Popen object of the child process _child_p_global: Optional[subprocess.Popen] = None @@ -150,6 +151,7 @@ def launch_server( host: str, port: int, llm_args: dict, + tool_parser: Optional[str] = None, metadata_server_cfg: Optional[MetadataServerConfig] = None, server_role: Optional[ServerRole] = None, disagg_cluster_config: Optional[DisaggClusterConfig] = None, @@ -173,6 +175,7 @@ def launch_server( server = OpenAIServer(llm=llm, model=model, + tool_parser=tool_parser, server_role=server_role, metadata_server_cfg=metadata_server_cfg, disagg_cluster_config=disagg_cluster_config, @@ -311,6 +314,12 @@ class ChoiceWithAlias(click.Choice): default=None, help="[Experimental] Specify the parser for reasoning models.", ) +@click.option( + "--tool_parser", + type=click.Choice(ToolParserFactory.parsers.keys()), + default=None, + help="[Experimental] Specify the parser for tool models.", +) @click.option("--metadata_server_config_file", type=str, default=None, @@ -352,7 +361,8 @@ def serve( gpus_per_node: Optional[int], kv_cache_free_gpu_memory_fraction: float, num_postprocess_workers: int, trust_remote_code: bool, extra_llm_api_options: Optional[str], reasoning_parser: Optional[str], - metadata_server_config_file: Optional[str], server_role: Optional[str], + tool_parser: Optional[str], metadata_server_config_file: Optional[str], + server_role: Optional[str], fail_fast_on_attention_window_too_large: bool, otlp_traces_endpoint: Optional[str], enable_chunked_prefill: bool, disagg_cluster_uri: Optional[str], media_io_kwargs: Optional[str]): @@ -423,8 +433,8 @@ def serve( multimodal_server_config = MultimodalServerConfig( media_io_kwargs=parsed_media_io_kwargs) - launch_server(host, port, llm_args, metadata_server_cfg, server_role, - disagg_cluster_config, multimodal_server_config) + launch_server(host, port, llm_args, tool_parser, metadata_server_cfg, + server_role, disagg_cluster_config, multimodal_server_config) @click.command("mm_embedding_serve") diff --git a/tensorrt_llm/serve/chat_utils.py b/tensorrt_llm/serve/chat_utils.py index 5f9f5d84ab..acda26b511 100644 --- a/tensorrt_llm/serve/chat_utils.py +++ b/tensorrt_llm/serve/chat_utils.py @@ -1,3 +1,4 @@ +import uuid from functools import partial from typing import (Any, Callable, Coroutine, Dict, Iterable, List, Literal, Optional, Tuple, TypeAlias, TypedDict, Union, cast) @@ -220,3 +221,11 @@ def check_multiple_response(n: int, backend: Optional[str]): if n > 1 and backend == "pytorch": raise ValueError( "Multiple response is not supported in PyTorch workflow") + + +def make_tool_call_id(id_type: str = "random", func_name=None, idx=None): + if id_type == "kimi_k2": + return f"functions.{func_name}:{idx}" + else: + # by default return random + return f"chatcmpl-tool-{uuid.uuid4().hex}" diff --git a/tensorrt_llm/serve/openai_server.py b/tensorrt_llm/serve/openai_server.py index f6eb18d11c..6edc7c0744 100644 --- a/tensorrt_llm/serve/openai_server.py +++ b/tensorrt_llm/serve/openai_server.py @@ -78,12 +78,14 @@ class OpenAIServer: def __init__(self, llm: Union[LLM, MultimodalEncoder], model: str, + tool_parser: Optional[str], server_role: Optional[ServerRole], metadata_server_cfg: MetadataServerConfig, disagg_cluster_config: Optional[DisaggClusterConfig] = None, multimodal_server_config: Optional[MultimodalServerConfig] = None): self.llm = llm self.tokenizer = llm.tokenizer + self.tool_parser = tool_parser self.metadata_server = create_metadata_server(metadata_server_cfg) self.disagg_cluster_config = disagg_cluster_config self.multimodal_server_config = multimodal_server_config @@ -532,6 +534,7 @@ class OpenAIServer: prompt["multi_modal_data"] = mm_data postproc_args.reasoning_parser = self.llm.args.reasoning_parser + postproc_args.tool_parser = self.tool_parser if conversation and conversation[-1].get( "content") and conversation[-1].get("role") == get_role(): postproc_args.last_message_content = conversation[-1]["content"] diff --git a/tensorrt_llm/serve/postprocess_handlers.py b/tensorrt_llm/serve/postprocess_handlers.py index ece007b539..f9d78a5354 100644 --- a/tensorrt_llm/serve/postprocess_handlers.py +++ b/tensorrt_llm/serve/postprocess_handlers.py @@ -10,6 +10,7 @@ from ..llmapi.reasoning_parser import (BaseReasoningParser, ReasoningParserFactory) from ..llmapi.tokenizer import TransformersTokenizer # yapf: disable +from .chat_utils import make_tool_call_id from .harmony_adapter import (handle_non_streaming_response, handle_streaming_response) from .openai_protocol import (ChatCompletionLogProbs, @@ -23,9 +24,13 @@ from .openai_protocol import (ChatCompletionLogProbs, CompletionRequest, CompletionResponse, CompletionResponseChoice, CompletionResponseStreamChoice, - CompletionStreamResponse, DeltaMessage, - FunctionCall, PromptTokensDetails, StreamOptions, - ToolCall, UsageInfo, to_disaggregated_params) + CompletionStreamResponse, DeltaFunctionCall, + DeltaMessage, DeltaToolCall, FunctionCall, + PromptTokensDetails, StreamOptions, ToolCall, + UsageInfo, to_disaggregated_params) +from .tool_parser.base_tool_parser import BaseToolParser +from .tool_parser.core_types import ToolCallItem +from .tool_parser.tool_parser_factory import ToolParserFactory # yapf: enable @@ -33,8 +38,8 @@ from .openai_protocol import (ChatCompletionLogProbs, @dataclass(kw_only=True) class ChatPostprocArgs(PostprocArgs): echo: bool = False - role: str = None - model: str = None + role: str + model: str num_choices: int = 1 tools: Optional[List[ChatCompletionToolsParam]] = None tool_choice: Optional[Union[Literal["none"], @@ -44,8 +49,11 @@ class ChatPostprocArgs(PostprocArgs): stream_options: Optional[StreamOptions] = None last_message_content: Optional[str] = None reasoning_parser: Optional[str] = None + tool_parser: Optional[str] = None reasoning_parser_dict: dict[int, BaseReasoningParser] = field( default_factory=dict) + tool_parser_dict: dict[int, BaseToolParser] = field(default_factory=dict) + has_tool_call: dict[int, bool] = field(default_factory=dict) @classmethod def from_request(cls, request: ChatCompletionRequest): @@ -116,6 +124,31 @@ def apply_reasoning_parser(args: ChatPostprocArgs, output_index: int, text: str, return content, reasoning_content +def apply_tool_parser(args: ChatPostprocArgs, output_index: int, text: str, + streaming: bool) -> Tuple[str, List[ToolCallItem]]: + tool_parser = None + tools = args.tools + if args.tool_parser is not None and tools is not None: + if output_index not in args.tool_parser_dict: + args.tool_parser_dict[ + output_index] = ToolParserFactory.create_tool_parser( + args.tool_parser) + tool_parser = args.tool_parser_dict[output_index] + + if tool_parser is not None and tools is not None: + if not streaming: + result = tool_parser.detect_and_parse(text, tools) + else: + result = tool_parser.parse_streaming_increment(text, tools) + normal_text, calls = result.normal_text, result.calls + if result.calls: + args.has_tool_call[output_index] = True + else: + normal_text, calls = text, [] + + return normal_text, calls + + @nvtx_range_debug("chat_stream_post_processor") def chat_stream_post_processor(rsp: GenerationResultBase, args: ChatPostprocArgs) -> List[str]: @@ -176,27 +209,63 @@ def chat_stream_post_processor(rsp: GenerationResultBase, if args.tool_choice and type( args.tool_choice) is ChatCompletionNamedToolChoiceParam: delta_message = DeltaMessage(tool_calls=[ - ToolCall(function=FunctionCall( - name=args.tool_choice.function.name, arguments=delta_text)) - ]) + DeltaToolCall( + function=DeltaFunctionCall( + name=args.tool_choice.function.name, + arguments=delta_text), + index=i, + ), + ], ) else: - delta_message = DeltaMessage(content=delta_text, - reasoning_content=reasoning_delta_text) + delta_text, calls = apply_tool_parser(args, i, delta_text, True) + tool_calls = [] + for call_item in calls: + # Tool call ID should be generated only once per tool call + if call_item.name: + # First chunk: include ID and function name + tool_call_id = make_tool_call_id() + function_name = call_item.name + else: + # Subsequent chunks: null ID and name for argument deltas + tool_call_id = None + function_name = None + + tool_calls.append( + DeltaToolCall( + id=tool_call_id, + index=call_item.tool_index, + function=DeltaFunctionCall( + name=function_name, + arguments=call_item.parameters, + ), + )) + if tool_calls or delta_text or reasoning_delta_text or output.finish_reason: + delta_message = DeltaMessage( + content=delta_text, + reasoning_content=reasoning_delta_text, + tool_calls=tool_calls if tool_calls else None) + else: + continue choice = ChatCompletionResponseStreamChoice( index=i, delta=delta_message, - finish_reason=None, avg_decoded_tokens_per_iter=getattr(rsp, 'avg_decoded_tokens_per_iter', - None)) + None), + stop_reason=output.stop_reason, + ) if args.return_logprobs: logprobs = output.logprobs_diff token_ids = output.token_ids_diff choice.logprobs = create_logprobs(token_ids, args.tokenizer, logprobs, args.top_logprobs) if output.finish_reason is not None: - choice.finish_reason = output.finish_reason + if output.finish_reason == "stop" and args.has_tool_call.get( + i, False): + choice.finish_reason = "tool_calls" + else: + choice.finish_reason = output.finish_reason choice.stop_reason = output.stop_reason finish_reason_sent[i] = True chunk = ChatCompletionStreamResponse(choices=[choice], model=args.model) @@ -247,21 +316,34 @@ def chat_response_post_processor( name=args.tool_choice.function.name, arguments=text)) ]) else: + if text is None: + text = "" + text, calls = apply_tool_parser(args, output.index, text, False) + tool_calls = [ + ToolCall(function=FunctionCall(name=call.name or "", + arguments=call.parameters)) + for call in calls + ] message = ChatMessage(role=role, content=text, - reasoning_content=reasoning_text) + reasoning_content=reasoning_text, + tool_calls=tool_calls) disaggregated_params = to_disaggregated_params( output.disaggregated_params) choice = ChatCompletionResponseChoice( index=output.index, message=message, - finish_reason=output.finish_reason, stop_reason=output.stop_reason, disaggregated_params=disaggregated_params, avg_decoded_tokens_per_iter=getattr(rsp, 'avg_decoded_tokens_per_iter', None), ) + if output.finish_reason == "stop" and args.has_tool_call.get( + output.index, False): + choice.finish_reason = "tool_calls" + else: + choice.finish_reason = output.finish_reason if args.return_logprobs: choice.logprobs = create_logprobs(output.token_ids, args.tokenizer, diff --git a/tensorrt_llm/serve/tool_parser/__init__.py b/tensorrt_llm/serve/tool_parser/__init__.py new file mode 100644 index 0000000000..d862620c95 --- /dev/null +++ b/tensorrt_llm/serve/tool_parser/__init__.py @@ -0,0 +1,3 @@ +from .tool_parser_factory import ToolParserFactory + +__all__ = ["ToolParserFactory"] diff --git a/tensorrt_llm/serve/tool_parser/base_tool_parser.py b/tensorrt_llm/serve/tool_parser/base_tool_parser.py new file mode 100644 index 0000000000..9fa87ec0bf --- /dev/null +++ b/tensorrt_llm/serve/tool_parser/base_tool_parser.py @@ -0,0 +1,324 @@ +# Adapted from https://github.com/sgl-project/sglang/blob/083629c23564e1a64deaa052f1df5c5d914358d8/python/sglang/srt/function_call/base_format_detector.py +import json +from abc import ABC, abstractmethod +from typing import Any, Dict, List + +from partial_json_parser.core.exceptions import MalformedJSON +from partial_json_parser.core.options import Allow + +from tensorrt_llm.logger import logger + +from ..openai_protocol import ChatCompletionToolsParam as Tool +from .core_types import StreamingParseResult, ToolCallItem, _GetInfoFunc +from .utils import find_common_prefix, is_complete_json, partial_json_loads + + +class BaseToolParser(ABC): + """Base class providing two sets of interfaces: one-time and streaming incremental.""" + + def __init__(self): + # Streaming state management + # Buffer for accumulating incomplete patterns that arrive across multiple streaming chunks + self._buffer = "" + # Stores complete tool call info (name and arguments) for each tool being parsed. + # Used by serving layer for completion handling when streaming ends. + # Format: [{"name": str, "arguments": dict}, ...] + self.prev_tool_call_arr: List[Dict] = [] + # Index of currently streaming tool call. Starts at -1 (no active tool), + # increments as each tool completes. Tracks which tool's arguments are streaming. + self.current_tool_id: int = -1 + # Flag for whether current tool's name has been sent to client. + # Tool names sent first with empty parameters, then arguments stream incrementally. + self.current_tool_name_sent: bool = False + # Tracks raw JSON string content streamed to client for each tool's arguments. + # Critical for serving layer to calculate remaining content when streaming ends. + # Each index corresponds to a tool_id. Example: ['{"location": "San Francisco"', '{"temp": 72'] + self.streamed_args_for_tool: List[str] = [] + + # Token configuration (override in subclasses) + self.bot_token = "" # nosec B105 + self.eot_token = "" # nosec B105 + self.tool_call_separator = ", " + + def _get_tool_indices(self, tools: List[Tool]) -> Dict[str, int]: + """ + Get a mapping of tool names to their indices in the tools list. + + This utility method creates a dictionary mapping function names to their + indices in the tools list, which is commonly needed for tool validation + and ToolCallItem creation. + + Args: + tools: List of available tools + + Returns: + Dictionary mapping tool names to their indices + """ + return { + tool.function.name: i + for i, tool in enumerate(tools) if tool.function.name + } + + def parse_base_json(self, action: Any, + tools: List[Tool]) -> List[ToolCallItem]: + tool_indices = self._get_tool_indices(tools) + if not isinstance(action, list): + action = [action] + + results = [] + for act in action: + name = act.get("name") + if name and name in tool_indices: + results.append( + ToolCallItem( + tool_index= + -1, # Caller should update this based on the actual tools array called + name=name, + parameters=json.dumps( + act.get("parameters") or act.get("arguments", {}), + ensure_ascii=False, + ), + )) + else: + logger.warning( + f"Model attempted to call undefined function: {name}") + + return results + + @abstractmethod + def detect_and_parse(self, text: str, + tools: List[Tool]) -> StreamingParseResult: + """ + Parses the text in one go. Returns success=True if the format matches, otherwise False. + Note that leftover_text here represents "content that this parser will not consume further". + """ + action = json.loads(text) + return StreamingParseResult(calls=self.parse_base_json(action, tools)) + + def _ends_with_partial_token(self, buffer: str, bot_token: str) -> int: + """ + Check if buffer ends with a partial bot_token. + Return the length of the partial bot_token. + + For some format, the bot_token is not a token in model's vocabulary, such as + `[TOOL_CALLS] [` in Mistral. + """ + for i in range(1, min(len(buffer) + 1, len(bot_token))): + if bot_token.startswith(buffer[-i:]): + return i + return 0 + + def parse_streaming_increment(self, new_text: str, + tools: List[Tool]) -> StreamingParseResult: + """ + Streaming incremental parsing with tool validation. + + This base implementation works best with formats where: + 1. bot_token is followed immediately by JSON (e.g., bot_token + JSON_array) + 2. JSON can be parsed incrementally using partial_json_loads + 3. Multiple tool calls are separated by "; " or ", " + + Examples of incompatible formats (need custom implementation, may reuse some logic from this class): + - Each tool call is wrapped in a separate block: See Qwen25Detector + - Multiple separate blocks: [TOOL_CALLS] [...] \n [TOOL_CALLS] [...] + - Tool call is Pythonic style + + For incompatible formats, detectors should override this method with custom logic. + """ + # Append new text to buffer + self._buffer += new_text + current_text = self._buffer + + # The current_text has tool_call if it is the start of a new tool call sequence + # or it is the start of a new tool call after a tool call separator, when there is a previous tool call + if not (self.has_tool_call(current_text) or + (self.current_tool_id > 0 + and current_text.startswith(self.tool_call_separator))): + # Only clear buffer if we're sure no tool call is starting + if not self._ends_with_partial_token(self._buffer, self.bot_token): + normal_text = self._buffer + self._buffer = "" + if self.eot_token in normal_text: + normal_text = normal_text.replace(self.eot_token, "") + return StreamingParseResult(normal_text=normal_text) + else: + # Might be partial bot_token, keep buffering + return StreamingParseResult() + + # Build tool indices if not already built + if not hasattr(self, "_tool_indices"): + self._tool_indices = self._get_tool_indices(tools) + + flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR + + try: + try: + tool_call_pos = current_text.find(self.bot_token) + if tool_call_pos != -1: + start_idx = tool_call_pos + len(self.bot_token) + elif self.current_tool_id > 0 and current_text.startswith( + self.tool_call_separator): + start_idx = len(self.tool_call_separator) + else: + start_idx = 0 + + if start_idx >= len(current_text): + return StreamingParseResult() + + (obj, end_idx) = partial_json_loads(current_text[start_idx:], + flags) + + is_current_complete = is_complete_json( + current_text[start_idx:start_idx + end_idx]) + + # Validate tool name if present + if "name" in obj and obj["name"] not in self._tool_indices: + # Invalid tool name - reset state + self._buffer = "" + self.current_tool_id = -1 + self.current_tool_name_sent = False + if self.streamed_args_for_tool: + self.streamed_args_for_tool.pop() + return StreamingParseResult() + + # Handle parameters/arguments consistency + # NOTE: we assume here that the obj is always partial of a single tool call + if "parameters" in obj: + assert ("arguments" not in obj + ), "model generated both parameters and arguments" + obj["arguments"] = obj["parameters"] + + current_tool_call = obj + + except MalformedJSON: + return StreamingParseResult() + + if not current_tool_call: + return StreamingParseResult() + + # Case 1: Handle tool name streaming + # This happens when we encounter a tool but haven't sent its name yet + if not self.current_tool_name_sent: + function_name = current_tool_call.get("name") + + if function_name and function_name in self._tool_indices: + # If this is a new tool (current_tool_id was -1), initialize it + if self.current_tool_id == -1: + self.current_tool_id = 0 + self.streamed_args_for_tool.append("") + # If this is a subsequent tool, ensure streamed_args_for_tool is large enough + elif self.current_tool_id >= len( + self.streamed_args_for_tool): + while len(self.streamed_args_for_tool + ) <= self.current_tool_id: + self.streamed_args_for_tool.append("") + + # Send the tool name with empty parameters + res = StreamingParseResult(calls=[ + ToolCallItem( + tool_index=self.current_tool_id, + name=function_name, + parameters="", + ) + ], ) + self.current_tool_name_sent = True + else: + res = StreamingParseResult() + + # Case 2: Handle streaming arguments + # This happens when we've already sent the tool name and now need to stream arguments incrementally + else: + cur_arguments = current_tool_call.get("arguments") + res = StreamingParseResult() + + if cur_arguments: + # Calculate how much of the arguments we've already streamed + sent = len( + self.streamed_args_for_tool[self.current_tool_id]) + cur_args_json = json.dumps(cur_arguments) + prev_arguments = None + if self.current_tool_id < len(self.prev_tool_call_arr): + prev_arguments = self.prev_tool_call_arr[ + self.current_tool_id].get("arguments") + + argument_diff = None + + # If the current tool's JSON is complete, send all remaining arguments + if is_current_complete: + argument_diff = cur_args_json[sent:] + completing_tool_id = ( + self.current_tool_id + ) # Save the ID of the tool that's completing + + # Only remove the processed portion, keep unprocessed content + self._buffer = current_text[start_idx + end_idx:] + + if self.current_tool_id < len(self.prev_tool_call_arr): + self.prev_tool_call_arr[ + self.current_tool_id].clear() + self.current_tool_name_sent = False + self.streamed_args_for_tool[self.current_tool_id] = "" + self.current_tool_id += 1 + + # If the tool is still being parsed, send incremental changes + elif prev_arguments: + prev_args_json = json.dumps(prev_arguments) + if cur_args_json != prev_args_json: + prefix = find_common_prefix(prev_args_json, + cur_args_json) + argument_diff = prefix[sent:] + + # Send the argument diff if there's something new + if argument_diff is not None: + # Use the correct tool_index: completing_tool_id for completed tools, current_tool_id for ongoing + tool_index_to_use = (completing_tool_id + if is_current_complete else + self.current_tool_id) + res = StreamingParseResult(calls=[ + ToolCallItem( + tool_index=tool_index_to_use, + parameters=argument_diff, + ) + ], ) + if not is_current_complete: + self.streamed_args_for_tool[ + self.current_tool_id] += argument_diff + + # Update prev_tool_call_arr with current state + if self.current_tool_id >= 0: + # Ensure prev_tool_call_arr is large enough + while len(self.prev_tool_call_arr) <= self.current_tool_id: + self.prev_tool_call_arr.append({}) + self.prev_tool_call_arr[ + self.current_tool_id] = current_tool_call + + return res + + except Exception as e: + logger.error(f"Error in parse_streaming_increment: {e}") + return StreamingParseResult() + + @abstractmethod + def has_tool_call(self, text: str) -> bool: + """ + Check if the given text contains function call markers specific to this format. + """ + raise NotImplementedError() + + def supports_structural_tag(self) -> bool: + """Return True if this detector supports structural tag format.""" + return True + + @abstractmethod + def structure_info(self) -> _GetInfoFunc: + """ + Return a function that creates StructureInfo for constrained generation. + + The returned function takes a tool name and returns a StructureInfo object + containing the begin/end patterns and trigger tokens needed for constrained + generation of function calls in this format. + + Returns: + A function that takes a tool name (str) and returns StructureInfo + """ + raise NotImplementedError() diff --git a/tensorrt_llm/serve/tool_parser/core_types.py b/tensorrt_llm/serve/tool_parser/core_types.py new file mode 100644 index 0000000000..88d14696a6 --- /dev/null +++ b/tensorrt_llm/serve/tool_parser/core_types.py @@ -0,0 +1,35 @@ +# Adapted from https://github.com/sgl-project/sglang/blob/083629c23564e1a64deaa052f1df5c5d914358d8/python/sglang/srt/function_call/qwen25_detector.py +from dataclasses import dataclass +from typing import Callable, List, Optional + +from pydantic import BaseModel + + +class ToolCallItem(BaseModel): + """Simple encapsulation of the parsed ToolCall result for easier usage in streaming contexts.""" + + tool_index: int + name: Optional[str] = None + parameters: str # JSON string + + +class StreamingParseResult(BaseModel): + """Result of streaming incremental parsing.""" + + normal_text: str = "" + calls: List[ToolCallItem] = [] + + +@dataclass +class StructureInfo: + begin: str + end: str + trigger: str + + +""" +Helper alias of function +Usually it is a function that takes a name string and returns a StructureInfo object, +which can be used to construct a structural_tag object +""" +_GetInfoFunc = Callable[[str], StructureInfo] diff --git a/tensorrt_llm/serve/tool_parser/qwen3_tool_parser.py b/tensorrt_llm/serve/tool_parser/qwen3_tool_parser.py new file mode 100644 index 0000000000..298389f47e --- /dev/null +++ b/tensorrt_llm/serve/tool_parser/qwen3_tool_parser.py @@ -0,0 +1,114 @@ +# Adapted from https://github.com/sgl-project/sglang/blob/083629c23564e1a64deaa052f1df5c5d914358d8/python/sglang/srt/function_call/qwen25_detector.py +import json +import re +from typing import List + +from tensorrt_llm.logger import logger + +from ..openai_protocol import ChatCompletionToolsParam as Tool +from .base_tool_parser import BaseToolParser +from .core_types import StreamingParseResult, StructureInfo, _GetInfoFunc + + +class Qwen3ToolParser(BaseToolParser): + """ + Detector for Qwen 2.5 and Qwen 3 model function call format. + + Format Structure: + ``` + \n{"name":"func1", "arguments":{...}}\n\n\n{"name":"func2", "arguments":{...}}\n + ``` + + Key Components: + - Tool Call Tags: `` and `` wrap each individual call + - Function Call Object: JSON object with "name" and "arguments" fields + + Reference: https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct?chat_template=default + """ + + def __init__(self): + """ + Initializes the detector with necessary state variables. + """ + super().__init__() + self.bot_token = "\n" # nosec B105 + self.eot_token = "\n" # nosec B105 + self.tool_call_separator = "\n" + self._normal_text_buffer = "" # Buffer for handling partial end tokens + + def has_tool_call(self, text: str) -> bool: + """Check if the text contains a Qwen 3 format tool call.""" + return self.bot_token in text + + def detect_and_parse(self, text: str, + tools: List[Tool]) -> StreamingParseResult: + """ + One-time parsing: Detects and parses tool calls in the provided text. + + :param text: The complete text to parse. + :param tools: List of available tools. + :return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls. + """ + idx = text.find(self.bot_token) + normal_text = text[:idx].strip() if idx != -1 else text + if self.bot_token not in text: + return StreamingParseResult(normal_text=normal_text, calls=[]) + + # Find all \n...\n blocks + pattern = rf"{re.escape(self.bot_token)}(.*?){re.escape(self.eot_token)}" + match_result_list = re.findall(pattern, text, re.DOTALL) + calls = [] + for match_result in match_result_list: + try: + parsed_call = json.loads(match_result.strip()) + calls.extend(self.parse_base_json(parsed_call, tools)) + except json.JSONDecodeError as e: + logger.warning( + f"Failed to parse JSON part: {match_result}, JSON parse error: {str(e)}" + ) + continue + return StreamingParseResult(normal_text=normal_text, calls=calls) + + def parse_streaming_increment(self, new_text: str, + tools: List[Tool]) -> StreamingParseResult: + """ + Streaming incremental parsing for Qwen 3 tool calls. + Uses base class implementation with buffering to handle partial end tokens. + """ + result = super().parse_streaming_increment(new_text, tools) + + # Handle partial end tokens that are streamed character by character + if result.normal_text: + self._normal_text_buffer += result.normal_text + + # Check if buffer contains complete end token (without leading newline) + end_token_without_newline = self.eot_token[1:] # "" + if end_token_without_newline in self._normal_text_buffer: + cleaned_text = self._normal_text_buffer.replace( + end_token_without_newline, "") + self._normal_text_buffer = "" + result.normal_text = cleaned_text + else: + # Check if buffer might contain partial end token at the end + partial_match_len = self._ends_with_partial_token( + self._normal_text_buffer, end_token_without_newline) + + if partial_match_len: + # Keep potential partial match in buffer, return the rest + result.normal_text = self._normal_text_buffer[: + -partial_match_len] + self._normal_text_buffer = self._normal_text_buffer[ + -partial_match_len:] + else: + # No partial match, return all buffered text + result.normal_text = self._normal_text_buffer + self._normal_text_buffer = "" + + return result + + def structure_info(self) -> _GetInfoFunc: + return lambda name: StructureInfo( + begin='\n{"name":"' + name + '", "arguments":', + end="}\n", + trigger="", + ) diff --git a/tensorrt_llm/serve/tool_parser/tool_parser_factory.py b/tensorrt_llm/serve/tool_parser/tool_parser_factory.py new file mode 100644 index 0000000000..73b02510a6 --- /dev/null +++ b/tensorrt_llm/serve/tool_parser/tool_parser_factory.py @@ -0,0 +1,21 @@ +from typing import Type + +from .base_tool_parser import BaseToolParser +from .qwen3_tool_parser import Qwen3ToolParser + + +class ToolParserFactory: + parsers: dict[str, Type[BaseToolParser]] = { + "qwen3": Qwen3ToolParser, + } + + @staticmethod + def create_tool_parser(tool_parser: str) -> BaseToolParser: + try: + tool_parser_class = ToolParserFactory.parsers[tool_parser.lower()] + return tool_parser_class() + except KeyError as e: + raise ValueError( + f"Invalid tool_parser: {tool_parser}\n" + f"Supported parsers: {list(ToolParserFactory.parsers.keys())}" + ) from e diff --git a/tensorrt_llm/serve/tool_parser/utils.py b/tensorrt_llm/serve/tool_parser/utils.py new file mode 100644 index 0000000000..7666036be5 --- /dev/null +++ b/tensorrt_llm/serve/tool_parser/utils.py @@ -0,0 +1,56 @@ +# Adapted from https://github.com/sgl-project/sglang/blob/083629c23564e1a64deaa052f1df5c5d914358d8/python/sglang/srt/function_call/qwen25_detector.py +import json +from json import JSONDecodeError, JSONDecoder +from json.decoder import WHITESPACE +from typing import Any + +import partial_json_parser +from partial_json_parser.core.options import Allow + + +def find_common_prefix(s1: str, s2: str) -> str: + prefix = "" + min_length = min(len(s1), len(s2)) + for i in range(0, min_length): + if s1[i] == s2[i]: + prefix += s1[i] + else: + break + return prefix + + +def partial_json_loads(input_str: str, flags: Allow) -> tuple[Any, int]: + """ + Parse incomplete or partial JSON strings commonly encountered during streaming. + + Args: + input_str (str): The potentially incomplete JSON string to parse. + flags (Allow): Bitwise flags controlling what types of partial data are allowed. + Common flags include: + - Allow.STR: Allow partial strings (e.g., '"hello wo' -> 'hello wo') + - Allow.OBJ: Allow partial objects (e.g., '{"key":' -> {'key': None}) + - Allow.ARR: Allow partial arrays (e.g., '[1, 2,' -> [1, 2]) + - Allow.ALL: Allow all types of partial data + + Returns: + Tuple[Any, int]: A tuple containing: + - parsed_object: The Python object parsed from the JSON + - consumed_length: Number of characters consumed from input_str + """ + try: + return (partial_json_parser.loads(input_str, flags), len(input_str)) + except (JSONDecodeError, IndexError) as e: + msg = getattr(e, "msg", str(e)) + if "Extra data" in msg or "pop from empty list" in msg: + start = WHITESPACE.match(input_str, 0).end() + obj, end = JSONDecoder().raw_decode(input_str, start) + return obj, end + raise + + +def is_complete_json(input_str: str) -> bool: + try: + json.loads(input_str) + return True + except JSONDecodeError: + return False diff --git a/tests/integration/defs/test_e2e.py b/tests/integration/defs/test_e2e.py index 3f1bab6c94..51b0b66c36 100644 --- a/tests/integration/defs/test_e2e.py +++ b/tests/integration/defs/test_e2e.py @@ -1613,6 +1613,13 @@ def test_openai_reasoning(llm_root, llm_venv, backend: str): ]) +def test_openai_tool_call(llm_root, llm_venv): + test_root = unittest_path() / "llmapi" / "apps" + llm_venv.run_cmd( + ["-m", "pytest", + str(test_root / "_test_openai_tool_call.py")]) + + @pytest.mark.parametrize("sampler", ["torch_sampler", "trtllm_sampler"]) def test_openai_completions_with_logit_bias(llm_root, llm_venv, sampler: str): test_root = unittest_path() / "llmapi" / "apps" diff --git a/tests/integration/test_lists/test-db/l0_a10.yml b/tests/integration/test_lists/test-db/l0_a10.yml index 9aef279b8d..5fc56bd938 100644 --- a/tests/integration/test_lists/test-db/l0_a10.yml +++ b/tests/integration/test_lists/test-db/l0_a10.yml @@ -57,6 +57,7 @@ l0_a10: - test_e2e.py::test_trtllm_serve_top_logprobs[pytorch] - test_e2e.py::test_openai_misc_example[pytorch] - test_e2e.py::test_openai_reasoning[pytorch] + - test_e2e.py::test_openai_tool_call - test_e2e.py::test_openai_completions_example[pytorch] - test_e2e.py::test_openai_chat_example[pytorch] TIMEOUT (90) - test_e2e.py::test_trtllm_bench_request_rate_and_concurrency[enable_concurrency-] @@ -71,6 +72,8 @@ l0_a10: - unittest/llmapi/test_additional_model_outputs.py # executor - unittest/executor/test_rpc.py + # trtllm-serve CPU-only + - unittest/llmapi/apps/test_tool_parsers.py - condition: ranges: system_gpu_count: diff --git a/tests/unittest/llmapi/apps/README.md b/tests/unittest/llmapi/apps/README.md index ff316cfa1e..4d066a807a 100644 --- a/tests/unittest/llmapi/apps/README.md +++ b/tests/unittest/llmapi/apps/README.md @@ -1,3 +1,3 @@ -This directory contains the end-to-end tests for the LLM API applications in `examples/apps`. +This directory contains the end-to-end tests for `trtllm-serve`. These tests are triggered in the `test_e2e.py`. diff --git a/tests/unittest/llmapi/apps/_test_openai_tool_call.py b/tests/unittest/llmapi/apps/_test_openai_tool_call.py new file mode 100644 index 0000000000..18dc644997 --- /dev/null +++ b/tests/unittest/llmapi/apps/_test_openai_tool_call.py @@ -0,0 +1,121 @@ +import json +import os +import sys + +import openai +import pytest + +from .openai_server import RemoteOpenAIServer + +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +from test_llm import get_model_path + +TOOLS = [ + { + "type": "function", + "function": { + "name": "get_current_temperature", + "description": "Get current temperature at a location.", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": + "string", + "description": + 'The location to get the temperature for, in the format "City, State, Country".', + }, + "unit": { + "type": + "string", + "enum": ["celsius", "fahrenheit"], + "description": + 'The unit to return the temperature in. Defaults to "celsius".', + }, + }, + "required": ["location"], + }, + }, + }, +] + + +def get_current_temperature(location: str, unit: str = "celsius") -> dict: + return {"temperature": 20 if unit == "celsius" else 68} + + +@pytest.fixture(scope="module", ids=["Qwen3-0.6B"]) +def model_name() -> str: + return "Qwen3/Qwen3-0.6B" + + +@pytest.fixture(scope="module") +def server(model_name: str): + model_path = get_model_path(model_name) + args = ["--tool_parser", "qwen3"] + with RemoteOpenAIServer(model_path, cli_args=args) as remote_server: + yield remote_server + + +@pytest.fixture(scope="module") +def client(server: RemoteOpenAIServer): + return server.get_async_client() + + +@pytest.mark.asyncio(loop_scope="module") +async def test_tool_parser(client: openai.AsyncOpenAI, model_name: str): + response = await client.chat.completions.create( + model=model_name, + messages=[{ + "role": "user", + "content": "What's the temperature in San Francisco now?" + }], + tools=TOOLS) + assert response.choices[0].finish_reason == "tool_calls" + message = response.choices[0].message + assert message.content is not None + assert message.tool_calls is not None + assert len(message.tool_calls) == 1 + tool_call = message.tool_calls[0] + assert tool_call.function.name == "get_current_temperature" + args = json.loads(tool_call.function.arguments) + get_current_temperature(**args) + + +@pytest.mark.asyncio(loop_scope="module") +async def test_tool_parser_streaming(client: openai.AsyncOpenAI, + model_name: str): + response = await client.chat.completions.create( + model=model_name, + messages=[{ + "role": "user", + "content": "What's the temperature in San Francisco now?" + }], + tools=TOOLS, + stream=True) + tool_id = None + tool_name = None + parameters = "" + finish_reason = None + + async for chunk in response: + if chunk.choices[0].delta.tool_calls: + tool_call = chunk.choices[0].delta.tool_calls[0] + if tool_call.id: + if tool_id is not None: + raise RuntimeError("tool_id already exists") + tool_id = tool_call.id + if tool_call.function.name: + if tool_name is not None: + raise RuntimeError("tool_name already exists") + tool_name = tool_call.function.name + if tool_call.function.arguments: + parameters += tool_call.function.arguments + if chunk.choices[0].finish_reason: + finish_reason = chunk.choices[0].finish_reason + assert tool_id is not None + assert tool_name == "get_current_temperature" + assert finish_reason == "tool_calls" + assert parameters + args = json.loads(parameters) + get_current_temperature(**args) diff --git a/tests/unittest/llmapi/apps/test_tool_parsers.py b/tests/unittest/llmapi/apps/test_tool_parsers.py new file mode 100644 index 0000000000..511f6a47fb --- /dev/null +++ b/tests/unittest/llmapi/apps/test_tool_parsers.py @@ -0,0 +1,597 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +import pytest + +from tensorrt_llm.serve.openai_protocol import (ChatCompletionToolsParam, + FunctionDefinition) +from tensorrt_llm.serve.tool_parser.base_tool_parser import BaseToolParser +from tensorrt_llm.serve.tool_parser.core_types import StructureInfo +from tensorrt_llm.serve.tool_parser.qwen3_tool_parser import Qwen3ToolParser + + +# Test fixtures for common tools +@pytest.fixture +def sample_tools(): + """Sample tools for testing.""" + return [ + ChatCompletionToolsParam( + type="function", + function=FunctionDefinition(name="get_weather", + description="Get the current weather", + parameters={ + "type": "object", + "properties": { + "location": { + "type": + "string", + "description": + "The city and state" + }, + "unit": { + "type": + "string", + "enum": + ["celsius", "fahrenheit"] + } + }, + "required": ["location"] + })), + ChatCompletionToolsParam( + type="function", + function=FunctionDefinition(name="search_web", + description="Search the web", + parameters={ + "type": "object", + "properties": { + "query": { + "type": + "string", + "description": + "The search query" + } + }, + "required": ["query"] + })) + ] + + +# Concrete implementation of BaseToolParser for testing +class ConcreteToolParser(BaseToolParser): + """Concrete implementation of BaseToolParser for testing abstract methods.""" + + def __init__(self): + super().__init__() + self.bot_token = "[TOOL_CALLS] " + self.eot_token = "[/TOOL_CALLS]" + + def has_tool_call(self, text: str) -> bool: + return self.bot_token in text + + def detect_and_parse(self, text: str, tools): + # Placeholder to avoid NotImplementedError + pass + + def structure_info(self): + return lambda name: StructureInfo( + begin=f'[TOOL_CALLS] {{"name":"{name}", "arguments":', + end="}[/TOOL_CALLS]", + trigger="[TOOL_CALLS]") + + +# ============================================================================ +# BaseToolParser Tests +# ============================================================================ + + +class TestBaseToolParser: + """Test suite for BaseToolParser class.""" + + def test_initialization(self): + """Test that BaseToolParser initializes correctly.""" + parser = ConcreteToolParser() + assert parser._buffer == "" + assert parser.prev_tool_call_arr == [] + assert parser.current_tool_id == -1 + assert parser.current_tool_name_sent is False + assert parser.streamed_args_for_tool == [] + + def test_get_tool_indices(self, sample_tools): + """Test _get_tool_indices correctly maps tool names to indices.""" + parser = ConcreteToolParser() + indices = parser._get_tool_indices(sample_tools) + + assert len(indices) == 2 + assert indices["get_weather"] == 0 + assert indices["search_web"] == 1 + + def test_get_tool_indices_empty(self): + """Test _get_tool_indices with empty tools list.""" + parser = ConcreteToolParser() + indices = parser._get_tool_indices([]) + assert indices == {} + + def test_parse_base_json_single_tool(self, sample_tools): + """Test parse_base_json with a single tool call.""" + parser = ConcreteToolParser() + action = { + "name": "get_weather", + "parameters": { + "location": "San Francisco" + } + } + + results = parser.parse_base_json(action, sample_tools) + + assert len(results) == 1 + assert results[0].name == "get_weather" + assert json.loads(results[0].parameters) == { + "location": "San Francisco" + } + + def test_parse_base_json_with_arguments_key(self, sample_tools): + """Test parse_base_json handles 'arguments' key instead of 'parameters'.""" + parser = ConcreteToolParser() + action = {"name": "search_web", "arguments": {"query": "TensorRT"}} + + results = parser.parse_base_json(action, sample_tools) + + assert len(results) == 1 + assert results[0].name == "search_web" + assert json.loads(results[0].parameters) == {"query": "TensorRT"} + + def test_parse_base_json_multiple_tools(self, sample_tools): + """Test parse_base_json with multiple tool calls.""" + parser = ConcreteToolParser() + actions = [{ + "name": "get_weather", + "parameters": { + "location": "Boston" + } + }, { + "name": "search_web", + "arguments": { + "query": "Python" + } + }] + + results = parser.parse_base_json(actions, sample_tools) + + assert len(results) == 2 + assert results[0].name == "get_weather" + assert results[1].name == "search_web" + + def test_parse_base_json_undefined_function(self, sample_tools): + """Test parse_base_json handles undefined function names gracefully.""" + parser = ConcreteToolParser() + action = {"name": "undefined_function", "parameters": {}} + + results = parser.parse_base_json(action, sample_tools) + + # Should return empty list and log warning + assert len(results) == 0 + + def test_parse_base_json_missing_parameters(self, sample_tools): + """Test parse_base_json handles missing parameters.""" + parser = ConcreteToolParser() + action = {"name": "get_weather"} + + results = parser.parse_base_json(action, sample_tools) + + assert len(results) == 1 + assert json.loads(results[0].parameters) == {} + + def test_ends_with_partial_token(self): + """Test _ends_with_partial_token detection.""" + parser = ConcreteToolParser() + + # Partial token at end (bot_token starts with the suffix) + assert parser._ends_with_partial_token("Some text [TOOL", + "[TOOL_CALLS] ") == 5 + assert parser._ends_with_partial_token("Some text [", + "[TOOL_CALLS] ") == 1 + assert parser._ends_with_partial_token("Some text [TOOL_CALLS", + "[TOOL_CALLS] ") == 11 + + # No partial token + assert parser._ends_with_partial_token("Some text", + "[TOOL_CALLS] ") == 0 + assert parser._ends_with_partial_token("Some text [XYZ", + "[TOOL_CALLS] ") == 0 + + # Complete token at end (entire buffer is bot_token prefix but not complete match) + # When buffer equals bot_token, it returns 0 because it's not a partial anymore + assert parser._ends_with_partial_token("text [TOOL_CALLS] ", + "[TOOL_CALLS] ") == 0 + + def test_parse_streaming_increment_no_tool_call(self, sample_tools): + """Test streaming parser returns normal text when no tool call present.""" + parser = ConcreteToolParser() + + result = parser.parse_streaming_increment("Hello, world!", sample_tools) + + assert result.normal_text == "Hello, world!" + assert len(result.calls) == 0 + + def test_parse_streaming_increment_partial_bot_token(self, sample_tools): + """Test streaming parser buffers partial bot token.""" + parser = ConcreteToolParser() + + # Send partial bot token + result = parser.parse_streaming_increment("[TOOL", sample_tools) + + # Should buffer and return nothing + assert result.normal_text == "" + assert len(result.calls) == 0 + assert parser._buffer == "[TOOL" + + def test_parse_streaming_increment_tool_name(self, sample_tools): + """Test streaming parser handles tool name streaming.""" + parser = ConcreteToolParser() + + # Send bot token with partial JSON containing name + result = parser.parse_streaming_increment( + '[TOOL_CALLS] {"name":"get_weather"', sample_tools) + + # Should send tool name with empty parameters + assert len(result.calls) == 1 + assert result.calls[0].name == "get_weather" + assert result.calls[0].parameters == "" + assert result.calls[0].tool_index == 0 + assert parser.current_tool_name_sent is True + + def test_parse_streaming_increment_tool_arguments(self, sample_tools): + """Test streaming parser handles incremental argument streaming.""" + parser = ConcreteToolParser() + + # First send tool name + result1 = parser.parse_streaming_increment( + '[TOOL_CALLS] {"name":"get_weather"', sample_tools) + # Should send tool name + assert len(result1.calls) == 1 + assert result1.calls[0].name == "get_weather" + + # Then send complete arguments (parser needs complete JSON to parse incrementally) + result2 = parser.parse_streaming_increment( + ',"arguments":{"location":"San Francisco"}}', sample_tools) + + # Should stream arguments or complete the tool call + # The base implementation uses partial JSON parsing, so it may return results + assert result2 is not None # Just verify it doesn't crash + + def test_parse_streaming_increment_complete_tool(self, sample_tools): + """Test streaming parser handles complete tool call.""" + parser = ConcreteToolParser() + + # Send complete tool call in one chunk + result = parser.parse_streaming_increment( + '[TOOL_CALLS] {"name":"get_weather","arguments":{"location":"Boston"}}', + sample_tools) + + # Should have sent tool name (first call) + assert len(result.calls) == 1 + assert result.calls[0].name == "get_weather" + + def test_parse_streaming_increment_invalid_tool_name(self, sample_tools): + """Test streaming parser handles invalid tool name.""" + parser = ConcreteToolParser() + + # Send invalid tool name + result = parser.parse_streaming_increment( + '[TOOL_CALLS] {"name":"invalid_tool"', sample_tools) + + # Should reset state + assert len(result.calls) == 0 + assert parser._buffer == "" + assert parser.current_tool_id == -1 + + def test_supports_structural_tag(self): + """Test supports_structural_tag returns True.""" + parser = ConcreteToolParser() + assert parser.supports_structural_tag() is True + + def test_structure_info(self): + """Test structure_info returns proper function.""" + parser = ConcreteToolParser() + func = parser.structure_info() + + info = func("test_function") + assert isinstance(info, StructureInfo) + assert "test_function" in info.begin + assert info.trigger == "[TOOL_CALLS]" + + +# ============================================================================ +# Qwen3ToolParser Tests +# ============================================================================ + + +class TestQwen3ToolParser: + """Test suite for Qwen3ToolParser class.""" + + def test_initialization(self): + """Test that Qwen3ToolParser initializes correctly.""" + parser = Qwen3ToolParser() + + assert parser.bot_token == "\n" + assert parser.eot_token == "\n" + assert parser.tool_call_separator == "\n" + assert parser._normal_text_buffer == "" + + def test_has_tool_call_true(self): + """Test has_tool_call returns True when tool call is present.""" + parser = Qwen3ToolParser() + text = 'Some text \n{"name":"get_weather"}\n' + + assert parser.has_tool_call(text) is True + + def test_has_tool_call_false(self): + """Test has_tool_call returns False when no tool call present.""" + parser = Qwen3ToolParser() + text = "Just some regular text without tool calls" + + assert parser.has_tool_call(text) is False + + def test_detect_and_parse_no_tool_call(self, sample_tools): + """Test detect_and_parse with text containing no tool calls.""" + parser = Qwen3ToolParser() + text = "This is just a regular response." + + result = parser.detect_and_parse(text, sample_tools) + + assert result.normal_text == "This is just a regular response." + assert len(result.calls) == 0 + + def test_detect_and_parse_single_tool(self, sample_tools): + """Test detect_and_parse with a single tool call.""" + parser = Qwen3ToolParser() + text = 'Normal text\n\n{"name":"get_weather","arguments":{"location":"NYC"}}\n' + + result = parser.detect_and_parse(text, sample_tools) + + assert result.normal_text == "Normal text" + assert len(result.calls) == 1 + assert result.calls[0].name == "get_weather" + assert json.loads(result.calls[0].parameters) == {"location": "NYC"} + + def test_detect_and_parse_multiple_tools(self, sample_tools): + """Test detect_and_parse with multiple tool calls.""" + parser = Qwen3ToolParser() + text = ( + '\n{"name":"get_weather","arguments":{"location":"LA"}}\n\n' + '\n{"name":"search_web","arguments":{"query":"AI"}}\n' + ) + + result = parser.detect_and_parse(text, sample_tools) + + assert len(result.calls) == 2 + assert result.calls[0].name == "get_weather" + assert result.calls[1].name == "search_web" + + def test_detect_and_parse_malformed_json(self, sample_tools): + """Test detect_and_parse handles malformed JSON gracefully.""" + parser = Qwen3ToolParser() + text = '\n{"name":"get_weather","arguments":MALFORMED}\n' + + result = parser.detect_and_parse(text, sample_tools) + + # Should return empty calls due to JSON parsing error + assert len(result.calls) == 0 + + def test_detect_and_parse_with_parameters_key(self, sample_tools): + """Test detect_and_parse handles 'parameters' key.""" + parser = Qwen3ToolParser() + text = '\n{"name":"search_web","parameters":{"query":"test"}}\n' + + result = parser.detect_and_parse(text, sample_tools) + + assert len(result.calls) == 1 + assert result.calls[0].name == "search_web" + assert json.loads(result.calls[0].parameters) == {"query": "test"} + + def test_parse_streaming_increment_normal_text(self, sample_tools): + """Test streaming parser handles normal text without tool calls.""" + parser = Qwen3ToolParser() + + result = parser.parse_streaming_increment("Hello, how can I help?", + sample_tools) + + assert result.normal_text == "Hello, how can I help?" + assert len(result.calls) == 0 + + def test_parse_streaming_increment_partial_bot_token(self, sample_tools): + """Test streaming parser buffers partial bot token.""" + parser = Qwen3ToolParser() + + # Send partial bot token + result = parser.parse_streaming_increment("\n", sample_tools) + + # Send partial JSON with name + result = parser.parse_streaming_increment('{"name":"get_weather"', + sample_tools) + + # Should send tool name + assert len(result.calls) == 1 + assert result.calls[0].name == "get_weather" + assert result.calls[0].parameters == "" + + # Send arguments + result = parser.parse_streaming_increment( + ',"arguments":{"location":"SF"}}\n', sample_tools) + + # Should stream arguments + assert len(result.calls) == 1 + assert json.loads(result.calls[0].parameters) == {"location": "SF"} + + def test_parse_streaming_increment_end_token_handling(self, sample_tools): + """Test streaming parser handles end token correctly.""" + parser = Qwen3ToolParser() + + # Send complete tool call + parser.parse_streaming_increment( + '\n{"name":"get_weather","arguments":{"location":"NYC"}}\n', + sample_tools) + + # The end token should be removed from normal text + # Check buffer state + assert parser._normal_text_buffer == "" + + def test_parse_streaming_increment_multiple_tools_streaming( + self, sample_tools): + """Test streaming parser handles multiple tool calls.""" + parser = Qwen3ToolParser() + + # First tool + parser.parse_streaming_increment('\n', sample_tools) + parser.parse_streaming_increment( + '{"name":"get_weather","arguments":{"location":"NYC"}}\n\n', + sample_tools) + + # Second tool + parser.parse_streaming_increment('\n', sample_tools) + result = parser.parse_streaming_increment('{"name":"search_web"', + sample_tools) + + # Should have started second tool + assert result.calls[0].name == "search_web" + assert result.calls[0].parameters == "" + assert result.calls[0].tool_index == 1 + + def test_structure_info_function(self): + """Test structure_info returns correct lambda function.""" + parser = Qwen3ToolParser() + func = parser.structure_info() + + info = func("test_function") + + assert isinstance(info, StructureInfo) + assert info.begin == '\n{"name":"test_function", "arguments":' + assert info.end == "}\n" + assert info.trigger == "" + + def test_structure_info_different_names(self): + """Test structure_info works with different function names.""" + parser = Qwen3ToolParser() + func = parser.structure_info() + + info1 = func("get_weather") + info2 = func("search_web") + + assert "get_weather" in info1.begin + assert "search_web" in info2.begin + assert info1.end == info2.end == "}\n" + + def test_qwen3_format_compliance(self, sample_tools): + """Test that Qwen3ToolParser follows the documented format structure.""" + parser = Qwen3ToolParser() + + # Test the exact format from the docstring + text = '\n{"name":"get_weather", "arguments":{"location":"Tokyo"}}\n' + + result = parser.detect_and_parse(text, sample_tools) + + assert len(result.calls) == 1 + assert result.calls[0].name == "get_weather" + assert json.loads(result.calls[0].parameters) == {"location": "Tokyo"} + + def test_undefined_tool_in_qwen3_format(self, sample_tools): + """Test Qwen3ToolParser handles undefined tool gracefully.""" + parser = Qwen3ToolParser() + text = '\n{"name":"undefined_func","arguments":{}}\n' + + result = parser.detect_and_parse(text, sample_tools) + + # Should not return any calls for undefined function + assert len(result.calls) == 0 + + +# ============================================================================ +# Integration Tests +# ============================================================================ + + +class TestToolParserIntegration: + """Integration tests for tool parsers.""" + + def test_end_to_end_single_tool(self, sample_tools): + """Test end-to-end parsing of a single tool call.""" + parser = Qwen3ToolParser() + + # Simulate streaming + chunks = [ + "\n", '{"name":"get', '_weather"', ',"arguments":', + '{"location"', ':"Paris"}}\n', '' + ] + + results = [] + for chunk in chunks: + result = parser.parse_streaming_increment(chunk, sample_tools) + if result.calls or result.normal_text: + results.append(result) + + # Should have received tool name and arguments + assert any(r.calls for r in results) + + def test_mixed_content_and_tool_calls(self, sample_tools): + """Test parsing text that mixes normal content with tool calls.""" + parser = Qwen3ToolParser() + + text = ( + 'I will check the weather for you.\n' + '\n{"name":"get_weather","arguments":{"location":"London"}}\n\n' + 'Let me search that for you.') + + result = parser.detect_and_parse(text, sample_tools) + + assert "I will check the weather for you." in result.normal_text + assert len(result.calls) == 1 + assert result.calls[0].name == "get_weather" + + def test_parser_state_reset(self, sample_tools): + """Test that parser state can be used for multiple requests.""" + parser = Qwen3ToolParser() + + # First request + result1 = parser.detect_and_parse( + '\n{"name":"get_weather","arguments":{"location":"NYC"}}\n', + sample_tools) + + # Reset internal state for new request + parser2 = Qwen3ToolParser() + + # Second request + result2 = parser2.detect_and_parse( + '\n{"name":"search_web","arguments":{"query":"test"}}\n', + sample_tools) + + assert result1.calls[0].name == "get_weather" + assert result2.calls[0].name == "search_web" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])