[TRTLLM-8214][feat] Support Qwen3 tool parser (#8216)

Signed-off-by: Pengyun Lin <81065165+LinPoly@users.noreply.github.com>
This commit is contained in:
Pengyun Lin 2025-10-29 15:48:29 +08:00 committed by GitHub
parent 741183917c
commit 2aade46d18
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 1405 additions and 19 deletions

View File

@ -78,3 +78,4 @@ nvidia-cutlass-dsl==4.2.1; python_version >= "3.10"
numba-cuda>=0.19.0 # WAR for nvbugs/5501820
plotly
numexpr<2.14.0 # WAR for attempted use of nonexistent numpy.typing
partial_json_parser

View File

@ -33,6 +33,7 @@ from tensorrt_llm.llmapi.mpi_session import find_free_port
from tensorrt_llm.llmapi.reasoning_parser import ReasoningParserFactory
from tensorrt_llm.logger import logger, severity_map
from tensorrt_llm.serve import OpenAIDisaggServer, OpenAIServer
from tensorrt_llm.serve.tool_parser import ToolParserFactory
# Global variable to store the Popen object of the child process
_child_p_global: Optional[subprocess.Popen] = None
@ -150,6 +151,7 @@ def launch_server(
host: str,
port: int,
llm_args: dict,
tool_parser: Optional[str] = None,
metadata_server_cfg: Optional[MetadataServerConfig] = None,
server_role: Optional[ServerRole] = None,
disagg_cluster_config: Optional[DisaggClusterConfig] = None,
@ -173,6 +175,7 @@ def launch_server(
server = OpenAIServer(llm=llm,
model=model,
tool_parser=tool_parser,
server_role=server_role,
metadata_server_cfg=metadata_server_cfg,
disagg_cluster_config=disagg_cluster_config,
@ -311,6 +314,12 @@ class ChoiceWithAlias(click.Choice):
default=None,
help="[Experimental] Specify the parser for reasoning models.",
)
@click.option(
"--tool_parser",
type=click.Choice(ToolParserFactory.parsers.keys()),
default=None,
help="[Experimental] Specify the parser for tool models.",
)
@click.option("--metadata_server_config_file",
type=str,
default=None,
@ -352,7 +361,8 @@ def serve(
gpus_per_node: Optional[int], kv_cache_free_gpu_memory_fraction: float,
num_postprocess_workers: int, trust_remote_code: bool,
extra_llm_api_options: Optional[str], reasoning_parser: Optional[str],
metadata_server_config_file: Optional[str], server_role: Optional[str],
tool_parser: Optional[str], metadata_server_config_file: Optional[str],
server_role: Optional[str],
fail_fast_on_attention_window_too_large: bool,
otlp_traces_endpoint: Optional[str], enable_chunked_prefill: bool,
disagg_cluster_uri: Optional[str], media_io_kwargs: Optional[str]):
@ -423,8 +433,8 @@ def serve(
multimodal_server_config = MultimodalServerConfig(
media_io_kwargs=parsed_media_io_kwargs)
launch_server(host, port, llm_args, metadata_server_cfg, server_role,
disagg_cluster_config, multimodal_server_config)
launch_server(host, port, llm_args, tool_parser, metadata_server_cfg,
server_role, disagg_cluster_config, multimodal_server_config)
@click.command("mm_embedding_serve")

View File

@ -1,3 +1,4 @@
import uuid
from functools import partial
from typing import (Any, Callable, Coroutine, Dict, Iterable, List, Literal,
Optional, Tuple, TypeAlias, TypedDict, Union, cast)
@ -220,3 +221,11 @@ def check_multiple_response(n: int, backend: Optional[str]):
if n > 1 and backend == "pytorch":
raise ValueError(
"Multiple response is not supported in PyTorch workflow")
def make_tool_call_id(id_type: str = "random", func_name=None, idx=None):
if id_type == "kimi_k2":
return f"functions.{func_name}:{idx}"
else:
# by default return random
return f"chatcmpl-tool-{uuid.uuid4().hex}"

View File

@ -78,12 +78,14 @@ class OpenAIServer:
def __init__(self,
llm: Union[LLM, MultimodalEncoder],
model: str,
tool_parser: Optional[str],
server_role: Optional[ServerRole],
metadata_server_cfg: MetadataServerConfig,
disagg_cluster_config: Optional[DisaggClusterConfig] = None,
multimodal_server_config: Optional[MultimodalServerConfig] = None):
self.llm = llm
self.tokenizer = llm.tokenizer
self.tool_parser = tool_parser
self.metadata_server = create_metadata_server(metadata_server_cfg)
self.disagg_cluster_config = disagg_cluster_config
self.multimodal_server_config = multimodal_server_config
@ -532,6 +534,7 @@ class OpenAIServer:
prompt["multi_modal_data"] = mm_data
postproc_args.reasoning_parser = self.llm.args.reasoning_parser
postproc_args.tool_parser = self.tool_parser
if conversation and conversation[-1].get(
"content") and conversation[-1].get("role") == get_role():
postproc_args.last_message_content = conversation[-1]["content"]

View File

@ -10,6 +10,7 @@ from ..llmapi.reasoning_parser import (BaseReasoningParser,
ReasoningParserFactory)
from ..llmapi.tokenizer import TransformersTokenizer
# yapf: disable
from .chat_utils import make_tool_call_id
from .harmony_adapter import (handle_non_streaming_response,
handle_streaming_response)
from .openai_protocol import (ChatCompletionLogProbs,
@ -23,9 +24,13 @@ from .openai_protocol import (ChatCompletionLogProbs,
CompletionRequest, CompletionResponse,
CompletionResponseChoice,
CompletionResponseStreamChoice,
CompletionStreamResponse, DeltaMessage,
FunctionCall, PromptTokensDetails, StreamOptions,
ToolCall, UsageInfo, to_disaggregated_params)
CompletionStreamResponse, DeltaFunctionCall,
DeltaMessage, DeltaToolCall, FunctionCall,
PromptTokensDetails, StreamOptions, ToolCall,
UsageInfo, to_disaggregated_params)
from .tool_parser.base_tool_parser import BaseToolParser
from .tool_parser.core_types import ToolCallItem
from .tool_parser.tool_parser_factory import ToolParserFactory
# yapf: enable
@ -33,8 +38,8 @@ from .openai_protocol import (ChatCompletionLogProbs,
@dataclass(kw_only=True)
class ChatPostprocArgs(PostprocArgs):
echo: bool = False
role: str = None
model: str = None
role: str
model: str
num_choices: int = 1
tools: Optional[List[ChatCompletionToolsParam]] = None
tool_choice: Optional[Union[Literal["none"],
@ -44,8 +49,11 @@ class ChatPostprocArgs(PostprocArgs):
stream_options: Optional[StreamOptions] = None
last_message_content: Optional[str] = None
reasoning_parser: Optional[str] = None
tool_parser: Optional[str] = None
reasoning_parser_dict: dict[int, BaseReasoningParser] = field(
default_factory=dict)
tool_parser_dict: dict[int, BaseToolParser] = field(default_factory=dict)
has_tool_call: dict[int, bool] = field(default_factory=dict)
@classmethod
def from_request(cls, request: ChatCompletionRequest):
@ -116,6 +124,31 @@ def apply_reasoning_parser(args: ChatPostprocArgs, output_index: int, text: str,
return content, reasoning_content
def apply_tool_parser(args: ChatPostprocArgs, output_index: int, text: str,
streaming: bool) -> Tuple[str, List[ToolCallItem]]:
tool_parser = None
tools = args.tools
if args.tool_parser is not None and tools is not None:
if output_index not in args.tool_parser_dict:
args.tool_parser_dict[
output_index] = ToolParserFactory.create_tool_parser(
args.tool_parser)
tool_parser = args.tool_parser_dict[output_index]
if tool_parser is not None and tools is not None:
if not streaming:
result = tool_parser.detect_and_parse(text, tools)
else:
result = tool_parser.parse_streaming_increment(text, tools)
normal_text, calls = result.normal_text, result.calls
if result.calls:
args.has_tool_call[output_index] = True
else:
normal_text, calls = text, []
return normal_text, calls
@nvtx_range_debug("chat_stream_post_processor")
def chat_stream_post_processor(rsp: GenerationResultBase,
args: ChatPostprocArgs) -> List[str]:
@ -176,27 +209,63 @@ def chat_stream_post_processor(rsp: GenerationResultBase,
if args.tool_choice and type(
args.tool_choice) is ChatCompletionNamedToolChoiceParam:
delta_message = DeltaMessage(tool_calls=[
ToolCall(function=FunctionCall(
name=args.tool_choice.function.name, arguments=delta_text))
])
DeltaToolCall(
function=DeltaFunctionCall(
name=args.tool_choice.function.name,
arguments=delta_text),
index=i,
),
], )
else:
delta_message = DeltaMessage(content=delta_text,
reasoning_content=reasoning_delta_text)
delta_text, calls = apply_tool_parser(args, i, delta_text, True)
tool_calls = []
for call_item in calls:
# Tool call ID should be generated only once per tool call
if call_item.name:
# First chunk: include ID and function name
tool_call_id = make_tool_call_id()
function_name = call_item.name
else:
# Subsequent chunks: null ID and name for argument deltas
tool_call_id = None
function_name = None
tool_calls.append(
DeltaToolCall(
id=tool_call_id,
index=call_item.tool_index,
function=DeltaFunctionCall(
name=function_name,
arguments=call_item.parameters,
),
))
if tool_calls or delta_text or reasoning_delta_text or output.finish_reason:
delta_message = DeltaMessage(
content=delta_text,
reasoning_content=reasoning_delta_text,
tool_calls=tool_calls if tool_calls else None)
else:
continue
choice = ChatCompletionResponseStreamChoice(
index=i,
delta=delta_message,
finish_reason=None,
avg_decoded_tokens_per_iter=getattr(rsp,
'avg_decoded_tokens_per_iter',
None))
None),
stop_reason=output.stop_reason,
)
if args.return_logprobs:
logprobs = output.logprobs_diff
token_ids = output.token_ids_diff
choice.logprobs = create_logprobs(token_ids, args.tokenizer,
logprobs, args.top_logprobs)
if output.finish_reason is not None:
choice.finish_reason = output.finish_reason
if output.finish_reason == "stop" and args.has_tool_call.get(
i, False):
choice.finish_reason = "tool_calls"
else:
choice.finish_reason = output.finish_reason
choice.stop_reason = output.stop_reason
finish_reason_sent[i] = True
chunk = ChatCompletionStreamResponse(choices=[choice], model=args.model)
@ -247,21 +316,34 @@ def chat_response_post_processor(
name=args.tool_choice.function.name, arguments=text))
])
else:
if text is None:
text = ""
text, calls = apply_tool_parser(args, output.index, text, False)
tool_calls = [
ToolCall(function=FunctionCall(name=call.name or "",
arguments=call.parameters))
for call in calls
]
message = ChatMessage(role=role,
content=text,
reasoning_content=reasoning_text)
reasoning_content=reasoning_text,
tool_calls=tool_calls)
disaggregated_params = to_disaggregated_params(
output.disaggregated_params)
choice = ChatCompletionResponseChoice(
index=output.index,
message=message,
finish_reason=output.finish_reason,
stop_reason=output.stop_reason,
disaggregated_params=disaggregated_params,
avg_decoded_tokens_per_iter=getattr(rsp,
'avg_decoded_tokens_per_iter',
None),
)
if output.finish_reason == "stop" and args.has_tool_call.get(
output.index, False):
choice.finish_reason = "tool_calls"
else:
choice.finish_reason = output.finish_reason
if args.return_logprobs:
choice.logprobs = create_logprobs(output.token_ids, args.tokenizer,

View File

@ -0,0 +1,3 @@
from .tool_parser_factory import ToolParserFactory
__all__ = ["ToolParserFactory"]

View File

@ -0,0 +1,324 @@
# Adapted from https://github.com/sgl-project/sglang/blob/083629c23564e1a64deaa052f1df5c5d914358d8/python/sglang/srt/function_call/base_format_detector.py
import json
from abc import ABC, abstractmethod
from typing import Any, Dict, List
from partial_json_parser.core.exceptions import MalformedJSON
from partial_json_parser.core.options import Allow
from tensorrt_llm.logger import logger
from ..openai_protocol import ChatCompletionToolsParam as Tool
from .core_types import StreamingParseResult, ToolCallItem, _GetInfoFunc
from .utils import find_common_prefix, is_complete_json, partial_json_loads
class BaseToolParser(ABC):
"""Base class providing two sets of interfaces: one-time and streaming incremental."""
def __init__(self):
# Streaming state management
# Buffer for accumulating incomplete patterns that arrive across multiple streaming chunks
self._buffer = ""
# Stores complete tool call info (name and arguments) for each tool being parsed.
# Used by serving layer for completion handling when streaming ends.
# Format: [{"name": str, "arguments": dict}, ...]
self.prev_tool_call_arr: List[Dict] = []
# Index of currently streaming tool call. Starts at -1 (no active tool),
# increments as each tool completes. Tracks which tool's arguments are streaming.
self.current_tool_id: int = -1
# Flag for whether current tool's name has been sent to client.
# Tool names sent first with empty parameters, then arguments stream incrementally.
self.current_tool_name_sent: bool = False
# Tracks raw JSON string content streamed to client for each tool's arguments.
# Critical for serving layer to calculate remaining content when streaming ends.
# Each index corresponds to a tool_id. Example: ['{"location": "San Francisco"', '{"temp": 72']
self.streamed_args_for_tool: List[str] = []
# Token configuration (override in subclasses)
self.bot_token = "" # nosec B105
self.eot_token = "" # nosec B105
self.tool_call_separator = ", "
def _get_tool_indices(self, tools: List[Tool]) -> Dict[str, int]:
"""
Get a mapping of tool names to their indices in the tools list.
This utility method creates a dictionary mapping function names to their
indices in the tools list, which is commonly needed for tool validation
and ToolCallItem creation.
Args:
tools: List of available tools
Returns:
Dictionary mapping tool names to their indices
"""
return {
tool.function.name: i
for i, tool in enumerate(tools) if tool.function.name
}
def parse_base_json(self, action: Any,
tools: List[Tool]) -> List[ToolCallItem]:
tool_indices = self._get_tool_indices(tools)
if not isinstance(action, list):
action = [action]
results = []
for act in action:
name = act.get("name")
if name and name in tool_indices:
results.append(
ToolCallItem(
tool_index=
-1, # Caller should update this based on the actual tools array called
name=name,
parameters=json.dumps(
act.get("parameters") or act.get("arguments", {}),
ensure_ascii=False,
),
))
else:
logger.warning(
f"Model attempted to call undefined function: {name}")
return results
@abstractmethod
def detect_and_parse(self, text: str,
tools: List[Tool]) -> StreamingParseResult:
"""
Parses the text in one go. Returns success=True if the format matches, otherwise False.
Note that leftover_text here represents "content that this parser will not consume further".
"""
action = json.loads(text)
return StreamingParseResult(calls=self.parse_base_json(action, tools))
def _ends_with_partial_token(self, buffer: str, bot_token: str) -> int:
"""
Check if buffer ends with a partial bot_token.
Return the length of the partial bot_token.
For some format, the bot_token is not a token in model's vocabulary, such as
`[TOOL_CALLS] [` in Mistral.
"""
for i in range(1, min(len(buffer) + 1, len(bot_token))):
if bot_token.startswith(buffer[-i:]):
return i
return 0
def parse_streaming_increment(self, new_text: str,
tools: List[Tool]) -> StreamingParseResult:
"""
Streaming incremental parsing with tool validation.
This base implementation works best with formats where:
1. bot_token is followed immediately by JSON (e.g., bot_token + JSON_array)
2. JSON can be parsed incrementally using partial_json_loads
3. Multiple tool calls are separated by "; " or ", "
Examples of incompatible formats (need custom implementation, may reuse some logic from this class):
- Each tool call is wrapped in a separate block: See Qwen25Detector
- Multiple separate blocks: [TOOL_CALLS] [...] \n [TOOL_CALLS] [...]
- Tool call is Pythonic style
For incompatible formats, detectors should override this method with custom logic.
"""
# Append new text to buffer
self._buffer += new_text
current_text = self._buffer
# The current_text has tool_call if it is the start of a new tool call sequence
# or it is the start of a new tool call after a tool call separator, when there is a previous tool call
if not (self.has_tool_call(current_text) or
(self.current_tool_id > 0
and current_text.startswith(self.tool_call_separator))):
# Only clear buffer if we're sure no tool call is starting
if not self._ends_with_partial_token(self._buffer, self.bot_token):
normal_text = self._buffer
self._buffer = ""
if self.eot_token in normal_text:
normal_text = normal_text.replace(self.eot_token, "")
return StreamingParseResult(normal_text=normal_text)
else:
# Might be partial bot_token, keep buffering
return StreamingParseResult()
# Build tool indices if not already built
if not hasattr(self, "_tool_indices"):
self._tool_indices = self._get_tool_indices(tools)
flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
try:
try:
tool_call_pos = current_text.find(self.bot_token)
if tool_call_pos != -1:
start_idx = tool_call_pos + len(self.bot_token)
elif self.current_tool_id > 0 and current_text.startswith(
self.tool_call_separator):
start_idx = len(self.tool_call_separator)
else:
start_idx = 0
if start_idx >= len(current_text):
return StreamingParseResult()
(obj, end_idx) = partial_json_loads(current_text[start_idx:],
flags)
is_current_complete = is_complete_json(
current_text[start_idx:start_idx + end_idx])
# Validate tool name if present
if "name" in obj and obj["name"] not in self._tool_indices:
# Invalid tool name - reset state
self._buffer = ""
self.current_tool_id = -1
self.current_tool_name_sent = False
if self.streamed_args_for_tool:
self.streamed_args_for_tool.pop()
return StreamingParseResult()
# Handle parameters/arguments consistency
# NOTE: we assume here that the obj is always partial of a single tool call
if "parameters" in obj:
assert ("arguments" not in obj
), "model generated both parameters and arguments"
obj["arguments"] = obj["parameters"]
current_tool_call = obj
except MalformedJSON:
return StreamingParseResult()
if not current_tool_call:
return StreamingParseResult()
# Case 1: Handle tool name streaming
# This happens when we encounter a tool but haven't sent its name yet
if not self.current_tool_name_sent:
function_name = current_tool_call.get("name")
if function_name and function_name in self._tool_indices:
# If this is a new tool (current_tool_id was -1), initialize it
if self.current_tool_id == -1:
self.current_tool_id = 0
self.streamed_args_for_tool.append("")
# If this is a subsequent tool, ensure streamed_args_for_tool is large enough
elif self.current_tool_id >= len(
self.streamed_args_for_tool):
while len(self.streamed_args_for_tool
) <= self.current_tool_id:
self.streamed_args_for_tool.append("")
# Send the tool name with empty parameters
res = StreamingParseResult(calls=[
ToolCallItem(
tool_index=self.current_tool_id,
name=function_name,
parameters="",
)
], )
self.current_tool_name_sent = True
else:
res = StreamingParseResult()
# Case 2: Handle streaming arguments
# This happens when we've already sent the tool name and now need to stream arguments incrementally
else:
cur_arguments = current_tool_call.get("arguments")
res = StreamingParseResult()
if cur_arguments:
# Calculate how much of the arguments we've already streamed
sent = len(
self.streamed_args_for_tool[self.current_tool_id])
cur_args_json = json.dumps(cur_arguments)
prev_arguments = None
if self.current_tool_id < len(self.prev_tool_call_arr):
prev_arguments = self.prev_tool_call_arr[
self.current_tool_id].get("arguments")
argument_diff = None
# If the current tool's JSON is complete, send all remaining arguments
if is_current_complete:
argument_diff = cur_args_json[sent:]
completing_tool_id = (
self.current_tool_id
) # Save the ID of the tool that's completing
# Only remove the processed portion, keep unprocessed content
self._buffer = current_text[start_idx + end_idx:]
if self.current_tool_id < len(self.prev_tool_call_arr):
self.prev_tool_call_arr[
self.current_tool_id].clear()
self.current_tool_name_sent = False
self.streamed_args_for_tool[self.current_tool_id] = ""
self.current_tool_id += 1
# If the tool is still being parsed, send incremental changes
elif prev_arguments:
prev_args_json = json.dumps(prev_arguments)
if cur_args_json != prev_args_json:
prefix = find_common_prefix(prev_args_json,
cur_args_json)
argument_diff = prefix[sent:]
# Send the argument diff if there's something new
if argument_diff is not None:
# Use the correct tool_index: completing_tool_id for completed tools, current_tool_id for ongoing
tool_index_to_use = (completing_tool_id
if is_current_complete else
self.current_tool_id)
res = StreamingParseResult(calls=[
ToolCallItem(
tool_index=tool_index_to_use,
parameters=argument_diff,
)
], )
if not is_current_complete:
self.streamed_args_for_tool[
self.current_tool_id] += argument_diff
# Update prev_tool_call_arr with current state
if self.current_tool_id >= 0:
# Ensure prev_tool_call_arr is large enough
while len(self.prev_tool_call_arr) <= self.current_tool_id:
self.prev_tool_call_arr.append({})
self.prev_tool_call_arr[
self.current_tool_id] = current_tool_call
return res
except Exception as e:
logger.error(f"Error in parse_streaming_increment: {e}")
return StreamingParseResult()
@abstractmethod
def has_tool_call(self, text: str) -> bool:
"""
Check if the given text contains function call markers specific to this format.
"""
raise NotImplementedError()
def supports_structural_tag(self) -> bool:
"""Return True if this detector supports structural tag format."""
return True
@abstractmethod
def structure_info(self) -> _GetInfoFunc:
"""
Return a function that creates StructureInfo for constrained generation.
The returned function takes a tool name and returns a StructureInfo object
containing the begin/end patterns and trigger tokens needed for constrained
generation of function calls in this format.
Returns:
A function that takes a tool name (str) and returns StructureInfo
"""
raise NotImplementedError()

View File

@ -0,0 +1,35 @@
# Adapted from https://github.com/sgl-project/sglang/blob/083629c23564e1a64deaa052f1df5c5d914358d8/python/sglang/srt/function_call/qwen25_detector.py
from dataclasses import dataclass
from typing import Callable, List, Optional
from pydantic import BaseModel
class ToolCallItem(BaseModel):
"""Simple encapsulation of the parsed ToolCall result for easier usage in streaming contexts."""
tool_index: int
name: Optional[str] = None
parameters: str # JSON string
class StreamingParseResult(BaseModel):
"""Result of streaming incremental parsing."""
normal_text: str = ""
calls: List[ToolCallItem] = []
@dataclass
class StructureInfo:
begin: str
end: str
trigger: str
"""
Helper alias of function
Usually it is a function that takes a name string and returns a StructureInfo object,
which can be used to construct a structural_tag object
"""
_GetInfoFunc = Callable[[str], StructureInfo]

View File

@ -0,0 +1,114 @@
# Adapted from https://github.com/sgl-project/sglang/blob/083629c23564e1a64deaa052f1df5c5d914358d8/python/sglang/srt/function_call/qwen25_detector.py
import json
import re
from typing import List
from tensorrt_llm.logger import logger
from ..openai_protocol import ChatCompletionToolsParam as Tool
from .base_tool_parser import BaseToolParser
from .core_types import StreamingParseResult, StructureInfo, _GetInfoFunc
class Qwen3ToolParser(BaseToolParser):
"""
Detector for Qwen 2.5 and Qwen 3 model function call format.
Format Structure:
```
<tool_call>\n{"name":"func1", "arguments":{...}}\n</tool_call>\n<tool_call>\n{"name":"func2", "arguments":{...}}\n</tool_call>
```
Key Components:
- Tool Call Tags: `<tool_call>` and `</tool_call>` wrap each individual call
- Function Call Object: JSON object with "name" and "arguments" fields
Reference: https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct?chat_template=default
"""
def __init__(self):
"""
Initializes the detector with necessary state variables.
"""
super().__init__()
self.bot_token = "<tool_call>\n" # nosec B105
self.eot_token = "\n</tool_call>" # nosec B105
self.tool_call_separator = "\n"
self._normal_text_buffer = "" # Buffer for handling partial end tokens
def has_tool_call(self, text: str) -> bool:
"""Check if the text contains a Qwen 3 format tool call."""
return self.bot_token in text
def detect_and_parse(self, text: str,
tools: List[Tool]) -> StreamingParseResult:
"""
One-time parsing: Detects and parses tool calls in the provided text.
:param text: The complete text to parse.
:param tools: List of available tools.
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
"""
idx = text.find(self.bot_token)
normal_text = text[:idx].strip() if idx != -1 else text
if self.bot_token not in text:
return StreamingParseResult(normal_text=normal_text, calls=[])
# Find all <tool_call>\n...\n</tool_call> blocks
pattern = rf"{re.escape(self.bot_token)}(.*?){re.escape(self.eot_token)}"
match_result_list = re.findall(pattern, text, re.DOTALL)
calls = []
for match_result in match_result_list:
try:
parsed_call = json.loads(match_result.strip())
calls.extend(self.parse_base_json(parsed_call, tools))
except json.JSONDecodeError as e:
logger.warning(
f"Failed to parse JSON part: {match_result}, JSON parse error: {str(e)}"
)
continue
return StreamingParseResult(normal_text=normal_text, calls=calls)
def parse_streaming_increment(self, new_text: str,
tools: List[Tool]) -> StreamingParseResult:
"""
Streaming incremental parsing for Qwen 3 tool calls.
Uses base class implementation with buffering to handle partial end tokens.
"""
result = super().parse_streaming_increment(new_text, tools)
# Handle partial end tokens that are streamed character by character
if result.normal_text:
self._normal_text_buffer += result.normal_text
# Check if buffer contains complete end token (without leading newline)
end_token_without_newline = self.eot_token[1:] # "</tool_call>"
if end_token_without_newline in self._normal_text_buffer:
cleaned_text = self._normal_text_buffer.replace(
end_token_without_newline, "")
self._normal_text_buffer = ""
result.normal_text = cleaned_text
else:
# Check if buffer might contain partial end token at the end
partial_match_len = self._ends_with_partial_token(
self._normal_text_buffer, end_token_without_newline)
if partial_match_len:
# Keep potential partial match in buffer, return the rest
result.normal_text = self._normal_text_buffer[:
-partial_match_len]
self._normal_text_buffer = self._normal_text_buffer[
-partial_match_len:]
else:
# No partial match, return all buffered text
result.normal_text = self._normal_text_buffer
self._normal_text_buffer = ""
return result
def structure_info(self) -> _GetInfoFunc:
return lambda name: StructureInfo(
begin='<tool_call>\n{"name":"' + name + '", "arguments":',
end="}\n</tool_call>",
trigger="<tool_call>",
)

View File

@ -0,0 +1,21 @@
from typing import Type
from .base_tool_parser import BaseToolParser
from .qwen3_tool_parser import Qwen3ToolParser
class ToolParserFactory:
parsers: dict[str, Type[BaseToolParser]] = {
"qwen3": Qwen3ToolParser,
}
@staticmethod
def create_tool_parser(tool_parser: str) -> BaseToolParser:
try:
tool_parser_class = ToolParserFactory.parsers[tool_parser.lower()]
return tool_parser_class()
except KeyError as e:
raise ValueError(
f"Invalid tool_parser: {tool_parser}\n"
f"Supported parsers: {list(ToolParserFactory.parsers.keys())}"
) from e

View File

@ -0,0 +1,56 @@
# Adapted from https://github.com/sgl-project/sglang/blob/083629c23564e1a64deaa052f1df5c5d914358d8/python/sglang/srt/function_call/qwen25_detector.py
import json
from json import JSONDecodeError, JSONDecoder
from json.decoder import WHITESPACE
from typing import Any
import partial_json_parser
from partial_json_parser.core.options import Allow
def find_common_prefix(s1: str, s2: str) -> str:
prefix = ""
min_length = min(len(s1), len(s2))
for i in range(0, min_length):
if s1[i] == s2[i]:
prefix += s1[i]
else:
break
return prefix
def partial_json_loads(input_str: str, flags: Allow) -> tuple[Any, int]:
"""
Parse incomplete or partial JSON strings commonly encountered during streaming.
Args:
input_str (str): The potentially incomplete JSON string to parse.
flags (Allow): Bitwise flags controlling what types of partial data are allowed.
Common flags include:
- Allow.STR: Allow partial strings (e.g., '"hello wo' -> 'hello wo')
- Allow.OBJ: Allow partial objects (e.g., '{"key":' -> {'key': None})
- Allow.ARR: Allow partial arrays (e.g., '[1, 2,' -> [1, 2])
- Allow.ALL: Allow all types of partial data
Returns:
Tuple[Any, int]: A tuple containing:
- parsed_object: The Python object parsed from the JSON
- consumed_length: Number of characters consumed from input_str
"""
try:
return (partial_json_parser.loads(input_str, flags), len(input_str))
except (JSONDecodeError, IndexError) as e:
msg = getattr(e, "msg", str(e))
if "Extra data" in msg or "pop from empty list" in msg:
start = WHITESPACE.match(input_str, 0).end()
obj, end = JSONDecoder().raw_decode(input_str, start)
return obj, end
raise
def is_complete_json(input_str: str) -> bool:
try:
json.loads(input_str)
return True
except JSONDecodeError:
return False

View File

@ -1613,6 +1613,13 @@ def test_openai_reasoning(llm_root, llm_venv, backend: str):
])
def test_openai_tool_call(llm_root, llm_venv):
test_root = unittest_path() / "llmapi" / "apps"
llm_venv.run_cmd(
["-m", "pytest",
str(test_root / "_test_openai_tool_call.py")])
@pytest.mark.parametrize("sampler", ["torch_sampler", "trtllm_sampler"])
def test_openai_completions_with_logit_bias(llm_root, llm_venv, sampler: str):
test_root = unittest_path() / "llmapi" / "apps"

View File

@ -57,6 +57,7 @@ l0_a10:
- test_e2e.py::test_trtllm_serve_top_logprobs[pytorch]
- test_e2e.py::test_openai_misc_example[pytorch]
- test_e2e.py::test_openai_reasoning[pytorch]
- test_e2e.py::test_openai_tool_call
- test_e2e.py::test_openai_completions_example[pytorch]
- test_e2e.py::test_openai_chat_example[pytorch] TIMEOUT (90)
- test_e2e.py::test_trtllm_bench_request_rate_and_concurrency[enable_concurrency-]
@ -71,6 +72,8 @@ l0_a10:
- unittest/llmapi/test_additional_model_outputs.py
# executor
- unittest/executor/test_rpc.py
# trtllm-serve CPU-only
- unittest/llmapi/apps/test_tool_parsers.py
- condition:
ranges:
system_gpu_count:

View File

@ -1,3 +1,3 @@
This directory contains the end-to-end tests for the LLM API applications in `examples/apps`.
This directory contains the end-to-end tests for `trtllm-serve`.
These tests are triggered in the `test_e2e.py`.

View File

@ -0,0 +1,121 @@
import json
import os
import sys
import openai
import pytest
from .openai_server import RemoteOpenAIServer
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from test_llm import get_model_path
TOOLS = [
{
"type": "function",
"function": {
"name": "get_current_temperature",
"description": "Get current temperature at a location.",
"parameters": {
"type": "object",
"properties": {
"location": {
"type":
"string",
"description":
'The location to get the temperature for, in the format "City, State, Country".',
},
"unit": {
"type":
"string",
"enum": ["celsius", "fahrenheit"],
"description":
'The unit to return the temperature in. Defaults to "celsius".',
},
},
"required": ["location"],
},
},
},
]
def get_current_temperature(location: str, unit: str = "celsius") -> dict:
return {"temperature": 20 if unit == "celsius" else 68}
@pytest.fixture(scope="module", ids=["Qwen3-0.6B"])
def model_name() -> str:
return "Qwen3/Qwen3-0.6B"
@pytest.fixture(scope="module")
def server(model_name: str):
model_path = get_model_path(model_name)
args = ["--tool_parser", "qwen3"]
with RemoteOpenAIServer(model_path, cli_args=args) as remote_server:
yield remote_server
@pytest.fixture(scope="module")
def client(server: RemoteOpenAIServer):
return server.get_async_client()
@pytest.mark.asyncio(loop_scope="module")
async def test_tool_parser(client: openai.AsyncOpenAI, model_name: str):
response = await client.chat.completions.create(
model=model_name,
messages=[{
"role": "user",
"content": "What's the temperature in San Francisco now?"
}],
tools=TOOLS)
assert response.choices[0].finish_reason == "tool_calls"
message = response.choices[0].message
assert message.content is not None
assert message.tool_calls is not None
assert len(message.tool_calls) == 1
tool_call = message.tool_calls[0]
assert tool_call.function.name == "get_current_temperature"
args = json.loads(tool_call.function.arguments)
get_current_temperature(**args)
@pytest.mark.asyncio(loop_scope="module")
async def test_tool_parser_streaming(client: openai.AsyncOpenAI,
model_name: str):
response = await client.chat.completions.create(
model=model_name,
messages=[{
"role": "user",
"content": "What's the temperature in San Francisco now?"
}],
tools=TOOLS,
stream=True)
tool_id = None
tool_name = None
parameters = ""
finish_reason = None
async for chunk in response:
if chunk.choices[0].delta.tool_calls:
tool_call = chunk.choices[0].delta.tool_calls[0]
if tool_call.id:
if tool_id is not None:
raise RuntimeError("tool_id already exists")
tool_id = tool_call.id
if tool_call.function.name:
if tool_name is not None:
raise RuntimeError("tool_name already exists")
tool_name = tool_call.function.name
if tool_call.function.arguments:
parameters += tool_call.function.arguments
if chunk.choices[0].finish_reason:
finish_reason = chunk.choices[0].finish_reason
assert tool_id is not None
assert tool_name == "get_current_temperature"
assert finish_reason == "tool_calls"
assert parameters
args = json.loads(parameters)
get_current_temperature(**args)

View File

@ -0,0 +1,597 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import pytest
from tensorrt_llm.serve.openai_protocol import (ChatCompletionToolsParam,
FunctionDefinition)
from tensorrt_llm.serve.tool_parser.base_tool_parser import BaseToolParser
from tensorrt_llm.serve.tool_parser.core_types import StructureInfo
from tensorrt_llm.serve.tool_parser.qwen3_tool_parser import Qwen3ToolParser
# Test fixtures for common tools
@pytest.fixture
def sample_tools():
"""Sample tools for testing."""
return [
ChatCompletionToolsParam(
type="function",
function=FunctionDefinition(name="get_weather",
description="Get the current weather",
parameters={
"type": "object",
"properties": {
"location": {
"type":
"string",
"description":
"The city and state"
},
"unit": {
"type":
"string",
"enum":
["celsius", "fahrenheit"]
}
},
"required": ["location"]
})),
ChatCompletionToolsParam(
type="function",
function=FunctionDefinition(name="search_web",
description="Search the web",
parameters={
"type": "object",
"properties": {
"query": {
"type":
"string",
"description":
"The search query"
}
},
"required": ["query"]
}))
]
# Concrete implementation of BaseToolParser for testing
class ConcreteToolParser(BaseToolParser):
"""Concrete implementation of BaseToolParser for testing abstract methods."""
def __init__(self):
super().__init__()
self.bot_token = "[TOOL_CALLS] "
self.eot_token = "[/TOOL_CALLS]"
def has_tool_call(self, text: str) -> bool:
return self.bot_token in text
def detect_and_parse(self, text: str, tools):
# Placeholder to avoid NotImplementedError
pass
def structure_info(self):
return lambda name: StructureInfo(
begin=f'[TOOL_CALLS] {{"name":"{name}", "arguments":',
end="}[/TOOL_CALLS]",
trigger="[TOOL_CALLS]")
# ============================================================================
# BaseToolParser Tests
# ============================================================================
class TestBaseToolParser:
"""Test suite for BaseToolParser class."""
def test_initialization(self):
"""Test that BaseToolParser initializes correctly."""
parser = ConcreteToolParser()
assert parser._buffer == ""
assert parser.prev_tool_call_arr == []
assert parser.current_tool_id == -1
assert parser.current_tool_name_sent is False
assert parser.streamed_args_for_tool == []
def test_get_tool_indices(self, sample_tools):
"""Test _get_tool_indices correctly maps tool names to indices."""
parser = ConcreteToolParser()
indices = parser._get_tool_indices(sample_tools)
assert len(indices) == 2
assert indices["get_weather"] == 0
assert indices["search_web"] == 1
def test_get_tool_indices_empty(self):
"""Test _get_tool_indices with empty tools list."""
parser = ConcreteToolParser()
indices = parser._get_tool_indices([])
assert indices == {}
def test_parse_base_json_single_tool(self, sample_tools):
"""Test parse_base_json with a single tool call."""
parser = ConcreteToolParser()
action = {
"name": "get_weather",
"parameters": {
"location": "San Francisco"
}
}
results = parser.parse_base_json(action, sample_tools)
assert len(results) == 1
assert results[0].name == "get_weather"
assert json.loads(results[0].parameters) == {
"location": "San Francisco"
}
def test_parse_base_json_with_arguments_key(self, sample_tools):
"""Test parse_base_json handles 'arguments' key instead of 'parameters'."""
parser = ConcreteToolParser()
action = {"name": "search_web", "arguments": {"query": "TensorRT"}}
results = parser.parse_base_json(action, sample_tools)
assert len(results) == 1
assert results[0].name == "search_web"
assert json.loads(results[0].parameters) == {"query": "TensorRT"}
def test_parse_base_json_multiple_tools(self, sample_tools):
"""Test parse_base_json with multiple tool calls."""
parser = ConcreteToolParser()
actions = [{
"name": "get_weather",
"parameters": {
"location": "Boston"
}
}, {
"name": "search_web",
"arguments": {
"query": "Python"
}
}]
results = parser.parse_base_json(actions, sample_tools)
assert len(results) == 2
assert results[0].name == "get_weather"
assert results[1].name == "search_web"
def test_parse_base_json_undefined_function(self, sample_tools):
"""Test parse_base_json handles undefined function names gracefully."""
parser = ConcreteToolParser()
action = {"name": "undefined_function", "parameters": {}}
results = parser.parse_base_json(action, sample_tools)
# Should return empty list and log warning
assert len(results) == 0
def test_parse_base_json_missing_parameters(self, sample_tools):
"""Test parse_base_json handles missing parameters."""
parser = ConcreteToolParser()
action = {"name": "get_weather"}
results = parser.parse_base_json(action, sample_tools)
assert len(results) == 1
assert json.loads(results[0].parameters) == {}
def test_ends_with_partial_token(self):
"""Test _ends_with_partial_token detection."""
parser = ConcreteToolParser()
# Partial token at end (bot_token starts with the suffix)
assert parser._ends_with_partial_token("Some text [TOOL",
"[TOOL_CALLS] ") == 5
assert parser._ends_with_partial_token("Some text [",
"[TOOL_CALLS] ") == 1
assert parser._ends_with_partial_token("Some text [TOOL_CALLS",
"[TOOL_CALLS] ") == 11
# No partial token
assert parser._ends_with_partial_token("Some text",
"[TOOL_CALLS] ") == 0
assert parser._ends_with_partial_token("Some text [XYZ",
"[TOOL_CALLS] ") == 0
# Complete token at end (entire buffer is bot_token prefix but not complete match)
# When buffer equals bot_token, it returns 0 because it's not a partial anymore
assert parser._ends_with_partial_token("text [TOOL_CALLS] ",
"[TOOL_CALLS] ") == 0
def test_parse_streaming_increment_no_tool_call(self, sample_tools):
"""Test streaming parser returns normal text when no tool call present."""
parser = ConcreteToolParser()
result = parser.parse_streaming_increment("Hello, world!", sample_tools)
assert result.normal_text == "Hello, world!"
assert len(result.calls) == 0
def test_parse_streaming_increment_partial_bot_token(self, sample_tools):
"""Test streaming parser buffers partial bot token."""
parser = ConcreteToolParser()
# Send partial bot token
result = parser.parse_streaming_increment("[TOOL", sample_tools)
# Should buffer and return nothing
assert result.normal_text == ""
assert len(result.calls) == 0
assert parser._buffer == "[TOOL"
def test_parse_streaming_increment_tool_name(self, sample_tools):
"""Test streaming parser handles tool name streaming."""
parser = ConcreteToolParser()
# Send bot token with partial JSON containing name
result = parser.parse_streaming_increment(
'[TOOL_CALLS] {"name":"get_weather"', sample_tools)
# Should send tool name with empty parameters
assert len(result.calls) == 1
assert result.calls[0].name == "get_weather"
assert result.calls[0].parameters == ""
assert result.calls[0].tool_index == 0
assert parser.current_tool_name_sent is True
def test_parse_streaming_increment_tool_arguments(self, sample_tools):
"""Test streaming parser handles incremental argument streaming."""
parser = ConcreteToolParser()
# First send tool name
result1 = parser.parse_streaming_increment(
'[TOOL_CALLS] {"name":"get_weather"', sample_tools)
# Should send tool name
assert len(result1.calls) == 1
assert result1.calls[0].name == "get_weather"
# Then send complete arguments (parser needs complete JSON to parse incrementally)
result2 = parser.parse_streaming_increment(
',"arguments":{"location":"San Francisco"}}', sample_tools)
# Should stream arguments or complete the tool call
# The base implementation uses partial JSON parsing, so it may return results
assert result2 is not None # Just verify it doesn't crash
def test_parse_streaming_increment_complete_tool(self, sample_tools):
"""Test streaming parser handles complete tool call."""
parser = ConcreteToolParser()
# Send complete tool call in one chunk
result = parser.parse_streaming_increment(
'[TOOL_CALLS] {"name":"get_weather","arguments":{"location":"Boston"}}',
sample_tools)
# Should have sent tool name (first call)
assert len(result.calls) == 1
assert result.calls[0].name == "get_weather"
def test_parse_streaming_increment_invalid_tool_name(self, sample_tools):
"""Test streaming parser handles invalid tool name."""
parser = ConcreteToolParser()
# Send invalid tool name
result = parser.parse_streaming_increment(
'[TOOL_CALLS] {"name":"invalid_tool"', sample_tools)
# Should reset state
assert len(result.calls) == 0
assert parser._buffer == ""
assert parser.current_tool_id == -1
def test_supports_structural_tag(self):
"""Test supports_structural_tag returns True."""
parser = ConcreteToolParser()
assert parser.supports_structural_tag() is True
def test_structure_info(self):
"""Test structure_info returns proper function."""
parser = ConcreteToolParser()
func = parser.structure_info()
info = func("test_function")
assert isinstance(info, StructureInfo)
assert "test_function" in info.begin
assert info.trigger == "[TOOL_CALLS]"
# ============================================================================
# Qwen3ToolParser Tests
# ============================================================================
class TestQwen3ToolParser:
"""Test suite for Qwen3ToolParser class."""
def test_initialization(self):
"""Test that Qwen3ToolParser initializes correctly."""
parser = Qwen3ToolParser()
assert parser.bot_token == "<tool_call>\n"
assert parser.eot_token == "\n</tool_call>"
assert parser.tool_call_separator == "\n"
assert parser._normal_text_buffer == ""
def test_has_tool_call_true(self):
"""Test has_tool_call returns True when tool call is present."""
parser = Qwen3ToolParser()
text = 'Some text <tool_call>\n{"name":"get_weather"}\n</tool_call>'
assert parser.has_tool_call(text) is True
def test_has_tool_call_false(self):
"""Test has_tool_call returns False when no tool call present."""
parser = Qwen3ToolParser()
text = "Just some regular text without tool calls"
assert parser.has_tool_call(text) is False
def test_detect_and_parse_no_tool_call(self, sample_tools):
"""Test detect_and_parse with text containing no tool calls."""
parser = Qwen3ToolParser()
text = "This is just a regular response."
result = parser.detect_and_parse(text, sample_tools)
assert result.normal_text == "This is just a regular response."
assert len(result.calls) == 0
def test_detect_and_parse_single_tool(self, sample_tools):
"""Test detect_and_parse with a single tool call."""
parser = Qwen3ToolParser()
text = 'Normal text\n<tool_call>\n{"name":"get_weather","arguments":{"location":"NYC"}}\n</tool_call>'
result = parser.detect_and_parse(text, sample_tools)
assert result.normal_text == "Normal text"
assert len(result.calls) == 1
assert result.calls[0].name == "get_weather"
assert json.loads(result.calls[0].parameters) == {"location": "NYC"}
def test_detect_and_parse_multiple_tools(self, sample_tools):
"""Test detect_and_parse with multiple tool calls."""
parser = Qwen3ToolParser()
text = (
'<tool_call>\n{"name":"get_weather","arguments":{"location":"LA"}}\n</tool_call>\n'
'<tool_call>\n{"name":"search_web","arguments":{"query":"AI"}}\n</tool_call>'
)
result = parser.detect_and_parse(text, sample_tools)
assert len(result.calls) == 2
assert result.calls[0].name == "get_weather"
assert result.calls[1].name == "search_web"
def test_detect_and_parse_malformed_json(self, sample_tools):
"""Test detect_and_parse handles malformed JSON gracefully."""
parser = Qwen3ToolParser()
text = '<tool_call>\n{"name":"get_weather","arguments":MALFORMED}\n</tool_call>'
result = parser.detect_and_parse(text, sample_tools)
# Should return empty calls due to JSON parsing error
assert len(result.calls) == 0
def test_detect_and_parse_with_parameters_key(self, sample_tools):
"""Test detect_and_parse handles 'parameters' key."""
parser = Qwen3ToolParser()
text = '<tool_call>\n{"name":"search_web","parameters":{"query":"test"}}\n</tool_call>'
result = parser.detect_and_parse(text, sample_tools)
assert len(result.calls) == 1
assert result.calls[0].name == "search_web"
assert json.loads(result.calls[0].parameters) == {"query": "test"}
def test_parse_streaming_increment_normal_text(self, sample_tools):
"""Test streaming parser handles normal text without tool calls."""
parser = Qwen3ToolParser()
result = parser.parse_streaming_increment("Hello, how can I help?",
sample_tools)
assert result.normal_text == "Hello, how can I help?"
assert len(result.calls) == 0
def test_parse_streaming_increment_partial_bot_token(self, sample_tools):
"""Test streaming parser buffers partial bot token."""
parser = Qwen3ToolParser()
# Send partial bot token
result = parser.parse_streaming_increment("<tool", sample_tools)
# Should buffer
assert result.normal_text == ""
assert len(result.calls) == 0
def test_parse_streaming_increment_complete_tool_call(self, sample_tools):
"""Test streaming parser with complete tool call in chunks."""
parser = Qwen3ToolParser()
# Send bot token
parser.parse_streaming_increment("<tool_call>\n", sample_tools)
# Send partial JSON with name
result = parser.parse_streaming_increment('{"name":"get_weather"',
sample_tools)
# Should send tool name
assert len(result.calls) == 1
assert result.calls[0].name == "get_weather"
assert result.calls[0].parameters == ""
# Send arguments
result = parser.parse_streaming_increment(
',"arguments":{"location":"SF"}}\n</tool_call>', sample_tools)
# Should stream arguments
assert len(result.calls) == 1
assert json.loads(result.calls[0].parameters) == {"location": "SF"}
def test_parse_streaming_increment_end_token_handling(self, sample_tools):
"""Test streaming parser handles end token correctly."""
parser = Qwen3ToolParser()
# Send complete tool call
parser.parse_streaming_increment(
'<tool_call>\n{"name":"get_weather","arguments":{"location":"NYC"}}\n</tool_call>',
sample_tools)
# The end token should be removed from normal text
# Check buffer state
assert parser._normal_text_buffer == ""
def test_parse_streaming_increment_multiple_tools_streaming(
self, sample_tools):
"""Test streaming parser handles multiple tool calls."""
parser = Qwen3ToolParser()
# First tool
parser.parse_streaming_increment('<tool_call>\n', sample_tools)
parser.parse_streaming_increment(
'{"name":"get_weather","arguments":{"location":"NYC"}}\n</tool_call>\n',
sample_tools)
# Second tool
parser.parse_streaming_increment('<tool_call>\n', sample_tools)
result = parser.parse_streaming_increment('{"name":"search_web"',
sample_tools)
# Should have started second tool
assert result.calls[0].name == "search_web"
assert result.calls[0].parameters == ""
assert result.calls[0].tool_index == 1
def test_structure_info_function(self):
"""Test structure_info returns correct lambda function."""
parser = Qwen3ToolParser()
func = parser.structure_info()
info = func("test_function")
assert isinstance(info, StructureInfo)
assert info.begin == '<tool_call>\n{"name":"test_function", "arguments":'
assert info.end == "}\n</tool_call>"
assert info.trigger == "<tool_call>"
def test_structure_info_different_names(self):
"""Test structure_info works with different function names."""
parser = Qwen3ToolParser()
func = parser.structure_info()
info1 = func("get_weather")
info2 = func("search_web")
assert "get_weather" in info1.begin
assert "search_web" in info2.begin
assert info1.end == info2.end == "}\n</tool_call>"
def test_qwen3_format_compliance(self, sample_tools):
"""Test that Qwen3ToolParser follows the documented format structure."""
parser = Qwen3ToolParser()
# Test the exact format from the docstring
text = '<tool_call>\n{"name":"get_weather", "arguments":{"location":"Tokyo"}}\n</tool_call>'
result = parser.detect_and_parse(text, sample_tools)
assert len(result.calls) == 1
assert result.calls[0].name == "get_weather"
assert json.loads(result.calls[0].parameters) == {"location": "Tokyo"}
def test_undefined_tool_in_qwen3_format(self, sample_tools):
"""Test Qwen3ToolParser handles undefined tool gracefully."""
parser = Qwen3ToolParser()
text = '<tool_call>\n{"name":"undefined_func","arguments":{}}\n</tool_call>'
result = parser.detect_and_parse(text, sample_tools)
# Should not return any calls for undefined function
assert len(result.calls) == 0
# ============================================================================
# Integration Tests
# ============================================================================
class TestToolParserIntegration:
"""Integration tests for tool parsers."""
def test_end_to_end_single_tool(self, sample_tools):
"""Test end-to-end parsing of a single tool call."""
parser = Qwen3ToolParser()
# Simulate streaming
chunks = [
"<tool_call>\n", '{"name":"get', '_weather"', ',"arguments":',
'{"location"', ':"Paris"}}\n', '</tool_call>'
]
results = []
for chunk in chunks:
result = parser.parse_streaming_increment(chunk, sample_tools)
if result.calls or result.normal_text:
results.append(result)
# Should have received tool name and arguments
assert any(r.calls for r in results)
def test_mixed_content_and_tool_calls(self, sample_tools):
"""Test parsing text that mixes normal content with tool calls."""
parser = Qwen3ToolParser()
text = (
'I will check the weather for you.\n'
'<tool_call>\n{"name":"get_weather","arguments":{"location":"London"}}\n</tool_call>\n'
'Let me search that for you.')
result = parser.detect_and_parse(text, sample_tools)
assert "I will check the weather for you." in result.normal_text
assert len(result.calls) == 1
assert result.calls[0].name == "get_weather"
def test_parser_state_reset(self, sample_tools):
"""Test that parser state can be used for multiple requests."""
parser = Qwen3ToolParser()
# First request
result1 = parser.detect_and_parse(
'<tool_call>\n{"name":"get_weather","arguments":{"location":"NYC"}}\n</tool_call>',
sample_tools)
# Reset internal state for new request
parser2 = Qwen3ToolParser()
# Second request
result2 = parser2.detect_and_parse(
'<tool_call>\n{"name":"search_web","arguments":{"query":"test"}}\n</tool_call>',
sample_tools)
assert result1.calls[0].name == "get_weather"
assert result2.calls[0].name == "search_web"
if __name__ == "__main__":
pytest.main([__file__, "-v"])