From 2aade46d18afbcaf9ee5be2fc1548f077ce03bbf Mon Sep 17 00:00:00 2001
From: Pengyun Lin <81065165+LinPoly@users.noreply.github.com>
Date: Wed, 29 Oct 2025 15:48:29 +0800
Subject: [PATCH] [TRTLLM-8214][feat] Support Qwen3 tool parser (#8216)
Signed-off-by: Pengyun Lin <81065165+LinPoly@users.noreply.github.com>
---
requirements.txt | 1 +
tensorrt_llm/commands/serve.py | 16 +-
tensorrt_llm/serve/chat_utils.py | 9 +
tensorrt_llm/serve/openai_server.py | 3 +
tensorrt_llm/serve/postprocess_handlers.py | 112 +++-
tensorrt_llm/serve/tool_parser/__init__.py | 3 +
.../serve/tool_parser/base_tool_parser.py | 324 ++++++++++
tensorrt_llm/serve/tool_parser/core_types.py | 35 +
.../serve/tool_parser/qwen3_tool_parser.py | 114 ++++
.../serve/tool_parser/tool_parser_factory.py | 21 +
tensorrt_llm/serve/tool_parser/utils.py | 56 ++
tests/integration/defs/test_e2e.py | 7 +
.../integration/test_lists/test-db/l0_a10.yml | 3 +
tests/unittest/llmapi/apps/README.md | 2 +-
.../llmapi/apps/_test_openai_tool_call.py | 121 ++++
.../unittest/llmapi/apps/test_tool_parsers.py | 597 ++++++++++++++++++
16 files changed, 1405 insertions(+), 19 deletions(-)
create mode 100644 tensorrt_llm/serve/tool_parser/__init__.py
create mode 100644 tensorrt_llm/serve/tool_parser/base_tool_parser.py
create mode 100644 tensorrt_llm/serve/tool_parser/core_types.py
create mode 100644 tensorrt_llm/serve/tool_parser/qwen3_tool_parser.py
create mode 100644 tensorrt_llm/serve/tool_parser/tool_parser_factory.py
create mode 100644 tensorrt_llm/serve/tool_parser/utils.py
create mode 100644 tests/unittest/llmapi/apps/_test_openai_tool_call.py
create mode 100644 tests/unittest/llmapi/apps/test_tool_parsers.py
diff --git a/requirements.txt b/requirements.txt
index 702004c8e1..7906df222a 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -78,3 +78,4 @@ nvidia-cutlass-dsl==4.2.1; python_version >= "3.10"
numba-cuda>=0.19.0 # WAR for nvbugs/5501820
plotly
numexpr<2.14.0 # WAR for attempted use of nonexistent numpy.typing
+partial_json_parser
diff --git a/tensorrt_llm/commands/serve.py b/tensorrt_llm/commands/serve.py
index c28d199408..f4f188fdea 100644
--- a/tensorrt_llm/commands/serve.py
+++ b/tensorrt_llm/commands/serve.py
@@ -33,6 +33,7 @@ from tensorrt_llm.llmapi.mpi_session import find_free_port
from tensorrt_llm.llmapi.reasoning_parser import ReasoningParserFactory
from tensorrt_llm.logger import logger, severity_map
from tensorrt_llm.serve import OpenAIDisaggServer, OpenAIServer
+from tensorrt_llm.serve.tool_parser import ToolParserFactory
# Global variable to store the Popen object of the child process
_child_p_global: Optional[subprocess.Popen] = None
@@ -150,6 +151,7 @@ def launch_server(
host: str,
port: int,
llm_args: dict,
+ tool_parser: Optional[str] = None,
metadata_server_cfg: Optional[MetadataServerConfig] = None,
server_role: Optional[ServerRole] = None,
disagg_cluster_config: Optional[DisaggClusterConfig] = None,
@@ -173,6 +175,7 @@ def launch_server(
server = OpenAIServer(llm=llm,
model=model,
+ tool_parser=tool_parser,
server_role=server_role,
metadata_server_cfg=metadata_server_cfg,
disagg_cluster_config=disagg_cluster_config,
@@ -311,6 +314,12 @@ class ChoiceWithAlias(click.Choice):
default=None,
help="[Experimental] Specify the parser for reasoning models.",
)
+@click.option(
+ "--tool_parser",
+ type=click.Choice(ToolParserFactory.parsers.keys()),
+ default=None,
+ help="[Experimental] Specify the parser for tool models.",
+)
@click.option("--metadata_server_config_file",
type=str,
default=None,
@@ -352,7 +361,8 @@ def serve(
gpus_per_node: Optional[int], kv_cache_free_gpu_memory_fraction: float,
num_postprocess_workers: int, trust_remote_code: bool,
extra_llm_api_options: Optional[str], reasoning_parser: Optional[str],
- metadata_server_config_file: Optional[str], server_role: Optional[str],
+ tool_parser: Optional[str], metadata_server_config_file: Optional[str],
+ server_role: Optional[str],
fail_fast_on_attention_window_too_large: bool,
otlp_traces_endpoint: Optional[str], enable_chunked_prefill: bool,
disagg_cluster_uri: Optional[str], media_io_kwargs: Optional[str]):
@@ -423,8 +433,8 @@ def serve(
multimodal_server_config = MultimodalServerConfig(
media_io_kwargs=parsed_media_io_kwargs)
- launch_server(host, port, llm_args, metadata_server_cfg, server_role,
- disagg_cluster_config, multimodal_server_config)
+ launch_server(host, port, llm_args, tool_parser, metadata_server_cfg,
+ server_role, disagg_cluster_config, multimodal_server_config)
@click.command("mm_embedding_serve")
diff --git a/tensorrt_llm/serve/chat_utils.py b/tensorrt_llm/serve/chat_utils.py
index 5f9f5d84ab..acda26b511 100644
--- a/tensorrt_llm/serve/chat_utils.py
+++ b/tensorrt_llm/serve/chat_utils.py
@@ -1,3 +1,4 @@
+import uuid
from functools import partial
from typing import (Any, Callable, Coroutine, Dict, Iterable, List, Literal,
Optional, Tuple, TypeAlias, TypedDict, Union, cast)
@@ -220,3 +221,11 @@ def check_multiple_response(n: int, backend: Optional[str]):
if n > 1 and backend == "pytorch":
raise ValueError(
"Multiple response is not supported in PyTorch workflow")
+
+
+def make_tool_call_id(id_type: str = "random", func_name=None, idx=None):
+ if id_type == "kimi_k2":
+ return f"functions.{func_name}:{idx}"
+ else:
+ # by default return random
+ return f"chatcmpl-tool-{uuid.uuid4().hex}"
diff --git a/tensorrt_llm/serve/openai_server.py b/tensorrt_llm/serve/openai_server.py
index f6eb18d11c..6edc7c0744 100644
--- a/tensorrt_llm/serve/openai_server.py
+++ b/tensorrt_llm/serve/openai_server.py
@@ -78,12 +78,14 @@ class OpenAIServer:
def __init__(self,
llm: Union[LLM, MultimodalEncoder],
model: str,
+ tool_parser: Optional[str],
server_role: Optional[ServerRole],
metadata_server_cfg: MetadataServerConfig,
disagg_cluster_config: Optional[DisaggClusterConfig] = None,
multimodal_server_config: Optional[MultimodalServerConfig] = None):
self.llm = llm
self.tokenizer = llm.tokenizer
+ self.tool_parser = tool_parser
self.metadata_server = create_metadata_server(metadata_server_cfg)
self.disagg_cluster_config = disagg_cluster_config
self.multimodal_server_config = multimodal_server_config
@@ -532,6 +534,7 @@ class OpenAIServer:
prompt["multi_modal_data"] = mm_data
postproc_args.reasoning_parser = self.llm.args.reasoning_parser
+ postproc_args.tool_parser = self.tool_parser
if conversation and conversation[-1].get(
"content") and conversation[-1].get("role") == get_role():
postproc_args.last_message_content = conversation[-1]["content"]
diff --git a/tensorrt_llm/serve/postprocess_handlers.py b/tensorrt_llm/serve/postprocess_handlers.py
index ece007b539..f9d78a5354 100644
--- a/tensorrt_llm/serve/postprocess_handlers.py
+++ b/tensorrt_llm/serve/postprocess_handlers.py
@@ -10,6 +10,7 @@ from ..llmapi.reasoning_parser import (BaseReasoningParser,
ReasoningParserFactory)
from ..llmapi.tokenizer import TransformersTokenizer
# yapf: disable
+from .chat_utils import make_tool_call_id
from .harmony_adapter import (handle_non_streaming_response,
handle_streaming_response)
from .openai_protocol import (ChatCompletionLogProbs,
@@ -23,9 +24,13 @@ from .openai_protocol import (ChatCompletionLogProbs,
CompletionRequest, CompletionResponse,
CompletionResponseChoice,
CompletionResponseStreamChoice,
- CompletionStreamResponse, DeltaMessage,
- FunctionCall, PromptTokensDetails, StreamOptions,
- ToolCall, UsageInfo, to_disaggregated_params)
+ CompletionStreamResponse, DeltaFunctionCall,
+ DeltaMessage, DeltaToolCall, FunctionCall,
+ PromptTokensDetails, StreamOptions, ToolCall,
+ UsageInfo, to_disaggregated_params)
+from .tool_parser.base_tool_parser import BaseToolParser
+from .tool_parser.core_types import ToolCallItem
+from .tool_parser.tool_parser_factory import ToolParserFactory
# yapf: enable
@@ -33,8 +38,8 @@ from .openai_protocol import (ChatCompletionLogProbs,
@dataclass(kw_only=True)
class ChatPostprocArgs(PostprocArgs):
echo: bool = False
- role: str = None
- model: str = None
+ role: str
+ model: str
num_choices: int = 1
tools: Optional[List[ChatCompletionToolsParam]] = None
tool_choice: Optional[Union[Literal["none"],
@@ -44,8 +49,11 @@ class ChatPostprocArgs(PostprocArgs):
stream_options: Optional[StreamOptions] = None
last_message_content: Optional[str] = None
reasoning_parser: Optional[str] = None
+ tool_parser: Optional[str] = None
reasoning_parser_dict: dict[int, BaseReasoningParser] = field(
default_factory=dict)
+ tool_parser_dict: dict[int, BaseToolParser] = field(default_factory=dict)
+ has_tool_call: dict[int, bool] = field(default_factory=dict)
@classmethod
def from_request(cls, request: ChatCompletionRequest):
@@ -116,6 +124,31 @@ def apply_reasoning_parser(args: ChatPostprocArgs, output_index: int, text: str,
return content, reasoning_content
+def apply_tool_parser(args: ChatPostprocArgs, output_index: int, text: str,
+ streaming: bool) -> Tuple[str, List[ToolCallItem]]:
+ tool_parser = None
+ tools = args.tools
+ if args.tool_parser is not None and tools is not None:
+ if output_index not in args.tool_parser_dict:
+ args.tool_parser_dict[
+ output_index] = ToolParserFactory.create_tool_parser(
+ args.tool_parser)
+ tool_parser = args.tool_parser_dict[output_index]
+
+ if tool_parser is not None and tools is not None:
+ if not streaming:
+ result = tool_parser.detect_and_parse(text, tools)
+ else:
+ result = tool_parser.parse_streaming_increment(text, tools)
+ normal_text, calls = result.normal_text, result.calls
+ if result.calls:
+ args.has_tool_call[output_index] = True
+ else:
+ normal_text, calls = text, []
+
+ return normal_text, calls
+
+
@nvtx_range_debug("chat_stream_post_processor")
def chat_stream_post_processor(rsp: GenerationResultBase,
args: ChatPostprocArgs) -> List[str]:
@@ -176,27 +209,63 @@ def chat_stream_post_processor(rsp: GenerationResultBase,
if args.tool_choice and type(
args.tool_choice) is ChatCompletionNamedToolChoiceParam:
delta_message = DeltaMessage(tool_calls=[
- ToolCall(function=FunctionCall(
- name=args.tool_choice.function.name, arguments=delta_text))
- ])
+ DeltaToolCall(
+ function=DeltaFunctionCall(
+ name=args.tool_choice.function.name,
+ arguments=delta_text),
+ index=i,
+ ),
+ ], )
else:
- delta_message = DeltaMessage(content=delta_text,
- reasoning_content=reasoning_delta_text)
+ delta_text, calls = apply_tool_parser(args, i, delta_text, True)
+ tool_calls = []
+ for call_item in calls:
+ # Tool call ID should be generated only once per tool call
+ if call_item.name:
+ # First chunk: include ID and function name
+ tool_call_id = make_tool_call_id()
+ function_name = call_item.name
+ else:
+ # Subsequent chunks: null ID and name for argument deltas
+ tool_call_id = None
+ function_name = None
+
+ tool_calls.append(
+ DeltaToolCall(
+ id=tool_call_id,
+ index=call_item.tool_index,
+ function=DeltaFunctionCall(
+ name=function_name,
+ arguments=call_item.parameters,
+ ),
+ ))
+ if tool_calls or delta_text or reasoning_delta_text or output.finish_reason:
+ delta_message = DeltaMessage(
+ content=delta_text,
+ reasoning_content=reasoning_delta_text,
+ tool_calls=tool_calls if tool_calls else None)
+ else:
+ continue
choice = ChatCompletionResponseStreamChoice(
index=i,
delta=delta_message,
- finish_reason=None,
avg_decoded_tokens_per_iter=getattr(rsp,
'avg_decoded_tokens_per_iter',
- None))
+ None),
+ stop_reason=output.stop_reason,
+ )
if args.return_logprobs:
logprobs = output.logprobs_diff
token_ids = output.token_ids_diff
choice.logprobs = create_logprobs(token_ids, args.tokenizer,
logprobs, args.top_logprobs)
if output.finish_reason is not None:
- choice.finish_reason = output.finish_reason
+ if output.finish_reason == "stop" and args.has_tool_call.get(
+ i, False):
+ choice.finish_reason = "tool_calls"
+ else:
+ choice.finish_reason = output.finish_reason
choice.stop_reason = output.stop_reason
finish_reason_sent[i] = True
chunk = ChatCompletionStreamResponse(choices=[choice], model=args.model)
@@ -247,21 +316,34 @@ def chat_response_post_processor(
name=args.tool_choice.function.name, arguments=text))
])
else:
+ if text is None:
+ text = ""
+ text, calls = apply_tool_parser(args, output.index, text, False)
+ tool_calls = [
+ ToolCall(function=FunctionCall(name=call.name or "",
+ arguments=call.parameters))
+ for call in calls
+ ]
message = ChatMessage(role=role,
content=text,
- reasoning_content=reasoning_text)
+ reasoning_content=reasoning_text,
+ tool_calls=tool_calls)
disaggregated_params = to_disaggregated_params(
output.disaggregated_params)
choice = ChatCompletionResponseChoice(
index=output.index,
message=message,
- finish_reason=output.finish_reason,
stop_reason=output.stop_reason,
disaggregated_params=disaggregated_params,
avg_decoded_tokens_per_iter=getattr(rsp,
'avg_decoded_tokens_per_iter',
None),
)
+ if output.finish_reason == "stop" and args.has_tool_call.get(
+ output.index, False):
+ choice.finish_reason = "tool_calls"
+ else:
+ choice.finish_reason = output.finish_reason
if args.return_logprobs:
choice.logprobs = create_logprobs(output.token_ids, args.tokenizer,
diff --git a/tensorrt_llm/serve/tool_parser/__init__.py b/tensorrt_llm/serve/tool_parser/__init__.py
new file mode 100644
index 0000000000..d862620c95
--- /dev/null
+++ b/tensorrt_llm/serve/tool_parser/__init__.py
@@ -0,0 +1,3 @@
+from .tool_parser_factory import ToolParserFactory
+
+__all__ = ["ToolParserFactory"]
diff --git a/tensorrt_llm/serve/tool_parser/base_tool_parser.py b/tensorrt_llm/serve/tool_parser/base_tool_parser.py
new file mode 100644
index 0000000000..9fa87ec0bf
--- /dev/null
+++ b/tensorrt_llm/serve/tool_parser/base_tool_parser.py
@@ -0,0 +1,324 @@
+# Adapted from https://github.com/sgl-project/sglang/blob/083629c23564e1a64deaa052f1df5c5d914358d8/python/sglang/srt/function_call/base_format_detector.py
+import json
+from abc import ABC, abstractmethod
+from typing import Any, Dict, List
+
+from partial_json_parser.core.exceptions import MalformedJSON
+from partial_json_parser.core.options import Allow
+
+from tensorrt_llm.logger import logger
+
+from ..openai_protocol import ChatCompletionToolsParam as Tool
+from .core_types import StreamingParseResult, ToolCallItem, _GetInfoFunc
+from .utils import find_common_prefix, is_complete_json, partial_json_loads
+
+
+class BaseToolParser(ABC):
+ """Base class providing two sets of interfaces: one-time and streaming incremental."""
+
+ def __init__(self):
+ # Streaming state management
+ # Buffer for accumulating incomplete patterns that arrive across multiple streaming chunks
+ self._buffer = ""
+ # Stores complete tool call info (name and arguments) for each tool being parsed.
+ # Used by serving layer for completion handling when streaming ends.
+ # Format: [{"name": str, "arguments": dict}, ...]
+ self.prev_tool_call_arr: List[Dict] = []
+ # Index of currently streaming tool call. Starts at -1 (no active tool),
+ # increments as each tool completes. Tracks which tool's arguments are streaming.
+ self.current_tool_id: int = -1
+ # Flag for whether current tool's name has been sent to client.
+ # Tool names sent first with empty parameters, then arguments stream incrementally.
+ self.current_tool_name_sent: bool = False
+ # Tracks raw JSON string content streamed to client for each tool's arguments.
+ # Critical for serving layer to calculate remaining content when streaming ends.
+ # Each index corresponds to a tool_id. Example: ['{"location": "San Francisco"', '{"temp": 72']
+ self.streamed_args_for_tool: List[str] = []
+
+ # Token configuration (override in subclasses)
+ self.bot_token = "" # nosec B105
+ self.eot_token = "" # nosec B105
+ self.tool_call_separator = ", "
+
+ def _get_tool_indices(self, tools: List[Tool]) -> Dict[str, int]:
+ """
+ Get a mapping of tool names to their indices in the tools list.
+
+ This utility method creates a dictionary mapping function names to their
+ indices in the tools list, which is commonly needed for tool validation
+ and ToolCallItem creation.
+
+ Args:
+ tools: List of available tools
+
+ Returns:
+ Dictionary mapping tool names to their indices
+ """
+ return {
+ tool.function.name: i
+ for i, tool in enumerate(tools) if tool.function.name
+ }
+
+ def parse_base_json(self, action: Any,
+ tools: List[Tool]) -> List[ToolCallItem]:
+ tool_indices = self._get_tool_indices(tools)
+ if not isinstance(action, list):
+ action = [action]
+
+ results = []
+ for act in action:
+ name = act.get("name")
+ if name and name in tool_indices:
+ results.append(
+ ToolCallItem(
+ tool_index=
+ -1, # Caller should update this based on the actual tools array called
+ name=name,
+ parameters=json.dumps(
+ act.get("parameters") or act.get("arguments", {}),
+ ensure_ascii=False,
+ ),
+ ))
+ else:
+ logger.warning(
+ f"Model attempted to call undefined function: {name}")
+
+ return results
+
+ @abstractmethod
+ def detect_and_parse(self, text: str,
+ tools: List[Tool]) -> StreamingParseResult:
+ """
+ Parses the text in one go. Returns success=True if the format matches, otherwise False.
+ Note that leftover_text here represents "content that this parser will not consume further".
+ """
+ action = json.loads(text)
+ return StreamingParseResult(calls=self.parse_base_json(action, tools))
+
+ def _ends_with_partial_token(self, buffer: str, bot_token: str) -> int:
+ """
+ Check if buffer ends with a partial bot_token.
+ Return the length of the partial bot_token.
+
+ For some format, the bot_token is not a token in model's vocabulary, such as
+ `[TOOL_CALLS] [` in Mistral.
+ """
+ for i in range(1, min(len(buffer) + 1, len(bot_token))):
+ if bot_token.startswith(buffer[-i:]):
+ return i
+ return 0
+
+ def parse_streaming_increment(self, new_text: str,
+ tools: List[Tool]) -> StreamingParseResult:
+ """
+ Streaming incremental parsing with tool validation.
+
+ This base implementation works best with formats where:
+ 1. bot_token is followed immediately by JSON (e.g., bot_token + JSON_array)
+ 2. JSON can be parsed incrementally using partial_json_loads
+ 3. Multiple tool calls are separated by "; " or ", "
+
+ Examples of incompatible formats (need custom implementation, may reuse some logic from this class):
+ - Each tool call is wrapped in a separate block: See Qwen25Detector
+ - Multiple separate blocks: [TOOL_CALLS] [...] \n [TOOL_CALLS] [...]
+ - Tool call is Pythonic style
+
+ For incompatible formats, detectors should override this method with custom logic.
+ """
+ # Append new text to buffer
+ self._buffer += new_text
+ current_text = self._buffer
+
+ # The current_text has tool_call if it is the start of a new tool call sequence
+ # or it is the start of a new tool call after a tool call separator, when there is a previous tool call
+ if not (self.has_tool_call(current_text) or
+ (self.current_tool_id > 0
+ and current_text.startswith(self.tool_call_separator))):
+ # Only clear buffer if we're sure no tool call is starting
+ if not self._ends_with_partial_token(self._buffer, self.bot_token):
+ normal_text = self._buffer
+ self._buffer = ""
+ if self.eot_token in normal_text:
+ normal_text = normal_text.replace(self.eot_token, "")
+ return StreamingParseResult(normal_text=normal_text)
+ else:
+ # Might be partial bot_token, keep buffering
+ return StreamingParseResult()
+
+ # Build tool indices if not already built
+ if not hasattr(self, "_tool_indices"):
+ self._tool_indices = self._get_tool_indices(tools)
+
+ flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
+
+ try:
+ try:
+ tool_call_pos = current_text.find(self.bot_token)
+ if tool_call_pos != -1:
+ start_idx = tool_call_pos + len(self.bot_token)
+ elif self.current_tool_id > 0 and current_text.startswith(
+ self.tool_call_separator):
+ start_idx = len(self.tool_call_separator)
+ else:
+ start_idx = 0
+
+ if start_idx >= len(current_text):
+ return StreamingParseResult()
+
+ (obj, end_idx) = partial_json_loads(current_text[start_idx:],
+ flags)
+
+ is_current_complete = is_complete_json(
+ current_text[start_idx:start_idx + end_idx])
+
+ # Validate tool name if present
+ if "name" in obj and obj["name"] not in self._tool_indices:
+ # Invalid tool name - reset state
+ self._buffer = ""
+ self.current_tool_id = -1
+ self.current_tool_name_sent = False
+ if self.streamed_args_for_tool:
+ self.streamed_args_for_tool.pop()
+ return StreamingParseResult()
+
+ # Handle parameters/arguments consistency
+ # NOTE: we assume here that the obj is always partial of a single tool call
+ if "parameters" in obj:
+ assert ("arguments" not in obj
+ ), "model generated both parameters and arguments"
+ obj["arguments"] = obj["parameters"]
+
+ current_tool_call = obj
+
+ except MalformedJSON:
+ return StreamingParseResult()
+
+ if not current_tool_call:
+ return StreamingParseResult()
+
+ # Case 1: Handle tool name streaming
+ # This happens when we encounter a tool but haven't sent its name yet
+ if not self.current_tool_name_sent:
+ function_name = current_tool_call.get("name")
+
+ if function_name and function_name in self._tool_indices:
+ # If this is a new tool (current_tool_id was -1), initialize it
+ if self.current_tool_id == -1:
+ self.current_tool_id = 0
+ self.streamed_args_for_tool.append("")
+ # If this is a subsequent tool, ensure streamed_args_for_tool is large enough
+ elif self.current_tool_id >= len(
+ self.streamed_args_for_tool):
+ while len(self.streamed_args_for_tool
+ ) <= self.current_tool_id:
+ self.streamed_args_for_tool.append("")
+
+ # Send the tool name with empty parameters
+ res = StreamingParseResult(calls=[
+ ToolCallItem(
+ tool_index=self.current_tool_id,
+ name=function_name,
+ parameters="",
+ )
+ ], )
+ self.current_tool_name_sent = True
+ else:
+ res = StreamingParseResult()
+
+ # Case 2: Handle streaming arguments
+ # This happens when we've already sent the tool name and now need to stream arguments incrementally
+ else:
+ cur_arguments = current_tool_call.get("arguments")
+ res = StreamingParseResult()
+
+ if cur_arguments:
+ # Calculate how much of the arguments we've already streamed
+ sent = len(
+ self.streamed_args_for_tool[self.current_tool_id])
+ cur_args_json = json.dumps(cur_arguments)
+ prev_arguments = None
+ if self.current_tool_id < len(self.prev_tool_call_arr):
+ prev_arguments = self.prev_tool_call_arr[
+ self.current_tool_id].get("arguments")
+
+ argument_diff = None
+
+ # If the current tool's JSON is complete, send all remaining arguments
+ if is_current_complete:
+ argument_diff = cur_args_json[sent:]
+ completing_tool_id = (
+ self.current_tool_id
+ ) # Save the ID of the tool that's completing
+
+ # Only remove the processed portion, keep unprocessed content
+ self._buffer = current_text[start_idx + end_idx:]
+
+ if self.current_tool_id < len(self.prev_tool_call_arr):
+ self.prev_tool_call_arr[
+ self.current_tool_id].clear()
+ self.current_tool_name_sent = False
+ self.streamed_args_for_tool[self.current_tool_id] = ""
+ self.current_tool_id += 1
+
+ # If the tool is still being parsed, send incremental changes
+ elif prev_arguments:
+ prev_args_json = json.dumps(prev_arguments)
+ if cur_args_json != prev_args_json:
+ prefix = find_common_prefix(prev_args_json,
+ cur_args_json)
+ argument_diff = prefix[sent:]
+
+ # Send the argument diff if there's something new
+ if argument_diff is not None:
+ # Use the correct tool_index: completing_tool_id for completed tools, current_tool_id for ongoing
+ tool_index_to_use = (completing_tool_id
+ if is_current_complete else
+ self.current_tool_id)
+ res = StreamingParseResult(calls=[
+ ToolCallItem(
+ tool_index=tool_index_to_use,
+ parameters=argument_diff,
+ )
+ ], )
+ if not is_current_complete:
+ self.streamed_args_for_tool[
+ self.current_tool_id] += argument_diff
+
+ # Update prev_tool_call_arr with current state
+ if self.current_tool_id >= 0:
+ # Ensure prev_tool_call_arr is large enough
+ while len(self.prev_tool_call_arr) <= self.current_tool_id:
+ self.prev_tool_call_arr.append({})
+ self.prev_tool_call_arr[
+ self.current_tool_id] = current_tool_call
+
+ return res
+
+ except Exception as e:
+ logger.error(f"Error in parse_streaming_increment: {e}")
+ return StreamingParseResult()
+
+ @abstractmethod
+ def has_tool_call(self, text: str) -> bool:
+ """
+ Check if the given text contains function call markers specific to this format.
+ """
+ raise NotImplementedError()
+
+ def supports_structural_tag(self) -> bool:
+ """Return True if this detector supports structural tag format."""
+ return True
+
+ @abstractmethod
+ def structure_info(self) -> _GetInfoFunc:
+ """
+ Return a function that creates StructureInfo for constrained generation.
+
+ The returned function takes a tool name and returns a StructureInfo object
+ containing the begin/end patterns and trigger tokens needed for constrained
+ generation of function calls in this format.
+
+ Returns:
+ A function that takes a tool name (str) and returns StructureInfo
+ """
+ raise NotImplementedError()
diff --git a/tensorrt_llm/serve/tool_parser/core_types.py b/tensorrt_llm/serve/tool_parser/core_types.py
new file mode 100644
index 0000000000..88d14696a6
--- /dev/null
+++ b/tensorrt_llm/serve/tool_parser/core_types.py
@@ -0,0 +1,35 @@
+# Adapted from https://github.com/sgl-project/sglang/blob/083629c23564e1a64deaa052f1df5c5d914358d8/python/sglang/srt/function_call/qwen25_detector.py
+from dataclasses import dataclass
+from typing import Callable, List, Optional
+
+from pydantic import BaseModel
+
+
+class ToolCallItem(BaseModel):
+ """Simple encapsulation of the parsed ToolCall result for easier usage in streaming contexts."""
+
+ tool_index: int
+ name: Optional[str] = None
+ parameters: str # JSON string
+
+
+class StreamingParseResult(BaseModel):
+ """Result of streaming incremental parsing."""
+
+ normal_text: str = ""
+ calls: List[ToolCallItem] = []
+
+
+@dataclass
+class StructureInfo:
+ begin: str
+ end: str
+ trigger: str
+
+
+"""
+Helper alias of function
+Usually it is a function that takes a name string and returns a StructureInfo object,
+which can be used to construct a structural_tag object
+"""
+_GetInfoFunc = Callable[[str], StructureInfo]
diff --git a/tensorrt_llm/serve/tool_parser/qwen3_tool_parser.py b/tensorrt_llm/serve/tool_parser/qwen3_tool_parser.py
new file mode 100644
index 0000000000..298389f47e
--- /dev/null
+++ b/tensorrt_llm/serve/tool_parser/qwen3_tool_parser.py
@@ -0,0 +1,114 @@
+# Adapted from https://github.com/sgl-project/sglang/blob/083629c23564e1a64deaa052f1df5c5d914358d8/python/sglang/srt/function_call/qwen25_detector.py
+import json
+import re
+from typing import List
+
+from tensorrt_llm.logger import logger
+
+from ..openai_protocol import ChatCompletionToolsParam as Tool
+from .base_tool_parser import BaseToolParser
+from .core_types import StreamingParseResult, StructureInfo, _GetInfoFunc
+
+
+class Qwen3ToolParser(BaseToolParser):
+ """
+ Detector for Qwen 2.5 and Qwen 3 model function call format.
+
+ Format Structure:
+ ```
+ \n{"name":"func1", "arguments":{...}}\n\n\n{"name":"func2", "arguments":{...}}\n
+ ```
+
+ Key Components:
+ - Tool Call Tags: `` and `` wrap each individual call
+ - Function Call Object: JSON object with "name" and "arguments" fields
+
+ Reference: https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct?chat_template=default
+ """
+
+ def __init__(self):
+ """
+ Initializes the detector with necessary state variables.
+ """
+ super().__init__()
+ self.bot_token = "\n" # nosec B105
+ self.eot_token = "\n" # nosec B105
+ self.tool_call_separator = "\n"
+ self._normal_text_buffer = "" # Buffer for handling partial end tokens
+
+ def has_tool_call(self, text: str) -> bool:
+ """Check if the text contains a Qwen 3 format tool call."""
+ return self.bot_token in text
+
+ def detect_and_parse(self, text: str,
+ tools: List[Tool]) -> StreamingParseResult:
+ """
+ One-time parsing: Detects and parses tool calls in the provided text.
+
+ :param text: The complete text to parse.
+ :param tools: List of available tools.
+ :return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
+ """
+ idx = text.find(self.bot_token)
+ normal_text = text[:idx].strip() if idx != -1 else text
+ if self.bot_token not in text:
+ return StreamingParseResult(normal_text=normal_text, calls=[])
+
+ # Find all \n...\n blocks
+ pattern = rf"{re.escape(self.bot_token)}(.*?){re.escape(self.eot_token)}"
+ match_result_list = re.findall(pattern, text, re.DOTALL)
+ calls = []
+ for match_result in match_result_list:
+ try:
+ parsed_call = json.loads(match_result.strip())
+ calls.extend(self.parse_base_json(parsed_call, tools))
+ except json.JSONDecodeError as e:
+ logger.warning(
+ f"Failed to parse JSON part: {match_result}, JSON parse error: {str(e)}"
+ )
+ continue
+ return StreamingParseResult(normal_text=normal_text, calls=calls)
+
+ def parse_streaming_increment(self, new_text: str,
+ tools: List[Tool]) -> StreamingParseResult:
+ """
+ Streaming incremental parsing for Qwen 3 tool calls.
+ Uses base class implementation with buffering to handle partial end tokens.
+ """
+ result = super().parse_streaming_increment(new_text, tools)
+
+ # Handle partial end tokens that are streamed character by character
+ if result.normal_text:
+ self._normal_text_buffer += result.normal_text
+
+ # Check if buffer contains complete end token (without leading newline)
+ end_token_without_newline = self.eot_token[1:] # ""
+ if end_token_without_newline in self._normal_text_buffer:
+ cleaned_text = self._normal_text_buffer.replace(
+ end_token_without_newline, "")
+ self._normal_text_buffer = ""
+ result.normal_text = cleaned_text
+ else:
+ # Check if buffer might contain partial end token at the end
+ partial_match_len = self._ends_with_partial_token(
+ self._normal_text_buffer, end_token_without_newline)
+
+ if partial_match_len:
+ # Keep potential partial match in buffer, return the rest
+ result.normal_text = self._normal_text_buffer[:
+ -partial_match_len]
+ self._normal_text_buffer = self._normal_text_buffer[
+ -partial_match_len:]
+ else:
+ # No partial match, return all buffered text
+ result.normal_text = self._normal_text_buffer
+ self._normal_text_buffer = ""
+
+ return result
+
+ def structure_info(self) -> _GetInfoFunc:
+ return lambda name: StructureInfo(
+ begin='\n{"name":"' + name + '", "arguments":',
+ end="}\n",
+ trigger="",
+ )
diff --git a/tensorrt_llm/serve/tool_parser/tool_parser_factory.py b/tensorrt_llm/serve/tool_parser/tool_parser_factory.py
new file mode 100644
index 0000000000..73b02510a6
--- /dev/null
+++ b/tensorrt_llm/serve/tool_parser/tool_parser_factory.py
@@ -0,0 +1,21 @@
+from typing import Type
+
+from .base_tool_parser import BaseToolParser
+from .qwen3_tool_parser import Qwen3ToolParser
+
+
+class ToolParserFactory:
+ parsers: dict[str, Type[BaseToolParser]] = {
+ "qwen3": Qwen3ToolParser,
+ }
+
+ @staticmethod
+ def create_tool_parser(tool_parser: str) -> BaseToolParser:
+ try:
+ tool_parser_class = ToolParserFactory.parsers[tool_parser.lower()]
+ return tool_parser_class()
+ except KeyError as e:
+ raise ValueError(
+ f"Invalid tool_parser: {tool_parser}\n"
+ f"Supported parsers: {list(ToolParserFactory.parsers.keys())}"
+ ) from e
diff --git a/tensorrt_llm/serve/tool_parser/utils.py b/tensorrt_llm/serve/tool_parser/utils.py
new file mode 100644
index 0000000000..7666036be5
--- /dev/null
+++ b/tensorrt_llm/serve/tool_parser/utils.py
@@ -0,0 +1,56 @@
+# Adapted from https://github.com/sgl-project/sglang/blob/083629c23564e1a64deaa052f1df5c5d914358d8/python/sglang/srt/function_call/qwen25_detector.py
+import json
+from json import JSONDecodeError, JSONDecoder
+from json.decoder import WHITESPACE
+from typing import Any
+
+import partial_json_parser
+from partial_json_parser.core.options import Allow
+
+
+def find_common_prefix(s1: str, s2: str) -> str:
+ prefix = ""
+ min_length = min(len(s1), len(s2))
+ for i in range(0, min_length):
+ if s1[i] == s2[i]:
+ prefix += s1[i]
+ else:
+ break
+ return prefix
+
+
+def partial_json_loads(input_str: str, flags: Allow) -> tuple[Any, int]:
+ """
+ Parse incomplete or partial JSON strings commonly encountered during streaming.
+
+ Args:
+ input_str (str): The potentially incomplete JSON string to parse.
+ flags (Allow): Bitwise flags controlling what types of partial data are allowed.
+ Common flags include:
+ - Allow.STR: Allow partial strings (e.g., '"hello wo' -> 'hello wo')
+ - Allow.OBJ: Allow partial objects (e.g., '{"key":' -> {'key': None})
+ - Allow.ARR: Allow partial arrays (e.g., '[1, 2,' -> [1, 2])
+ - Allow.ALL: Allow all types of partial data
+
+ Returns:
+ Tuple[Any, int]: A tuple containing:
+ - parsed_object: The Python object parsed from the JSON
+ - consumed_length: Number of characters consumed from input_str
+ """
+ try:
+ return (partial_json_parser.loads(input_str, flags), len(input_str))
+ except (JSONDecodeError, IndexError) as e:
+ msg = getattr(e, "msg", str(e))
+ if "Extra data" in msg or "pop from empty list" in msg:
+ start = WHITESPACE.match(input_str, 0).end()
+ obj, end = JSONDecoder().raw_decode(input_str, start)
+ return obj, end
+ raise
+
+
+def is_complete_json(input_str: str) -> bool:
+ try:
+ json.loads(input_str)
+ return True
+ except JSONDecodeError:
+ return False
diff --git a/tests/integration/defs/test_e2e.py b/tests/integration/defs/test_e2e.py
index 3f1bab6c94..51b0b66c36 100644
--- a/tests/integration/defs/test_e2e.py
+++ b/tests/integration/defs/test_e2e.py
@@ -1613,6 +1613,13 @@ def test_openai_reasoning(llm_root, llm_venv, backend: str):
])
+def test_openai_tool_call(llm_root, llm_venv):
+ test_root = unittest_path() / "llmapi" / "apps"
+ llm_venv.run_cmd(
+ ["-m", "pytest",
+ str(test_root / "_test_openai_tool_call.py")])
+
+
@pytest.mark.parametrize("sampler", ["torch_sampler", "trtllm_sampler"])
def test_openai_completions_with_logit_bias(llm_root, llm_venv, sampler: str):
test_root = unittest_path() / "llmapi" / "apps"
diff --git a/tests/integration/test_lists/test-db/l0_a10.yml b/tests/integration/test_lists/test-db/l0_a10.yml
index 9aef279b8d..5fc56bd938 100644
--- a/tests/integration/test_lists/test-db/l0_a10.yml
+++ b/tests/integration/test_lists/test-db/l0_a10.yml
@@ -57,6 +57,7 @@ l0_a10:
- test_e2e.py::test_trtllm_serve_top_logprobs[pytorch]
- test_e2e.py::test_openai_misc_example[pytorch]
- test_e2e.py::test_openai_reasoning[pytorch]
+ - test_e2e.py::test_openai_tool_call
- test_e2e.py::test_openai_completions_example[pytorch]
- test_e2e.py::test_openai_chat_example[pytorch] TIMEOUT (90)
- test_e2e.py::test_trtllm_bench_request_rate_and_concurrency[enable_concurrency-]
@@ -71,6 +72,8 @@ l0_a10:
- unittest/llmapi/test_additional_model_outputs.py
# executor
- unittest/executor/test_rpc.py
+ # trtllm-serve CPU-only
+ - unittest/llmapi/apps/test_tool_parsers.py
- condition:
ranges:
system_gpu_count:
diff --git a/tests/unittest/llmapi/apps/README.md b/tests/unittest/llmapi/apps/README.md
index ff316cfa1e..4d066a807a 100644
--- a/tests/unittest/llmapi/apps/README.md
+++ b/tests/unittest/llmapi/apps/README.md
@@ -1,3 +1,3 @@
-This directory contains the end-to-end tests for the LLM API applications in `examples/apps`.
+This directory contains the end-to-end tests for `trtllm-serve`.
These tests are triggered in the `test_e2e.py`.
diff --git a/tests/unittest/llmapi/apps/_test_openai_tool_call.py b/tests/unittest/llmapi/apps/_test_openai_tool_call.py
new file mode 100644
index 0000000000..18dc644997
--- /dev/null
+++ b/tests/unittest/llmapi/apps/_test_openai_tool_call.py
@@ -0,0 +1,121 @@
+import json
+import os
+import sys
+
+import openai
+import pytest
+
+from .openai_server import RemoteOpenAIServer
+
+sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
+from test_llm import get_model_path
+
+TOOLS = [
+ {
+ "type": "function",
+ "function": {
+ "name": "get_current_temperature",
+ "description": "Get current temperature at a location.",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "location": {
+ "type":
+ "string",
+ "description":
+ 'The location to get the temperature for, in the format "City, State, Country".',
+ },
+ "unit": {
+ "type":
+ "string",
+ "enum": ["celsius", "fahrenheit"],
+ "description":
+ 'The unit to return the temperature in. Defaults to "celsius".',
+ },
+ },
+ "required": ["location"],
+ },
+ },
+ },
+]
+
+
+def get_current_temperature(location: str, unit: str = "celsius") -> dict:
+ return {"temperature": 20 if unit == "celsius" else 68}
+
+
+@pytest.fixture(scope="module", ids=["Qwen3-0.6B"])
+def model_name() -> str:
+ return "Qwen3/Qwen3-0.6B"
+
+
+@pytest.fixture(scope="module")
+def server(model_name: str):
+ model_path = get_model_path(model_name)
+ args = ["--tool_parser", "qwen3"]
+ with RemoteOpenAIServer(model_path, cli_args=args) as remote_server:
+ yield remote_server
+
+
+@pytest.fixture(scope="module")
+def client(server: RemoteOpenAIServer):
+ return server.get_async_client()
+
+
+@pytest.mark.asyncio(loop_scope="module")
+async def test_tool_parser(client: openai.AsyncOpenAI, model_name: str):
+ response = await client.chat.completions.create(
+ model=model_name,
+ messages=[{
+ "role": "user",
+ "content": "What's the temperature in San Francisco now?"
+ }],
+ tools=TOOLS)
+ assert response.choices[0].finish_reason == "tool_calls"
+ message = response.choices[0].message
+ assert message.content is not None
+ assert message.tool_calls is not None
+ assert len(message.tool_calls) == 1
+ tool_call = message.tool_calls[0]
+ assert tool_call.function.name == "get_current_temperature"
+ args = json.loads(tool_call.function.arguments)
+ get_current_temperature(**args)
+
+
+@pytest.mark.asyncio(loop_scope="module")
+async def test_tool_parser_streaming(client: openai.AsyncOpenAI,
+ model_name: str):
+ response = await client.chat.completions.create(
+ model=model_name,
+ messages=[{
+ "role": "user",
+ "content": "What's the temperature in San Francisco now?"
+ }],
+ tools=TOOLS,
+ stream=True)
+ tool_id = None
+ tool_name = None
+ parameters = ""
+ finish_reason = None
+
+ async for chunk in response:
+ if chunk.choices[0].delta.tool_calls:
+ tool_call = chunk.choices[0].delta.tool_calls[0]
+ if tool_call.id:
+ if tool_id is not None:
+ raise RuntimeError("tool_id already exists")
+ tool_id = tool_call.id
+ if tool_call.function.name:
+ if tool_name is not None:
+ raise RuntimeError("tool_name already exists")
+ tool_name = tool_call.function.name
+ if tool_call.function.arguments:
+ parameters += tool_call.function.arguments
+ if chunk.choices[0].finish_reason:
+ finish_reason = chunk.choices[0].finish_reason
+ assert tool_id is not None
+ assert tool_name == "get_current_temperature"
+ assert finish_reason == "tool_calls"
+ assert parameters
+ args = json.loads(parameters)
+ get_current_temperature(**args)
diff --git a/tests/unittest/llmapi/apps/test_tool_parsers.py b/tests/unittest/llmapi/apps/test_tool_parsers.py
new file mode 100644
index 0000000000..511f6a47fb
--- /dev/null
+++ b/tests/unittest/llmapi/apps/test_tool_parsers.py
@@ -0,0 +1,597 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+
+import pytest
+
+from tensorrt_llm.serve.openai_protocol import (ChatCompletionToolsParam,
+ FunctionDefinition)
+from tensorrt_llm.serve.tool_parser.base_tool_parser import BaseToolParser
+from tensorrt_llm.serve.tool_parser.core_types import StructureInfo
+from tensorrt_llm.serve.tool_parser.qwen3_tool_parser import Qwen3ToolParser
+
+
+# Test fixtures for common tools
+@pytest.fixture
+def sample_tools():
+ """Sample tools for testing."""
+ return [
+ ChatCompletionToolsParam(
+ type="function",
+ function=FunctionDefinition(name="get_weather",
+ description="Get the current weather",
+ parameters={
+ "type": "object",
+ "properties": {
+ "location": {
+ "type":
+ "string",
+ "description":
+ "The city and state"
+ },
+ "unit": {
+ "type":
+ "string",
+ "enum":
+ ["celsius", "fahrenheit"]
+ }
+ },
+ "required": ["location"]
+ })),
+ ChatCompletionToolsParam(
+ type="function",
+ function=FunctionDefinition(name="search_web",
+ description="Search the web",
+ parameters={
+ "type": "object",
+ "properties": {
+ "query": {
+ "type":
+ "string",
+ "description":
+ "The search query"
+ }
+ },
+ "required": ["query"]
+ }))
+ ]
+
+
+# Concrete implementation of BaseToolParser for testing
+class ConcreteToolParser(BaseToolParser):
+ """Concrete implementation of BaseToolParser for testing abstract methods."""
+
+ def __init__(self):
+ super().__init__()
+ self.bot_token = "[TOOL_CALLS] "
+ self.eot_token = "[/TOOL_CALLS]"
+
+ def has_tool_call(self, text: str) -> bool:
+ return self.bot_token in text
+
+ def detect_and_parse(self, text: str, tools):
+ # Placeholder to avoid NotImplementedError
+ pass
+
+ def structure_info(self):
+ return lambda name: StructureInfo(
+ begin=f'[TOOL_CALLS] {{"name":"{name}", "arguments":',
+ end="}[/TOOL_CALLS]",
+ trigger="[TOOL_CALLS]")
+
+
+# ============================================================================
+# BaseToolParser Tests
+# ============================================================================
+
+
+class TestBaseToolParser:
+ """Test suite for BaseToolParser class."""
+
+ def test_initialization(self):
+ """Test that BaseToolParser initializes correctly."""
+ parser = ConcreteToolParser()
+ assert parser._buffer == ""
+ assert parser.prev_tool_call_arr == []
+ assert parser.current_tool_id == -1
+ assert parser.current_tool_name_sent is False
+ assert parser.streamed_args_for_tool == []
+
+ def test_get_tool_indices(self, sample_tools):
+ """Test _get_tool_indices correctly maps tool names to indices."""
+ parser = ConcreteToolParser()
+ indices = parser._get_tool_indices(sample_tools)
+
+ assert len(indices) == 2
+ assert indices["get_weather"] == 0
+ assert indices["search_web"] == 1
+
+ def test_get_tool_indices_empty(self):
+ """Test _get_tool_indices with empty tools list."""
+ parser = ConcreteToolParser()
+ indices = parser._get_tool_indices([])
+ assert indices == {}
+
+ def test_parse_base_json_single_tool(self, sample_tools):
+ """Test parse_base_json with a single tool call."""
+ parser = ConcreteToolParser()
+ action = {
+ "name": "get_weather",
+ "parameters": {
+ "location": "San Francisco"
+ }
+ }
+
+ results = parser.parse_base_json(action, sample_tools)
+
+ assert len(results) == 1
+ assert results[0].name == "get_weather"
+ assert json.loads(results[0].parameters) == {
+ "location": "San Francisco"
+ }
+
+ def test_parse_base_json_with_arguments_key(self, sample_tools):
+ """Test parse_base_json handles 'arguments' key instead of 'parameters'."""
+ parser = ConcreteToolParser()
+ action = {"name": "search_web", "arguments": {"query": "TensorRT"}}
+
+ results = parser.parse_base_json(action, sample_tools)
+
+ assert len(results) == 1
+ assert results[0].name == "search_web"
+ assert json.loads(results[0].parameters) == {"query": "TensorRT"}
+
+ def test_parse_base_json_multiple_tools(self, sample_tools):
+ """Test parse_base_json with multiple tool calls."""
+ parser = ConcreteToolParser()
+ actions = [{
+ "name": "get_weather",
+ "parameters": {
+ "location": "Boston"
+ }
+ }, {
+ "name": "search_web",
+ "arguments": {
+ "query": "Python"
+ }
+ }]
+
+ results = parser.parse_base_json(actions, sample_tools)
+
+ assert len(results) == 2
+ assert results[0].name == "get_weather"
+ assert results[1].name == "search_web"
+
+ def test_parse_base_json_undefined_function(self, sample_tools):
+ """Test parse_base_json handles undefined function names gracefully."""
+ parser = ConcreteToolParser()
+ action = {"name": "undefined_function", "parameters": {}}
+
+ results = parser.parse_base_json(action, sample_tools)
+
+ # Should return empty list and log warning
+ assert len(results) == 0
+
+ def test_parse_base_json_missing_parameters(self, sample_tools):
+ """Test parse_base_json handles missing parameters."""
+ parser = ConcreteToolParser()
+ action = {"name": "get_weather"}
+
+ results = parser.parse_base_json(action, sample_tools)
+
+ assert len(results) == 1
+ assert json.loads(results[0].parameters) == {}
+
+ def test_ends_with_partial_token(self):
+ """Test _ends_with_partial_token detection."""
+ parser = ConcreteToolParser()
+
+ # Partial token at end (bot_token starts with the suffix)
+ assert parser._ends_with_partial_token("Some text [TOOL",
+ "[TOOL_CALLS] ") == 5
+ assert parser._ends_with_partial_token("Some text [",
+ "[TOOL_CALLS] ") == 1
+ assert parser._ends_with_partial_token("Some text [TOOL_CALLS",
+ "[TOOL_CALLS] ") == 11
+
+ # No partial token
+ assert parser._ends_with_partial_token("Some text",
+ "[TOOL_CALLS] ") == 0
+ assert parser._ends_with_partial_token("Some text [XYZ",
+ "[TOOL_CALLS] ") == 0
+
+ # Complete token at end (entire buffer is bot_token prefix but not complete match)
+ # When buffer equals bot_token, it returns 0 because it's not a partial anymore
+ assert parser._ends_with_partial_token("text [TOOL_CALLS] ",
+ "[TOOL_CALLS] ") == 0
+
+ def test_parse_streaming_increment_no_tool_call(self, sample_tools):
+ """Test streaming parser returns normal text when no tool call present."""
+ parser = ConcreteToolParser()
+
+ result = parser.parse_streaming_increment("Hello, world!", sample_tools)
+
+ assert result.normal_text == "Hello, world!"
+ assert len(result.calls) == 0
+
+ def test_parse_streaming_increment_partial_bot_token(self, sample_tools):
+ """Test streaming parser buffers partial bot token."""
+ parser = ConcreteToolParser()
+
+ # Send partial bot token
+ result = parser.parse_streaming_increment("[TOOL", sample_tools)
+
+ # Should buffer and return nothing
+ assert result.normal_text == ""
+ assert len(result.calls) == 0
+ assert parser._buffer == "[TOOL"
+
+ def test_parse_streaming_increment_tool_name(self, sample_tools):
+ """Test streaming parser handles tool name streaming."""
+ parser = ConcreteToolParser()
+
+ # Send bot token with partial JSON containing name
+ result = parser.parse_streaming_increment(
+ '[TOOL_CALLS] {"name":"get_weather"', sample_tools)
+
+ # Should send tool name with empty parameters
+ assert len(result.calls) == 1
+ assert result.calls[0].name == "get_weather"
+ assert result.calls[0].parameters == ""
+ assert result.calls[0].tool_index == 0
+ assert parser.current_tool_name_sent is True
+
+ def test_parse_streaming_increment_tool_arguments(self, sample_tools):
+ """Test streaming parser handles incremental argument streaming."""
+ parser = ConcreteToolParser()
+
+ # First send tool name
+ result1 = parser.parse_streaming_increment(
+ '[TOOL_CALLS] {"name":"get_weather"', sample_tools)
+ # Should send tool name
+ assert len(result1.calls) == 1
+ assert result1.calls[0].name == "get_weather"
+
+ # Then send complete arguments (parser needs complete JSON to parse incrementally)
+ result2 = parser.parse_streaming_increment(
+ ',"arguments":{"location":"San Francisco"}}', sample_tools)
+
+ # Should stream arguments or complete the tool call
+ # The base implementation uses partial JSON parsing, so it may return results
+ assert result2 is not None # Just verify it doesn't crash
+
+ def test_parse_streaming_increment_complete_tool(self, sample_tools):
+ """Test streaming parser handles complete tool call."""
+ parser = ConcreteToolParser()
+
+ # Send complete tool call in one chunk
+ result = parser.parse_streaming_increment(
+ '[TOOL_CALLS] {"name":"get_weather","arguments":{"location":"Boston"}}',
+ sample_tools)
+
+ # Should have sent tool name (first call)
+ assert len(result.calls) == 1
+ assert result.calls[0].name == "get_weather"
+
+ def test_parse_streaming_increment_invalid_tool_name(self, sample_tools):
+ """Test streaming parser handles invalid tool name."""
+ parser = ConcreteToolParser()
+
+ # Send invalid tool name
+ result = parser.parse_streaming_increment(
+ '[TOOL_CALLS] {"name":"invalid_tool"', sample_tools)
+
+ # Should reset state
+ assert len(result.calls) == 0
+ assert parser._buffer == ""
+ assert parser.current_tool_id == -1
+
+ def test_supports_structural_tag(self):
+ """Test supports_structural_tag returns True."""
+ parser = ConcreteToolParser()
+ assert parser.supports_structural_tag() is True
+
+ def test_structure_info(self):
+ """Test structure_info returns proper function."""
+ parser = ConcreteToolParser()
+ func = parser.structure_info()
+
+ info = func("test_function")
+ assert isinstance(info, StructureInfo)
+ assert "test_function" in info.begin
+ assert info.trigger == "[TOOL_CALLS]"
+
+
+# ============================================================================
+# Qwen3ToolParser Tests
+# ============================================================================
+
+
+class TestQwen3ToolParser:
+ """Test suite for Qwen3ToolParser class."""
+
+ def test_initialization(self):
+ """Test that Qwen3ToolParser initializes correctly."""
+ parser = Qwen3ToolParser()
+
+ assert parser.bot_token == "\n"
+ assert parser.eot_token == "\n"
+ assert parser.tool_call_separator == "\n"
+ assert parser._normal_text_buffer == ""
+
+ def test_has_tool_call_true(self):
+ """Test has_tool_call returns True when tool call is present."""
+ parser = Qwen3ToolParser()
+ text = 'Some text \n{"name":"get_weather"}\n'
+
+ assert parser.has_tool_call(text) is True
+
+ def test_has_tool_call_false(self):
+ """Test has_tool_call returns False when no tool call present."""
+ parser = Qwen3ToolParser()
+ text = "Just some regular text without tool calls"
+
+ assert parser.has_tool_call(text) is False
+
+ def test_detect_and_parse_no_tool_call(self, sample_tools):
+ """Test detect_and_parse with text containing no tool calls."""
+ parser = Qwen3ToolParser()
+ text = "This is just a regular response."
+
+ result = parser.detect_and_parse(text, sample_tools)
+
+ assert result.normal_text == "This is just a regular response."
+ assert len(result.calls) == 0
+
+ def test_detect_and_parse_single_tool(self, sample_tools):
+ """Test detect_and_parse with a single tool call."""
+ parser = Qwen3ToolParser()
+ text = 'Normal text\n\n{"name":"get_weather","arguments":{"location":"NYC"}}\n'
+
+ result = parser.detect_and_parse(text, sample_tools)
+
+ assert result.normal_text == "Normal text"
+ assert len(result.calls) == 1
+ assert result.calls[0].name == "get_weather"
+ assert json.loads(result.calls[0].parameters) == {"location": "NYC"}
+
+ def test_detect_and_parse_multiple_tools(self, sample_tools):
+ """Test detect_and_parse with multiple tool calls."""
+ parser = Qwen3ToolParser()
+ text = (
+ '\n{"name":"get_weather","arguments":{"location":"LA"}}\n\n'
+ '\n{"name":"search_web","arguments":{"query":"AI"}}\n'
+ )
+
+ result = parser.detect_and_parse(text, sample_tools)
+
+ assert len(result.calls) == 2
+ assert result.calls[0].name == "get_weather"
+ assert result.calls[1].name == "search_web"
+
+ def test_detect_and_parse_malformed_json(self, sample_tools):
+ """Test detect_and_parse handles malformed JSON gracefully."""
+ parser = Qwen3ToolParser()
+ text = '\n{"name":"get_weather","arguments":MALFORMED}\n'
+
+ result = parser.detect_and_parse(text, sample_tools)
+
+ # Should return empty calls due to JSON parsing error
+ assert len(result.calls) == 0
+
+ def test_detect_and_parse_with_parameters_key(self, sample_tools):
+ """Test detect_and_parse handles 'parameters' key."""
+ parser = Qwen3ToolParser()
+ text = '\n{"name":"search_web","parameters":{"query":"test"}}\n'
+
+ result = parser.detect_and_parse(text, sample_tools)
+
+ assert len(result.calls) == 1
+ assert result.calls[0].name == "search_web"
+ assert json.loads(result.calls[0].parameters) == {"query": "test"}
+
+ def test_parse_streaming_increment_normal_text(self, sample_tools):
+ """Test streaming parser handles normal text without tool calls."""
+ parser = Qwen3ToolParser()
+
+ result = parser.parse_streaming_increment("Hello, how can I help?",
+ sample_tools)
+
+ assert result.normal_text == "Hello, how can I help?"
+ assert len(result.calls) == 0
+
+ def test_parse_streaming_increment_partial_bot_token(self, sample_tools):
+ """Test streaming parser buffers partial bot token."""
+ parser = Qwen3ToolParser()
+
+ # Send partial bot token
+ result = parser.parse_streaming_increment("\n", sample_tools)
+
+ # Send partial JSON with name
+ result = parser.parse_streaming_increment('{"name":"get_weather"',
+ sample_tools)
+
+ # Should send tool name
+ assert len(result.calls) == 1
+ assert result.calls[0].name == "get_weather"
+ assert result.calls[0].parameters == ""
+
+ # Send arguments
+ result = parser.parse_streaming_increment(
+ ',"arguments":{"location":"SF"}}\n', sample_tools)
+
+ # Should stream arguments
+ assert len(result.calls) == 1
+ assert json.loads(result.calls[0].parameters) == {"location": "SF"}
+
+ def test_parse_streaming_increment_end_token_handling(self, sample_tools):
+ """Test streaming parser handles end token correctly."""
+ parser = Qwen3ToolParser()
+
+ # Send complete tool call
+ parser.parse_streaming_increment(
+ '\n{"name":"get_weather","arguments":{"location":"NYC"}}\n',
+ sample_tools)
+
+ # The end token should be removed from normal text
+ # Check buffer state
+ assert parser._normal_text_buffer == ""
+
+ def test_parse_streaming_increment_multiple_tools_streaming(
+ self, sample_tools):
+ """Test streaming parser handles multiple tool calls."""
+ parser = Qwen3ToolParser()
+
+ # First tool
+ parser.parse_streaming_increment('\n', sample_tools)
+ parser.parse_streaming_increment(
+ '{"name":"get_weather","arguments":{"location":"NYC"}}\n\n',
+ sample_tools)
+
+ # Second tool
+ parser.parse_streaming_increment('\n', sample_tools)
+ result = parser.parse_streaming_increment('{"name":"search_web"',
+ sample_tools)
+
+ # Should have started second tool
+ assert result.calls[0].name == "search_web"
+ assert result.calls[0].parameters == ""
+ assert result.calls[0].tool_index == 1
+
+ def test_structure_info_function(self):
+ """Test structure_info returns correct lambda function."""
+ parser = Qwen3ToolParser()
+ func = parser.structure_info()
+
+ info = func("test_function")
+
+ assert isinstance(info, StructureInfo)
+ assert info.begin == '\n{"name":"test_function", "arguments":'
+ assert info.end == "}\n"
+ assert info.trigger == ""
+
+ def test_structure_info_different_names(self):
+ """Test structure_info works with different function names."""
+ parser = Qwen3ToolParser()
+ func = parser.structure_info()
+
+ info1 = func("get_weather")
+ info2 = func("search_web")
+
+ assert "get_weather" in info1.begin
+ assert "search_web" in info2.begin
+ assert info1.end == info2.end == "}\n"
+
+ def test_qwen3_format_compliance(self, sample_tools):
+ """Test that Qwen3ToolParser follows the documented format structure."""
+ parser = Qwen3ToolParser()
+
+ # Test the exact format from the docstring
+ text = '\n{"name":"get_weather", "arguments":{"location":"Tokyo"}}\n'
+
+ result = parser.detect_and_parse(text, sample_tools)
+
+ assert len(result.calls) == 1
+ assert result.calls[0].name == "get_weather"
+ assert json.loads(result.calls[0].parameters) == {"location": "Tokyo"}
+
+ def test_undefined_tool_in_qwen3_format(self, sample_tools):
+ """Test Qwen3ToolParser handles undefined tool gracefully."""
+ parser = Qwen3ToolParser()
+ text = '\n{"name":"undefined_func","arguments":{}}\n'
+
+ result = parser.detect_and_parse(text, sample_tools)
+
+ # Should not return any calls for undefined function
+ assert len(result.calls) == 0
+
+
+# ============================================================================
+# Integration Tests
+# ============================================================================
+
+
+class TestToolParserIntegration:
+ """Integration tests for tool parsers."""
+
+ def test_end_to_end_single_tool(self, sample_tools):
+ """Test end-to-end parsing of a single tool call."""
+ parser = Qwen3ToolParser()
+
+ # Simulate streaming
+ chunks = [
+ "\n", '{"name":"get', '_weather"', ',"arguments":',
+ '{"location"', ':"Paris"}}\n', ''
+ ]
+
+ results = []
+ for chunk in chunks:
+ result = parser.parse_streaming_increment(chunk, sample_tools)
+ if result.calls or result.normal_text:
+ results.append(result)
+
+ # Should have received tool name and arguments
+ assert any(r.calls for r in results)
+
+ def test_mixed_content_and_tool_calls(self, sample_tools):
+ """Test parsing text that mixes normal content with tool calls."""
+ parser = Qwen3ToolParser()
+
+ text = (
+ 'I will check the weather for you.\n'
+ '\n{"name":"get_weather","arguments":{"location":"London"}}\n\n'
+ 'Let me search that for you.')
+
+ result = parser.detect_and_parse(text, sample_tools)
+
+ assert "I will check the weather for you." in result.normal_text
+ assert len(result.calls) == 1
+ assert result.calls[0].name == "get_weather"
+
+ def test_parser_state_reset(self, sample_tools):
+ """Test that parser state can be used for multiple requests."""
+ parser = Qwen3ToolParser()
+
+ # First request
+ result1 = parser.detect_and_parse(
+ '\n{"name":"get_weather","arguments":{"location":"NYC"}}\n',
+ sample_tools)
+
+ # Reset internal state for new request
+ parser2 = Qwen3ToolParser()
+
+ # Second request
+ result2 = parser2.detect_and_parse(
+ '\n{"name":"search_web","arguments":{"query":"test"}}\n',
+ sample_tools)
+
+ assert result1.calls[0].name == "get_weather"
+ assert result2.calls[0].name == "search_web"
+
+
+if __name__ == "__main__":
+ pytest.main([__file__, "-v"])