mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[TRTLLM-8214][feat] Support Qwen3 tool parser (#8216)
Signed-off-by: Pengyun Lin <81065165+LinPoly@users.noreply.github.com>
This commit is contained in:
parent
741183917c
commit
2aade46d18
@ -78,3 +78,4 @@ nvidia-cutlass-dsl==4.2.1; python_version >= "3.10"
|
||||
numba-cuda>=0.19.0 # WAR for nvbugs/5501820
|
||||
plotly
|
||||
numexpr<2.14.0 # WAR for attempted use of nonexistent numpy.typing
|
||||
partial_json_parser
|
||||
|
||||
@ -33,6 +33,7 @@ from tensorrt_llm.llmapi.mpi_session import find_free_port
|
||||
from tensorrt_llm.llmapi.reasoning_parser import ReasoningParserFactory
|
||||
from tensorrt_llm.logger import logger, severity_map
|
||||
from tensorrt_llm.serve import OpenAIDisaggServer, OpenAIServer
|
||||
from tensorrt_llm.serve.tool_parser import ToolParserFactory
|
||||
|
||||
# Global variable to store the Popen object of the child process
|
||||
_child_p_global: Optional[subprocess.Popen] = None
|
||||
@ -150,6 +151,7 @@ def launch_server(
|
||||
host: str,
|
||||
port: int,
|
||||
llm_args: dict,
|
||||
tool_parser: Optional[str] = None,
|
||||
metadata_server_cfg: Optional[MetadataServerConfig] = None,
|
||||
server_role: Optional[ServerRole] = None,
|
||||
disagg_cluster_config: Optional[DisaggClusterConfig] = None,
|
||||
@ -173,6 +175,7 @@ def launch_server(
|
||||
|
||||
server = OpenAIServer(llm=llm,
|
||||
model=model,
|
||||
tool_parser=tool_parser,
|
||||
server_role=server_role,
|
||||
metadata_server_cfg=metadata_server_cfg,
|
||||
disagg_cluster_config=disagg_cluster_config,
|
||||
@ -311,6 +314,12 @@ class ChoiceWithAlias(click.Choice):
|
||||
default=None,
|
||||
help="[Experimental] Specify the parser for reasoning models.",
|
||||
)
|
||||
@click.option(
|
||||
"--tool_parser",
|
||||
type=click.Choice(ToolParserFactory.parsers.keys()),
|
||||
default=None,
|
||||
help="[Experimental] Specify the parser for tool models.",
|
||||
)
|
||||
@click.option("--metadata_server_config_file",
|
||||
type=str,
|
||||
default=None,
|
||||
@ -352,7 +361,8 @@ def serve(
|
||||
gpus_per_node: Optional[int], kv_cache_free_gpu_memory_fraction: float,
|
||||
num_postprocess_workers: int, trust_remote_code: bool,
|
||||
extra_llm_api_options: Optional[str], reasoning_parser: Optional[str],
|
||||
metadata_server_config_file: Optional[str], server_role: Optional[str],
|
||||
tool_parser: Optional[str], metadata_server_config_file: Optional[str],
|
||||
server_role: Optional[str],
|
||||
fail_fast_on_attention_window_too_large: bool,
|
||||
otlp_traces_endpoint: Optional[str], enable_chunked_prefill: bool,
|
||||
disagg_cluster_uri: Optional[str], media_io_kwargs: Optional[str]):
|
||||
@ -423,8 +433,8 @@ def serve(
|
||||
|
||||
multimodal_server_config = MultimodalServerConfig(
|
||||
media_io_kwargs=parsed_media_io_kwargs)
|
||||
launch_server(host, port, llm_args, metadata_server_cfg, server_role,
|
||||
disagg_cluster_config, multimodal_server_config)
|
||||
launch_server(host, port, llm_args, tool_parser, metadata_server_cfg,
|
||||
server_role, disagg_cluster_config, multimodal_server_config)
|
||||
|
||||
|
||||
@click.command("mm_embedding_serve")
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import uuid
|
||||
from functools import partial
|
||||
from typing import (Any, Callable, Coroutine, Dict, Iterable, List, Literal,
|
||||
Optional, Tuple, TypeAlias, TypedDict, Union, cast)
|
||||
@ -220,3 +221,11 @@ def check_multiple_response(n: int, backend: Optional[str]):
|
||||
if n > 1 and backend == "pytorch":
|
||||
raise ValueError(
|
||||
"Multiple response is not supported in PyTorch workflow")
|
||||
|
||||
|
||||
def make_tool_call_id(id_type: str = "random", func_name=None, idx=None):
|
||||
if id_type == "kimi_k2":
|
||||
return f"functions.{func_name}:{idx}"
|
||||
else:
|
||||
# by default return random
|
||||
return f"chatcmpl-tool-{uuid.uuid4().hex}"
|
||||
|
||||
@ -78,12 +78,14 @@ class OpenAIServer:
|
||||
def __init__(self,
|
||||
llm: Union[LLM, MultimodalEncoder],
|
||||
model: str,
|
||||
tool_parser: Optional[str],
|
||||
server_role: Optional[ServerRole],
|
||||
metadata_server_cfg: MetadataServerConfig,
|
||||
disagg_cluster_config: Optional[DisaggClusterConfig] = None,
|
||||
multimodal_server_config: Optional[MultimodalServerConfig] = None):
|
||||
self.llm = llm
|
||||
self.tokenizer = llm.tokenizer
|
||||
self.tool_parser = tool_parser
|
||||
self.metadata_server = create_metadata_server(metadata_server_cfg)
|
||||
self.disagg_cluster_config = disagg_cluster_config
|
||||
self.multimodal_server_config = multimodal_server_config
|
||||
@ -532,6 +534,7 @@ class OpenAIServer:
|
||||
prompt["multi_modal_data"] = mm_data
|
||||
|
||||
postproc_args.reasoning_parser = self.llm.args.reasoning_parser
|
||||
postproc_args.tool_parser = self.tool_parser
|
||||
if conversation and conversation[-1].get(
|
||||
"content") and conversation[-1].get("role") == get_role():
|
||||
postproc_args.last_message_content = conversation[-1]["content"]
|
||||
|
||||
@ -10,6 +10,7 @@ from ..llmapi.reasoning_parser import (BaseReasoningParser,
|
||||
ReasoningParserFactory)
|
||||
from ..llmapi.tokenizer import TransformersTokenizer
|
||||
# yapf: disable
|
||||
from .chat_utils import make_tool_call_id
|
||||
from .harmony_adapter import (handle_non_streaming_response,
|
||||
handle_streaming_response)
|
||||
from .openai_protocol import (ChatCompletionLogProbs,
|
||||
@ -23,9 +24,13 @@ from .openai_protocol import (ChatCompletionLogProbs,
|
||||
CompletionRequest, CompletionResponse,
|
||||
CompletionResponseChoice,
|
||||
CompletionResponseStreamChoice,
|
||||
CompletionStreamResponse, DeltaMessage,
|
||||
FunctionCall, PromptTokensDetails, StreamOptions,
|
||||
ToolCall, UsageInfo, to_disaggregated_params)
|
||||
CompletionStreamResponse, DeltaFunctionCall,
|
||||
DeltaMessage, DeltaToolCall, FunctionCall,
|
||||
PromptTokensDetails, StreamOptions, ToolCall,
|
||||
UsageInfo, to_disaggregated_params)
|
||||
from .tool_parser.base_tool_parser import BaseToolParser
|
||||
from .tool_parser.core_types import ToolCallItem
|
||||
from .tool_parser.tool_parser_factory import ToolParserFactory
|
||||
|
||||
# yapf: enable
|
||||
|
||||
@ -33,8 +38,8 @@ from .openai_protocol import (ChatCompletionLogProbs,
|
||||
@dataclass(kw_only=True)
|
||||
class ChatPostprocArgs(PostprocArgs):
|
||||
echo: bool = False
|
||||
role: str = None
|
||||
model: str = None
|
||||
role: str
|
||||
model: str
|
||||
num_choices: int = 1
|
||||
tools: Optional[List[ChatCompletionToolsParam]] = None
|
||||
tool_choice: Optional[Union[Literal["none"],
|
||||
@ -44,8 +49,11 @@ class ChatPostprocArgs(PostprocArgs):
|
||||
stream_options: Optional[StreamOptions] = None
|
||||
last_message_content: Optional[str] = None
|
||||
reasoning_parser: Optional[str] = None
|
||||
tool_parser: Optional[str] = None
|
||||
reasoning_parser_dict: dict[int, BaseReasoningParser] = field(
|
||||
default_factory=dict)
|
||||
tool_parser_dict: dict[int, BaseToolParser] = field(default_factory=dict)
|
||||
has_tool_call: dict[int, bool] = field(default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def from_request(cls, request: ChatCompletionRequest):
|
||||
@ -116,6 +124,31 @@ def apply_reasoning_parser(args: ChatPostprocArgs, output_index: int, text: str,
|
||||
return content, reasoning_content
|
||||
|
||||
|
||||
def apply_tool_parser(args: ChatPostprocArgs, output_index: int, text: str,
|
||||
streaming: bool) -> Tuple[str, List[ToolCallItem]]:
|
||||
tool_parser = None
|
||||
tools = args.tools
|
||||
if args.tool_parser is not None and tools is not None:
|
||||
if output_index not in args.tool_parser_dict:
|
||||
args.tool_parser_dict[
|
||||
output_index] = ToolParserFactory.create_tool_parser(
|
||||
args.tool_parser)
|
||||
tool_parser = args.tool_parser_dict[output_index]
|
||||
|
||||
if tool_parser is not None and tools is not None:
|
||||
if not streaming:
|
||||
result = tool_parser.detect_and_parse(text, tools)
|
||||
else:
|
||||
result = tool_parser.parse_streaming_increment(text, tools)
|
||||
normal_text, calls = result.normal_text, result.calls
|
||||
if result.calls:
|
||||
args.has_tool_call[output_index] = True
|
||||
else:
|
||||
normal_text, calls = text, []
|
||||
|
||||
return normal_text, calls
|
||||
|
||||
|
||||
@nvtx_range_debug("chat_stream_post_processor")
|
||||
def chat_stream_post_processor(rsp: GenerationResultBase,
|
||||
args: ChatPostprocArgs) -> List[str]:
|
||||
@ -176,27 +209,63 @@ def chat_stream_post_processor(rsp: GenerationResultBase,
|
||||
if args.tool_choice and type(
|
||||
args.tool_choice) is ChatCompletionNamedToolChoiceParam:
|
||||
delta_message = DeltaMessage(tool_calls=[
|
||||
ToolCall(function=FunctionCall(
|
||||
name=args.tool_choice.function.name, arguments=delta_text))
|
||||
])
|
||||
DeltaToolCall(
|
||||
function=DeltaFunctionCall(
|
||||
name=args.tool_choice.function.name,
|
||||
arguments=delta_text),
|
||||
index=i,
|
||||
),
|
||||
], )
|
||||
else:
|
||||
delta_message = DeltaMessage(content=delta_text,
|
||||
reasoning_content=reasoning_delta_text)
|
||||
delta_text, calls = apply_tool_parser(args, i, delta_text, True)
|
||||
tool_calls = []
|
||||
for call_item in calls:
|
||||
# Tool call ID should be generated only once per tool call
|
||||
if call_item.name:
|
||||
# First chunk: include ID and function name
|
||||
tool_call_id = make_tool_call_id()
|
||||
function_name = call_item.name
|
||||
else:
|
||||
# Subsequent chunks: null ID and name for argument deltas
|
||||
tool_call_id = None
|
||||
function_name = None
|
||||
|
||||
tool_calls.append(
|
||||
DeltaToolCall(
|
||||
id=tool_call_id,
|
||||
index=call_item.tool_index,
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name,
|
||||
arguments=call_item.parameters,
|
||||
),
|
||||
))
|
||||
if tool_calls or delta_text or reasoning_delta_text or output.finish_reason:
|
||||
delta_message = DeltaMessage(
|
||||
content=delta_text,
|
||||
reasoning_content=reasoning_delta_text,
|
||||
tool_calls=tool_calls if tool_calls else None)
|
||||
else:
|
||||
continue
|
||||
|
||||
choice = ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=delta_message,
|
||||
finish_reason=None,
|
||||
avg_decoded_tokens_per_iter=getattr(rsp,
|
||||
'avg_decoded_tokens_per_iter',
|
||||
None))
|
||||
None),
|
||||
stop_reason=output.stop_reason,
|
||||
)
|
||||
if args.return_logprobs:
|
||||
logprobs = output.logprobs_diff
|
||||
token_ids = output.token_ids_diff
|
||||
choice.logprobs = create_logprobs(token_ids, args.tokenizer,
|
||||
logprobs, args.top_logprobs)
|
||||
if output.finish_reason is not None:
|
||||
choice.finish_reason = output.finish_reason
|
||||
if output.finish_reason == "stop" and args.has_tool_call.get(
|
||||
i, False):
|
||||
choice.finish_reason = "tool_calls"
|
||||
else:
|
||||
choice.finish_reason = output.finish_reason
|
||||
choice.stop_reason = output.stop_reason
|
||||
finish_reason_sent[i] = True
|
||||
chunk = ChatCompletionStreamResponse(choices=[choice], model=args.model)
|
||||
@ -247,21 +316,34 @@ def chat_response_post_processor(
|
||||
name=args.tool_choice.function.name, arguments=text))
|
||||
])
|
||||
else:
|
||||
if text is None:
|
||||
text = ""
|
||||
text, calls = apply_tool_parser(args, output.index, text, False)
|
||||
tool_calls = [
|
||||
ToolCall(function=FunctionCall(name=call.name or "",
|
||||
arguments=call.parameters))
|
||||
for call in calls
|
||||
]
|
||||
message = ChatMessage(role=role,
|
||||
content=text,
|
||||
reasoning_content=reasoning_text)
|
||||
reasoning_content=reasoning_text,
|
||||
tool_calls=tool_calls)
|
||||
disaggregated_params = to_disaggregated_params(
|
||||
output.disaggregated_params)
|
||||
choice = ChatCompletionResponseChoice(
|
||||
index=output.index,
|
||||
message=message,
|
||||
finish_reason=output.finish_reason,
|
||||
stop_reason=output.stop_reason,
|
||||
disaggregated_params=disaggregated_params,
|
||||
avg_decoded_tokens_per_iter=getattr(rsp,
|
||||
'avg_decoded_tokens_per_iter',
|
||||
None),
|
||||
)
|
||||
if output.finish_reason == "stop" and args.has_tool_call.get(
|
||||
output.index, False):
|
||||
choice.finish_reason = "tool_calls"
|
||||
else:
|
||||
choice.finish_reason = output.finish_reason
|
||||
|
||||
if args.return_logprobs:
|
||||
choice.logprobs = create_logprobs(output.token_ids, args.tokenizer,
|
||||
|
||||
3
tensorrt_llm/serve/tool_parser/__init__.py
Normal file
3
tensorrt_llm/serve/tool_parser/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .tool_parser_factory import ToolParserFactory
|
||||
|
||||
__all__ = ["ToolParserFactory"]
|
||||
324
tensorrt_llm/serve/tool_parser/base_tool_parser.py
Normal file
324
tensorrt_llm/serve/tool_parser/base_tool_parser.py
Normal file
@ -0,0 +1,324 @@
|
||||
# Adapted from https://github.com/sgl-project/sglang/blob/083629c23564e1a64deaa052f1df5c5d914358d8/python/sglang/srt/function_call/base_format_detector.py
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from partial_json_parser.core.exceptions import MalformedJSON
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
from ..openai_protocol import ChatCompletionToolsParam as Tool
|
||||
from .core_types import StreamingParseResult, ToolCallItem, _GetInfoFunc
|
||||
from .utils import find_common_prefix, is_complete_json, partial_json_loads
|
||||
|
||||
|
||||
class BaseToolParser(ABC):
|
||||
"""Base class providing two sets of interfaces: one-time and streaming incremental."""
|
||||
|
||||
def __init__(self):
|
||||
# Streaming state management
|
||||
# Buffer for accumulating incomplete patterns that arrive across multiple streaming chunks
|
||||
self._buffer = ""
|
||||
# Stores complete tool call info (name and arguments) for each tool being parsed.
|
||||
# Used by serving layer for completion handling when streaming ends.
|
||||
# Format: [{"name": str, "arguments": dict}, ...]
|
||||
self.prev_tool_call_arr: List[Dict] = []
|
||||
# Index of currently streaming tool call. Starts at -1 (no active tool),
|
||||
# increments as each tool completes. Tracks which tool's arguments are streaming.
|
||||
self.current_tool_id: int = -1
|
||||
# Flag for whether current tool's name has been sent to client.
|
||||
# Tool names sent first with empty parameters, then arguments stream incrementally.
|
||||
self.current_tool_name_sent: bool = False
|
||||
# Tracks raw JSON string content streamed to client for each tool's arguments.
|
||||
# Critical for serving layer to calculate remaining content when streaming ends.
|
||||
# Each index corresponds to a tool_id. Example: ['{"location": "San Francisco"', '{"temp": 72']
|
||||
self.streamed_args_for_tool: List[str] = []
|
||||
|
||||
# Token configuration (override in subclasses)
|
||||
self.bot_token = "" # nosec B105
|
||||
self.eot_token = "" # nosec B105
|
||||
self.tool_call_separator = ", "
|
||||
|
||||
def _get_tool_indices(self, tools: List[Tool]) -> Dict[str, int]:
|
||||
"""
|
||||
Get a mapping of tool names to their indices in the tools list.
|
||||
|
||||
This utility method creates a dictionary mapping function names to their
|
||||
indices in the tools list, which is commonly needed for tool validation
|
||||
and ToolCallItem creation.
|
||||
|
||||
Args:
|
||||
tools: List of available tools
|
||||
|
||||
Returns:
|
||||
Dictionary mapping tool names to their indices
|
||||
"""
|
||||
return {
|
||||
tool.function.name: i
|
||||
for i, tool in enumerate(tools) if tool.function.name
|
||||
}
|
||||
|
||||
def parse_base_json(self, action: Any,
|
||||
tools: List[Tool]) -> List[ToolCallItem]:
|
||||
tool_indices = self._get_tool_indices(tools)
|
||||
if not isinstance(action, list):
|
||||
action = [action]
|
||||
|
||||
results = []
|
||||
for act in action:
|
||||
name = act.get("name")
|
||||
if name and name in tool_indices:
|
||||
results.append(
|
||||
ToolCallItem(
|
||||
tool_index=
|
||||
-1, # Caller should update this based on the actual tools array called
|
||||
name=name,
|
||||
parameters=json.dumps(
|
||||
act.get("parameters") or act.get("arguments", {}),
|
||||
ensure_ascii=False,
|
||||
),
|
||||
))
|
||||
else:
|
||||
logger.warning(
|
||||
f"Model attempted to call undefined function: {name}")
|
||||
|
||||
return results
|
||||
|
||||
@abstractmethod
|
||||
def detect_and_parse(self, text: str,
|
||||
tools: List[Tool]) -> StreamingParseResult:
|
||||
"""
|
||||
Parses the text in one go. Returns success=True if the format matches, otherwise False.
|
||||
Note that leftover_text here represents "content that this parser will not consume further".
|
||||
"""
|
||||
action = json.loads(text)
|
||||
return StreamingParseResult(calls=self.parse_base_json(action, tools))
|
||||
|
||||
def _ends_with_partial_token(self, buffer: str, bot_token: str) -> int:
|
||||
"""
|
||||
Check if buffer ends with a partial bot_token.
|
||||
Return the length of the partial bot_token.
|
||||
|
||||
For some format, the bot_token is not a token in model's vocabulary, such as
|
||||
`[TOOL_CALLS] [` in Mistral.
|
||||
"""
|
||||
for i in range(1, min(len(buffer) + 1, len(bot_token))):
|
||||
if bot_token.startswith(buffer[-i:]):
|
||||
return i
|
||||
return 0
|
||||
|
||||
def parse_streaming_increment(self, new_text: str,
|
||||
tools: List[Tool]) -> StreamingParseResult:
|
||||
"""
|
||||
Streaming incremental parsing with tool validation.
|
||||
|
||||
This base implementation works best with formats where:
|
||||
1. bot_token is followed immediately by JSON (e.g., bot_token + JSON_array)
|
||||
2. JSON can be parsed incrementally using partial_json_loads
|
||||
3. Multiple tool calls are separated by "; " or ", "
|
||||
|
||||
Examples of incompatible formats (need custom implementation, may reuse some logic from this class):
|
||||
- Each tool call is wrapped in a separate block: See Qwen25Detector
|
||||
- Multiple separate blocks: [TOOL_CALLS] [...] \n [TOOL_CALLS] [...]
|
||||
- Tool call is Pythonic style
|
||||
|
||||
For incompatible formats, detectors should override this method with custom logic.
|
||||
"""
|
||||
# Append new text to buffer
|
||||
self._buffer += new_text
|
||||
current_text = self._buffer
|
||||
|
||||
# The current_text has tool_call if it is the start of a new tool call sequence
|
||||
# or it is the start of a new tool call after a tool call separator, when there is a previous tool call
|
||||
if not (self.has_tool_call(current_text) or
|
||||
(self.current_tool_id > 0
|
||||
and current_text.startswith(self.tool_call_separator))):
|
||||
# Only clear buffer if we're sure no tool call is starting
|
||||
if not self._ends_with_partial_token(self._buffer, self.bot_token):
|
||||
normal_text = self._buffer
|
||||
self._buffer = ""
|
||||
if self.eot_token in normal_text:
|
||||
normal_text = normal_text.replace(self.eot_token, "")
|
||||
return StreamingParseResult(normal_text=normal_text)
|
||||
else:
|
||||
# Might be partial bot_token, keep buffering
|
||||
return StreamingParseResult()
|
||||
|
||||
# Build tool indices if not already built
|
||||
if not hasattr(self, "_tool_indices"):
|
||||
self._tool_indices = self._get_tool_indices(tools)
|
||||
|
||||
flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
|
||||
|
||||
try:
|
||||
try:
|
||||
tool_call_pos = current_text.find(self.bot_token)
|
||||
if tool_call_pos != -1:
|
||||
start_idx = tool_call_pos + len(self.bot_token)
|
||||
elif self.current_tool_id > 0 and current_text.startswith(
|
||||
self.tool_call_separator):
|
||||
start_idx = len(self.tool_call_separator)
|
||||
else:
|
||||
start_idx = 0
|
||||
|
||||
if start_idx >= len(current_text):
|
||||
return StreamingParseResult()
|
||||
|
||||
(obj, end_idx) = partial_json_loads(current_text[start_idx:],
|
||||
flags)
|
||||
|
||||
is_current_complete = is_complete_json(
|
||||
current_text[start_idx:start_idx + end_idx])
|
||||
|
||||
# Validate tool name if present
|
||||
if "name" in obj and obj["name"] not in self._tool_indices:
|
||||
# Invalid tool name - reset state
|
||||
self._buffer = ""
|
||||
self.current_tool_id = -1
|
||||
self.current_tool_name_sent = False
|
||||
if self.streamed_args_for_tool:
|
||||
self.streamed_args_for_tool.pop()
|
||||
return StreamingParseResult()
|
||||
|
||||
# Handle parameters/arguments consistency
|
||||
# NOTE: we assume here that the obj is always partial of a single tool call
|
||||
if "parameters" in obj:
|
||||
assert ("arguments" not in obj
|
||||
), "model generated both parameters and arguments"
|
||||
obj["arguments"] = obj["parameters"]
|
||||
|
||||
current_tool_call = obj
|
||||
|
||||
except MalformedJSON:
|
||||
return StreamingParseResult()
|
||||
|
||||
if not current_tool_call:
|
||||
return StreamingParseResult()
|
||||
|
||||
# Case 1: Handle tool name streaming
|
||||
# This happens when we encounter a tool but haven't sent its name yet
|
||||
if not self.current_tool_name_sent:
|
||||
function_name = current_tool_call.get("name")
|
||||
|
||||
if function_name and function_name in self._tool_indices:
|
||||
# If this is a new tool (current_tool_id was -1), initialize it
|
||||
if self.current_tool_id == -1:
|
||||
self.current_tool_id = 0
|
||||
self.streamed_args_for_tool.append("")
|
||||
# If this is a subsequent tool, ensure streamed_args_for_tool is large enough
|
||||
elif self.current_tool_id >= len(
|
||||
self.streamed_args_for_tool):
|
||||
while len(self.streamed_args_for_tool
|
||||
) <= self.current_tool_id:
|
||||
self.streamed_args_for_tool.append("")
|
||||
|
||||
# Send the tool name with empty parameters
|
||||
res = StreamingParseResult(calls=[
|
||||
ToolCallItem(
|
||||
tool_index=self.current_tool_id,
|
||||
name=function_name,
|
||||
parameters="",
|
||||
)
|
||||
], )
|
||||
self.current_tool_name_sent = True
|
||||
else:
|
||||
res = StreamingParseResult()
|
||||
|
||||
# Case 2: Handle streaming arguments
|
||||
# This happens when we've already sent the tool name and now need to stream arguments incrementally
|
||||
else:
|
||||
cur_arguments = current_tool_call.get("arguments")
|
||||
res = StreamingParseResult()
|
||||
|
||||
if cur_arguments:
|
||||
# Calculate how much of the arguments we've already streamed
|
||||
sent = len(
|
||||
self.streamed_args_for_tool[self.current_tool_id])
|
||||
cur_args_json = json.dumps(cur_arguments)
|
||||
prev_arguments = None
|
||||
if self.current_tool_id < len(self.prev_tool_call_arr):
|
||||
prev_arguments = self.prev_tool_call_arr[
|
||||
self.current_tool_id].get("arguments")
|
||||
|
||||
argument_diff = None
|
||||
|
||||
# If the current tool's JSON is complete, send all remaining arguments
|
||||
if is_current_complete:
|
||||
argument_diff = cur_args_json[sent:]
|
||||
completing_tool_id = (
|
||||
self.current_tool_id
|
||||
) # Save the ID of the tool that's completing
|
||||
|
||||
# Only remove the processed portion, keep unprocessed content
|
||||
self._buffer = current_text[start_idx + end_idx:]
|
||||
|
||||
if self.current_tool_id < len(self.prev_tool_call_arr):
|
||||
self.prev_tool_call_arr[
|
||||
self.current_tool_id].clear()
|
||||
self.current_tool_name_sent = False
|
||||
self.streamed_args_for_tool[self.current_tool_id] = ""
|
||||
self.current_tool_id += 1
|
||||
|
||||
# If the tool is still being parsed, send incremental changes
|
||||
elif prev_arguments:
|
||||
prev_args_json = json.dumps(prev_arguments)
|
||||
if cur_args_json != prev_args_json:
|
||||
prefix = find_common_prefix(prev_args_json,
|
||||
cur_args_json)
|
||||
argument_diff = prefix[sent:]
|
||||
|
||||
# Send the argument diff if there's something new
|
||||
if argument_diff is not None:
|
||||
# Use the correct tool_index: completing_tool_id for completed tools, current_tool_id for ongoing
|
||||
tool_index_to_use = (completing_tool_id
|
||||
if is_current_complete else
|
||||
self.current_tool_id)
|
||||
res = StreamingParseResult(calls=[
|
||||
ToolCallItem(
|
||||
tool_index=tool_index_to_use,
|
||||
parameters=argument_diff,
|
||||
)
|
||||
], )
|
||||
if not is_current_complete:
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] += argument_diff
|
||||
|
||||
# Update prev_tool_call_arr with current state
|
||||
if self.current_tool_id >= 0:
|
||||
# Ensure prev_tool_call_arr is large enough
|
||||
while len(self.prev_tool_call_arr) <= self.current_tool_id:
|
||||
self.prev_tool_call_arr.append({})
|
||||
self.prev_tool_call_arr[
|
||||
self.current_tool_id] = current_tool_call
|
||||
|
||||
return res
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in parse_streaming_increment: {e}")
|
||||
return StreamingParseResult()
|
||||
|
||||
@abstractmethod
|
||||
def has_tool_call(self, text: str) -> bool:
|
||||
"""
|
||||
Check if the given text contains function call markers specific to this format.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def supports_structural_tag(self) -> bool:
|
||||
"""Return True if this detector supports structural tag format."""
|
||||
return True
|
||||
|
||||
@abstractmethod
|
||||
def structure_info(self) -> _GetInfoFunc:
|
||||
"""
|
||||
Return a function that creates StructureInfo for constrained generation.
|
||||
|
||||
The returned function takes a tool name and returns a StructureInfo object
|
||||
containing the begin/end patterns and trigger tokens needed for constrained
|
||||
generation of function calls in this format.
|
||||
|
||||
Returns:
|
||||
A function that takes a tool name (str) and returns StructureInfo
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
35
tensorrt_llm/serve/tool_parser/core_types.py
Normal file
35
tensorrt_llm/serve/tool_parser/core_types.py
Normal file
@ -0,0 +1,35 @@
|
||||
# Adapted from https://github.com/sgl-project/sglang/blob/083629c23564e1a64deaa052f1df5c5d914358d8/python/sglang/srt/function_call/qwen25_detector.py
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ToolCallItem(BaseModel):
|
||||
"""Simple encapsulation of the parsed ToolCall result for easier usage in streaming contexts."""
|
||||
|
||||
tool_index: int
|
||||
name: Optional[str] = None
|
||||
parameters: str # JSON string
|
||||
|
||||
|
||||
class StreamingParseResult(BaseModel):
|
||||
"""Result of streaming incremental parsing."""
|
||||
|
||||
normal_text: str = ""
|
||||
calls: List[ToolCallItem] = []
|
||||
|
||||
|
||||
@dataclass
|
||||
class StructureInfo:
|
||||
begin: str
|
||||
end: str
|
||||
trigger: str
|
||||
|
||||
|
||||
"""
|
||||
Helper alias of function
|
||||
Usually it is a function that takes a name string and returns a StructureInfo object,
|
||||
which can be used to construct a structural_tag object
|
||||
"""
|
||||
_GetInfoFunc = Callable[[str], StructureInfo]
|
||||
114
tensorrt_llm/serve/tool_parser/qwen3_tool_parser.py
Normal file
114
tensorrt_llm/serve/tool_parser/qwen3_tool_parser.py
Normal file
@ -0,0 +1,114 @@
|
||||
# Adapted from https://github.com/sgl-project/sglang/blob/083629c23564e1a64deaa052f1df5c5d914358d8/python/sglang/srt/function_call/qwen25_detector.py
|
||||
import json
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
from ..openai_protocol import ChatCompletionToolsParam as Tool
|
||||
from .base_tool_parser import BaseToolParser
|
||||
from .core_types import StreamingParseResult, StructureInfo, _GetInfoFunc
|
||||
|
||||
|
||||
class Qwen3ToolParser(BaseToolParser):
|
||||
"""
|
||||
Detector for Qwen 2.5 and Qwen 3 model function call format.
|
||||
|
||||
Format Structure:
|
||||
```
|
||||
<tool_call>\n{"name":"func1", "arguments":{...}}\n</tool_call>\n<tool_call>\n{"name":"func2", "arguments":{...}}\n</tool_call>
|
||||
```
|
||||
|
||||
Key Components:
|
||||
- Tool Call Tags: `<tool_call>` and `</tool_call>` wrap each individual call
|
||||
- Function Call Object: JSON object with "name" and "arguments" fields
|
||||
|
||||
Reference: https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct?chat_template=default
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initializes the detector with necessary state variables.
|
||||
"""
|
||||
super().__init__()
|
||||
self.bot_token = "<tool_call>\n" # nosec B105
|
||||
self.eot_token = "\n</tool_call>" # nosec B105
|
||||
self.tool_call_separator = "\n"
|
||||
self._normal_text_buffer = "" # Buffer for handling partial end tokens
|
||||
|
||||
def has_tool_call(self, text: str) -> bool:
|
||||
"""Check if the text contains a Qwen 3 format tool call."""
|
||||
return self.bot_token in text
|
||||
|
||||
def detect_and_parse(self, text: str,
|
||||
tools: List[Tool]) -> StreamingParseResult:
|
||||
"""
|
||||
One-time parsing: Detects and parses tool calls in the provided text.
|
||||
|
||||
:param text: The complete text to parse.
|
||||
:param tools: List of available tools.
|
||||
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
|
||||
"""
|
||||
idx = text.find(self.bot_token)
|
||||
normal_text = text[:idx].strip() if idx != -1 else text
|
||||
if self.bot_token not in text:
|
||||
return StreamingParseResult(normal_text=normal_text, calls=[])
|
||||
|
||||
# Find all <tool_call>\n...\n</tool_call> blocks
|
||||
pattern = rf"{re.escape(self.bot_token)}(.*?){re.escape(self.eot_token)}"
|
||||
match_result_list = re.findall(pattern, text, re.DOTALL)
|
||||
calls = []
|
||||
for match_result in match_result_list:
|
||||
try:
|
||||
parsed_call = json.loads(match_result.strip())
|
||||
calls.extend(self.parse_base_json(parsed_call, tools))
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(
|
||||
f"Failed to parse JSON part: {match_result}, JSON parse error: {str(e)}"
|
||||
)
|
||||
continue
|
||||
return StreamingParseResult(normal_text=normal_text, calls=calls)
|
||||
|
||||
def parse_streaming_increment(self, new_text: str,
|
||||
tools: List[Tool]) -> StreamingParseResult:
|
||||
"""
|
||||
Streaming incremental parsing for Qwen 3 tool calls.
|
||||
Uses base class implementation with buffering to handle partial end tokens.
|
||||
"""
|
||||
result = super().parse_streaming_increment(new_text, tools)
|
||||
|
||||
# Handle partial end tokens that are streamed character by character
|
||||
if result.normal_text:
|
||||
self._normal_text_buffer += result.normal_text
|
||||
|
||||
# Check if buffer contains complete end token (without leading newline)
|
||||
end_token_without_newline = self.eot_token[1:] # "</tool_call>"
|
||||
if end_token_without_newline in self._normal_text_buffer:
|
||||
cleaned_text = self._normal_text_buffer.replace(
|
||||
end_token_without_newline, "")
|
||||
self._normal_text_buffer = ""
|
||||
result.normal_text = cleaned_text
|
||||
else:
|
||||
# Check if buffer might contain partial end token at the end
|
||||
partial_match_len = self._ends_with_partial_token(
|
||||
self._normal_text_buffer, end_token_without_newline)
|
||||
|
||||
if partial_match_len:
|
||||
# Keep potential partial match in buffer, return the rest
|
||||
result.normal_text = self._normal_text_buffer[:
|
||||
-partial_match_len]
|
||||
self._normal_text_buffer = self._normal_text_buffer[
|
||||
-partial_match_len:]
|
||||
else:
|
||||
# No partial match, return all buffered text
|
||||
result.normal_text = self._normal_text_buffer
|
||||
self._normal_text_buffer = ""
|
||||
|
||||
return result
|
||||
|
||||
def structure_info(self) -> _GetInfoFunc:
|
||||
return lambda name: StructureInfo(
|
||||
begin='<tool_call>\n{"name":"' + name + '", "arguments":',
|
||||
end="}\n</tool_call>",
|
||||
trigger="<tool_call>",
|
||||
)
|
||||
21
tensorrt_llm/serve/tool_parser/tool_parser_factory.py
Normal file
21
tensorrt_llm/serve/tool_parser/tool_parser_factory.py
Normal file
@ -0,0 +1,21 @@
|
||||
from typing import Type
|
||||
|
||||
from .base_tool_parser import BaseToolParser
|
||||
from .qwen3_tool_parser import Qwen3ToolParser
|
||||
|
||||
|
||||
class ToolParserFactory:
|
||||
parsers: dict[str, Type[BaseToolParser]] = {
|
||||
"qwen3": Qwen3ToolParser,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def create_tool_parser(tool_parser: str) -> BaseToolParser:
|
||||
try:
|
||||
tool_parser_class = ToolParserFactory.parsers[tool_parser.lower()]
|
||||
return tool_parser_class()
|
||||
except KeyError as e:
|
||||
raise ValueError(
|
||||
f"Invalid tool_parser: {tool_parser}\n"
|
||||
f"Supported parsers: {list(ToolParserFactory.parsers.keys())}"
|
||||
) from e
|
||||
56
tensorrt_llm/serve/tool_parser/utils.py
Normal file
56
tensorrt_llm/serve/tool_parser/utils.py
Normal file
@ -0,0 +1,56 @@
|
||||
# Adapted from https://github.com/sgl-project/sglang/blob/083629c23564e1a64deaa052f1df5c5d914358d8/python/sglang/srt/function_call/qwen25_detector.py
|
||||
import json
|
||||
from json import JSONDecodeError, JSONDecoder
|
||||
from json.decoder import WHITESPACE
|
||||
from typing import Any
|
||||
|
||||
import partial_json_parser
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
|
||||
def find_common_prefix(s1: str, s2: str) -> str:
|
||||
prefix = ""
|
||||
min_length = min(len(s1), len(s2))
|
||||
for i in range(0, min_length):
|
||||
if s1[i] == s2[i]:
|
||||
prefix += s1[i]
|
||||
else:
|
||||
break
|
||||
return prefix
|
||||
|
||||
|
||||
def partial_json_loads(input_str: str, flags: Allow) -> tuple[Any, int]:
|
||||
"""
|
||||
Parse incomplete or partial JSON strings commonly encountered during streaming.
|
||||
|
||||
Args:
|
||||
input_str (str): The potentially incomplete JSON string to parse.
|
||||
flags (Allow): Bitwise flags controlling what types of partial data are allowed.
|
||||
Common flags include:
|
||||
- Allow.STR: Allow partial strings (e.g., '"hello wo' -> 'hello wo')
|
||||
- Allow.OBJ: Allow partial objects (e.g., '{"key":' -> {'key': None})
|
||||
- Allow.ARR: Allow partial arrays (e.g., '[1, 2,' -> [1, 2])
|
||||
- Allow.ALL: Allow all types of partial data
|
||||
|
||||
Returns:
|
||||
Tuple[Any, int]: A tuple containing:
|
||||
- parsed_object: The Python object parsed from the JSON
|
||||
- consumed_length: Number of characters consumed from input_str
|
||||
"""
|
||||
try:
|
||||
return (partial_json_parser.loads(input_str, flags), len(input_str))
|
||||
except (JSONDecodeError, IndexError) as e:
|
||||
msg = getattr(e, "msg", str(e))
|
||||
if "Extra data" in msg or "pop from empty list" in msg:
|
||||
start = WHITESPACE.match(input_str, 0).end()
|
||||
obj, end = JSONDecoder().raw_decode(input_str, start)
|
||||
return obj, end
|
||||
raise
|
||||
|
||||
|
||||
def is_complete_json(input_str: str) -> bool:
|
||||
try:
|
||||
json.loads(input_str)
|
||||
return True
|
||||
except JSONDecodeError:
|
||||
return False
|
||||
@ -1613,6 +1613,13 @@ def test_openai_reasoning(llm_root, llm_venv, backend: str):
|
||||
])
|
||||
|
||||
|
||||
def test_openai_tool_call(llm_root, llm_venv):
|
||||
test_root = unittest_path() / "llmapi" / "apps"
|
||||
llm_venv.run_cmd(
|
||||
["-m", "pytest",
|
||||
str(test_root / "_test_openai_tool_call.py")])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sampler", ["torch_sampler", "trtllm_sampler"])
|
||||
def test_openai_completions_with_logit_bias(llm_root, llm_venv, sampler: str):
|
||||
test_root = unittest_path() / "llmapi" / "apps"
|
||||
|
||||
@ -57,6 +57,7 @@ l0_a10:
|
||||
- test_e2e.py::test_trtllm_serve_top_logprobs[pytorch]
|
||||
- test_e2e.py::test_openai_misc_example[pytorch]
|
||||
- test_e2e.py::test_openai_reasoning[pytorch]
|
||||
- test_e2e.py::test_openai_tool_call
|
||||
- test_e2e.py::test_openai_completions_example[pytorch]
|
||||
- test_e2e.py::test_openai_chat_example[pytorch] TIMEOUT (90)
|
||||
- test_e2e.py::test_trtllm_bench_request_rate_and_concurrency[enable_concurrency-]
|
||||
@ -71,6 +72,8 @@ l0_a10:
|
||||
- unittest/llmapi/test_additional_model_outputs.py
|
||||
# executor
|
||||
- unittest/executor/test_rpc.py
|
||||
# trtllm-serve CPU-only
|
||||
- unittest/llmapi/apps/test_tool_parsers.py
|
||||
- condition:
|
||||
ranges:
|
||||
system_gpu_count:
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
This directory contains the end-to-end tests for the LLM API applications in `examples/apps`.
|
||||
This directory contains the end-to-end tests for `trtllm-serve`.
|
||||
|
||||
These tests are triggered in the `test_e2e.py`.
|
||||
|
||||
121
tests/unittest/llmapi/apps/_test_openai_tool_call.py
Normal file
121
tests/unittest/llmapi/apps/_test_openai_tool_call.py
Normal file
@ -0,0 +1,121 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
|
||||
from .openai_server import RemoteOpenAIServer
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
||||
from test_llm import get_model_path
|
||||
|
||||
TOOLS = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_temperature",
|
||||
"description": "Get current temperature at a location.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type":
|
||||
"string",
|
||||
"description":
|
||||
'The location to get the temperature for, in the format "City, State, Country".',
|
||||
},
|
||||
"unit": {
|
||||
"type":
|
||||
"string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description":
|
||||
'The unit to return the temperature in. Defaults to "celsius".',
|
||||
},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def get_current_temperature(location: str, unit: str = "celsius") -> dict:
|
||||
return {"temperature": 20 if unit == "celsius" else 68}
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", ids=["Qwen3-0.6B"])
|
||||
def model_name() -> str:
|
||||
return "Qwen3/Qwen3-0.6B"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server(model_name: str):
|
||||
model_path = get_model_path(model_name)
|
||||
args = ["--tool_parser", "qwen3"]
|
||||
with RemoteOpenAIServer(model_path, cli_args=args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def client(server: RemoteOpenAIServer):
|
||||
return server.get_async_client()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
async def test_tool_parser(client: openai.AsyncOpenAI, model_name: str):
|
||||
response = await client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "What's the temperature in San Francisco now?"
|
||||
}],
|
||||
tools=TOOLS)
|
||||
assert response.choices[0].finish_reason == "tool_calls"
|
||||
message = response.choices[0].message
|
||||
assert message.content is not None
|
||||
assert message.tool_calls is not None
|
||||
assert len(message.tool_calls) == 1
|
||||
tool_call = message.tool_calls[0]
|
||||
assert tool_call.function.name == "get_current_temperature"
|
||||
args = json.loads(tool_call.function.arguments)
|
||||
get_current_temperature(**args)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
async def test_tool_parser_streaming(client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
response = await client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "What's the temperature in San Francisco now?"
|
||||
}],
|
||||
tools=TOOLS,
|
||||
stream=True)
|
||||
tool_id = None
|
||||
tool_name = None
|
||||
parameters = ""
|
||||
finish_reason = None
|
||||
|
||||
async for chunk in response:
|
||||
if chunk.choices[0].delta.tool_calls:
|
||||
tool_call = chunk.choices[0].delta.tool_calls[0]
|
||||
if tool_call.id:
|
||||
if tool_id is not None:
|
||||
raise RuntimeError("tool_id already exists")
|
||||
tool_id = tool_call.id
|
||||
if tool_call.function.name:
|
||||
if tool_name is not None:
|
||||
raise RuntimeError("tool_name already exists")
|
||||
tool_name = tool_call.function.name
|
||||
if tool_call.function.arguments:
|
||||
parameters += tool_call.function.arguments
|
||||
if chunk.choices[0].finish_reason:
|
||||
finish_reason = chunk.choices[0].finish_reason
|
||||
assert tool_id is not None
|
||||
assert tool_name == "get_current_temperature"
|
||||
assert finish_reason == "tool_calls"
|
||||
assert parameters
|
||||
args = json.loads(parameters)
|
||||
get_current_temperature(**args)
|
||||
597
tests/unittest/llmapi/apps/test_tool_parsers.py
Normal file
597
tests/unittest/llmapi/apps/test_tool_parsers.py
Normal file
@ -0,0 +1,597 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from tensorrt_llm.serve.openai_protocol import (ChatCompletionToolsParam,
|
||||
FunctionDefinition)
|
||||
from tensorrt_llm.serve.tool_parser.base_tool_parser import BaseToolParser
|
||||
from tensorrt_llm.serve.tool_parser.core_types import StructureInfo
|
||||
from tensorrt_llm.serve.tool_parser.qwen3_tool_parser import Qwen3ToolParser
|
||||
|
||||
|
||||
# Test fixtures for common tools
|
||||
@pytest.fixture
|
||||
def sample_tools():
|
||||
"""Sample tools for testing."""
|
||||
return [
|
||||
ChatCompletionToolsParam(
|
||||
type="function",
|
||||
function=FunctionDefinition(name="get_weather",
|
||||
description="Get the current weather",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type":
|
||||
"string",
|
||||
"description":
|
||||
"The city and state"
|
||||
},
|
||||
"unit": {
|
||||
"type":
|
||||
"string",
|
||||
"enum":
|
||||
["celsius", "fahrenheit"]
|
||||
}
|
||||
},
|
||||
"required": ["location"]
|
||||
})),
|
||||
ChatCompletionToolsParam(
|
||||
type="function",
|
||||
function=FunctionDefinition(name="search_web",
|
||||
description="Search the web",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type":
|
||||
"string",
|
||||
"description":
|
||||
"The search query"
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
}))
|
||||
]
|
||||
|
||||
|
||||
# Concrete implementation of BaseToolParser for testing
|
||||
class ConcreteToolParser(BaseToolParser):
|
||||
"""Concrete implementation of BaseToolParser for testing abstract methods."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.bot_token = "[TOOL_CALLS] "
|
||||
self.eot_token = "[/TOOL_CALLS]"
|
||||
|
||||
def has_tool_call(self, text: str) -> bool:
|
||||
return self.bot_token in text
|
||||
|
||||
def detect_and_parse(self, text: str, tools):
|
||||
# Placeholder to avoid NotImplementedError
|
||||
pass
|
||||
|
||||
def structure_info(self):
|
||||
return lambda name: StructureInfo(
|
||||
begin=f'[TOOL_CALLS] {{"name":"{name}", "arguments":',
|
||||
end="}[/TOOL_CALLS]",
|
||||
trigger="[TOOL_CALLS]")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# BaseToolParser Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestBaseToolParser:
|
||||
"""Test suite for BaseToolParser class."""
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test that BaseToolParser initializes correctly."""
|
||||
parser = ConcreteToolParser()
|
||||
assert parser._buffer == ""
|
||||
assert parser.prev_tool_call_arr == []
|
||||
assert parser.current_tool_id == -1
|
||||
assert parser.current_tool_name_sent is False
|
||||
assert parser.streamed_args_for_tool == []
|
||||
|
||||
def test_get_tool_indices(self, sample_tools):
|
||||
"""Test _get_tool_indices correctly maps tool names to indices."""
|
||||
parser = ConcreteToolParser()
|
||||
indices = parser._get_tool_indices(sample_tools)
|
||||
|
||||
assert len(indices) == 2
|
||||
assert indices["get_weather"] == 0
|
||||
assert indices["search_web"] == 1
|
||||
|
||||
def test_get_tool_indices_empty(self):
|
||||
"""Test _get_tool_indices with empty tools list."""
|
||||
parser = ConcreteToolParser()
|
||||
indices = parser._get_tool_indices([])
|
||||
assert indices == {}
|
||||
|
||||
def test_parse_base_json_single_tool(self, sample_tools):
|
||||
"""Test parse_base_json with a single tool call."""
|
||||
parser = ConcreteToolParser()
|
||||
action = {
|
||||
"name": "get_weather",
|
||||
"parameters": {
|
||||
"location": "San Francisco"
|
||||
}
|
||||
}
|
||||
|
||||
results = parser.parse_base_json(action, sample_tools)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].name == "get_weather"
|
||||
assert json.loads(results[0].parameters) == {
|
||||
"location": "San Francisco"
|
||||
}
|
||||
|
||||
def test_parse_base_json_with_arguments_key(self, sample_tools):
|
||||
"""Test parse_base_json handles 'arguments' key instead of 'parameters'."""
|
||||
parser = ConcreteToolParser()
|
||||
action = {"name": "search_web", "arguments": {"query": "TensorRT"}}
|
||||
|
||||
results = parser.parse_base_json(action, sample_tools)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].name == "search_web"
|
||||
assert json.loads(results[0].parameters) == {"query": "TensorRT"}
|
||||
|
||||
def test_parse_base_json_multiple_tools(self, sample_tools):
|
||||
"""Test parse_base_json with multiple tool calls."""
|
||||
parser = ConcreteToolParser()
|
||||
actions = [{
|
||||
"name": "get_weather",
|
||||
"parameters": {
|
||||
"location": "Boston"
|
||||
}
|
||||
}, {
|
||||
"name": "search_web",
|
||||
"arguments": {
|
||||
"query": "Python"
|
||||
}
|
||||
}]
|
||||
|
||||
results = parser.parse_base_json(actions, sample_tools)
|
||||
|
||||
assert len(results) == 2
|
||||
assert results[0].name == "get_weather"
|
||||
assert results[1].name == "search_web"
|
||||
|
||||
def test_parse_base_json_undefined_function(self, sample_tools):
|
||||
"""Test parse_base_json handles undefined function names gracefully."""
|
||||
parser = ConcreteToolParser()
|
||||
action = {"name": "undefined_function", "parameters": {}}
|
||||
|
||||
results = parser.parse_base_json(action, sample_tools)
|
||||
|
||||
# Should return empty list and log warning
|
||||
assert len(results) == 0
|
||||
|
||||
def test_parse_base_json_missing_parameters(self, sample_tools):
|
||||
"""Test parse_base_json handles missing parameters."""
|
||||
parser = ConcreteToolParser()
|
||||
action = {"name": "get_weather"}
|
||||
|
||||
results = parser.parse_base_json(action, sample_tools)
|
||||
|
||||
assert len(results) == 1
|
||||
assert json.loads(results[0].parameters) == {}
|
||||
|
||||
def test_ends_with_partial_token(self):
|
||||
"""Test _ends_with_partial_token detection."""
|
||||
parser = ConcreteToolParser()
|
||||
|
||||
# Partial token at end (bot_token starts with the suffix)
|
||||
assert parser._ends_with_partial_token("Some text [TOOL",
|
||||
"[TOOL_CALLS] ") == 5
|
||||
assert parser._ends_with_partial_token("Some text [",
|
||||
"[TOOL_CALLS] ") == 1
|
||||
assert parser._ends_with_partial_token("Some text [TOOL_CALLS",
|
||||
"[TOOL_CALLS] ") == 11
|
||||
|
||||
# No partial token
|
||||
assert parser._ends_with_partial_token("Some text",
|
||||
"[TOOL_CALLS] ") == 0
|
||||
assert parser._ends_with_partial_token("Some text [XYZ",
|
||||
"[TOOL_CALLS] ") == 0
|
||||
|
||||
# Complete token at end (entire buffer is bot_token prefix but not complete match)
|
||||
# When buffer equals bot_token, it returns 0 because it's not a partial anymore
|
||||
assert parser._ends_with_partial_token("text [TOOL_CALLS] ",
|
||||
"[TOOL_CALLS] ") == 0
|
||||
|
||||
def test_parse_streaming_increment_no_tool_call(self, sample_tools):
|
||||
"""Test streaming parser returns normal text when no tool call present."""
|
||||
parser = ConcreteToolParser()
|
||||
|
||||
result = parser.parse_streaming_increment("Hello, world!", sample_tools)
|
||||
|
||||
assert result.normal_text == "Hello, world!"
|
||||
assert len(result.calls) == 0
|
||||
|
||||
def test_parse_streaming_increment_partial_bot_token(self, sample_tools):
|
||||
"""Test streaming parser buffers partial bot token."""
|
||||
parser = ConcreteToolParser()
|
||||
|
||||
# Send partial bot token
|
||||
result = parser.parse_streaming_increment("[TOOL", sample_tools)
|
||||
|
||||
# Should buffer and return nothing
|
||||
assert result.normal_text == ""
|
||||
assert len(result.calls) == 0
|
||||
assert parser._buffer == "[TOOL"
|
||||
|
||||
def test_parse_streaming_increment_tool_name(self, sample_tools):
|
||||
"""Test streaming parser handles tool name streaming."""
|
||||
parser = ConcreteToolParser()
|
||||
|
||||
# Send bot token with partial JSON containing name
|
||||
result = parser.parse_streaming_increment(
|
||||
'[TOOL_CALLS] {"name":"get_weather"', sample_tools)
|
||||
|
||||
# Should send tool name with empty parameters
|
||||
assert len(result.calls) == 1
|
||||
assert result.calls[0].name == "get_weather"
|
||||
assert result.calls[0].parameters == ""
|
||||
assert result.calls[0].tool_index == 0
|
||||
assert parser.current_tool_name_sent is True
|
||||
|
||||
def test_parse_streaming_increment_tool_arguments(self, sample_tools):
|
||||
"""Test streaming parser handles incremental argument streaming."""
|
||||
parser = ConcreteToolParser()
|
||||
|
||||
# First send tool name
|
||||
result1 = parser.parse_streaming_increment(
|
||||
'[TOOL_CALLS] {"name":"get_weather"', sample_tools)
|
||||
# Should send tool name
|
||||
assert len(result1.calls) == 1
|
||||
assert result1.calls[0].name == "get_weather"
|
||||
|
||||
# Then send complete arguments (parser needs complete JSON to parse incrementally)
|
||||
result2 = parser.parse_streaming_increment(
|
||||
',"arguments":{"location":"San Francisco"}}', sample_tools)
|
||||
|
||||
# Should stream arguments or complete the tool call
|
||||
# The base implementation uses partial JSON parsing, so it may return results
|
||||
assert result2 is not None # Just verify it doesn't crash
|
||||
|
||||
def test_parse_streaming_increment_complete_tool(self, sample_tools):
|
||||
"""Test streaming parser handles complete tool call."""
|
||||
parser = ConcreteToolParser()
|
||||
|
||||
# Send complete tool call in one chunk
|
||||
result = parser.parse_streaming_increment(
|
||||
'[TOOL_CALLS] {"name":"get_weather","arguments":{"location":"Boston"}}',
|
||||
sample_tools)
|
||||
|
||||
# Should have sent tool name (first call)
|
||||
assert len(result.calls) == 1
|
||||
assert result.calls[0].name == "get_weather"
|
||||
|
||||
def test_parse_streaming_increment_invalid_tool_name(self, sample_tools):
|
||||
"""Test streaming parser handles invalid tool name."""
|
||||
parser = ConcreteToolParser()
|
||||
|
||||
# Send invalid tool name
|
||||
result = parser.parse_streaming_increment(
|
||||
'[TOOL_CALLS] {"name":"invalid_tool"', sample_tools)
|
||||
|
||||
# Should reset state
|
||||
assert len(result.calls) == 0
|
||||
assert parser._buffer == ""
|
||||
assert parser.current_tool_id == -1
|
||||
|
||||
def test_supports_structural_tag(self):
|
||||
"""Test supports_structural_tag returns True."""
|
||||
parser = ConcreteToolParser()
|
||||
assert parser.supports_structural_tag() is True
|
||||
|
||||
def test_structure_info(self):
|
||||
"""Test structure_info returns proper function."""
|
||||
parser = ConcreteToolParser()
|
||||
func = parser.structure_info()
|
||||
|
||||
info = func("test_function")
|
||||
assert isinstance(info, StructureInfo)
|
||||
assert "test_function" in info.begin
|
||||
assert info.trigger == "[TOOL_CALLS]"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Qwen3ToolParser Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestQwen3ToolParser:
|
||||
"""Test suite for Qwen3ToolParser class."""
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test that Qwen3ToolParser initializes correctly."""
|
||||
parser = Qwen3ToolParser()
|
||||
|
||||
assert parser.bot_token == "<tool_call>\n"
|
||||
assert parser.eot_token == "\n</tool_call>"
|
||||
assert parser.tool_call_separator == "\n"
|
||||
assert parser._normal_text_buffer == ""
|
||||
|
||||
def test_has_tool_call_true(self):
|
||||
"""Test has_tool_call returns True when tool call is present."""
|
||||
parser = Qwen3ToolParser()
|
||||
text = 'Some text <tool_call>\n{"name":"get_weather"}\n</tool_call>'
|
||||
|
||||
assert parser.has_tool_call(text) is True
|
||||
|
||||
def test_has_tool_call_false(self):
|
||||
"""Test has_tool_call returns False when no tool call present."""
|
||||
parser = Qwen3ToolParser()
|
||||
text = "Just some regular text without tool calls"
|
||||
|
||||
assert parser.has_tool_call(text) is False
|
||||
|
||||
def test_detect_and_parse_no_tool_call(self, sample_tools):
|
||||
"""Test detect_and_parse with text containing no tool calls."""
|
||||
parser = Qwen3ToolParser()
|
||||
text = "This is just a regular response."
|
||||
|
||||
result = parser.detect_and_parse(text, sample_tools)
|
||||
|
||||
assert result.normal_text == "This is just a regular response."
|
||||
assert len(result.calls) == 0
|
||||
|
||||
def test_detect_and_parse_single_tool(self, sample_tools):
|
||||
"""Test detect_and_parse with a single tool call."""
|
||||
parser = Qwen3ToolParser()
|
||||
text = 'Normal text\n<tool_call>\n{"name":"get_weather","arguments":{"location":"NYC"}}\n</tool_call>'
|
||||
|
||||
result = parser.detect_and_parse(text, sample_tools)
|
||||
|
||||
assert result.normal_text == "Normal text"
|
||||
assert len(result.calls) == 1
|
||||
assert result.calls[0].name == "get_weather"
|
||||
assert json.loads(result.calls[0].parameters) == {"location": "NYC"}
|
||||
|
||||
def test_detect_and_parse_multiple_tools(self, sample_tools):
|
||||
"""Test detect_and_parse with multiple tool calls."""
|
||||
parser = Qwen3ToolParser()
|
||||
text = (
|
||||
'<tool_call>\n{"name":"get_weather","arguments":{"location":"LA"}}\n</tool_call>\n'
|
||||
'<tool_call>\n{"name":"search_web","arguments":{"query":"AI"}}\n</tool_call>'
|
||||
)
|
||||
|
||||
result = parser.detect_and_parse(text, sample_tools)
|
||||
|
||||
assert len(result.calls) == 2
|
||||
assert result.calls[0].name == "get_weather"
|
||||
assert result.calls[1].name == "search_web"
|
||||
|
||||
def test_detect_and_parse_malformed_json(self, sample_tools):
|
||||
"""Test detect_and_parse handles malformed JSON gracefully."""
|
||||
parser = Qwen3ToolParser()
|
||||
text = '<tool_call>\n{"name":"get_weather","arguments":MALFORMED}\n</tool_call>'
|
||||
|
||||
result = parser.detect_and_parse(text, sample_tools)
|
||||
|
||||
# Should return empty calls due to JSON parsing error
|
||||
assert len(result.calls) == 0
|
||||
|
||||
def test_detect_and_parse_with_parameters_key(self, sample_tools):
|
||||
"""Test detect_and_parse handles 'parameters' key."""
|
||||
parser = Qwen3ToolParser()
|
||||
text = '<tool_call>\n{"name":"search_web","parameters":{"query":"test"}}\n</tool_call>'
|
||||
|
||||
result = parser.detect_and_parse(text, sample_tools)
|
||||
|
||||
assert len(result.calls) == 1
|
||||
assert result.calls[0].name == "search_web"
|
||||
assert json.loads(result.calls[0].parameters) == {"query": "test"}
|
||||
|
||||
def test_parse_streaming_increment_normal_text(self, sample_tools):
|
||||
"""Test streaming parser handles normal text without tool calls."""
|
||||
parser = Qwen3ToolParser()
|
||||
|
||||
result = parser.parse_streaming_increment("Hello, how can I help?",
|
||||
sample_tools)
|
||||
|
||||
assert result.normal_text == "Hello, how can I help?"
|
||||
assert len(result.calls) == 0
|
||||
|
||||
def test_parse_streaming_increment_partial_bot_token(self, sample_tools):
|
||||
"""Test streaming parser buffers partial bot token."""
|
||||
parser = Qwen3ToolParser()
|
||||
|
||||
# Send partial bot token
|
||||
result = parser.parse_streaming_increment("<tool", sample_tools)
|
||||
|
||||
# Should buffer
|
||||
assert result.normal_text == ""
|
||||
assert len(result.calls) == 0
|
||||
|
||||
def test_parse_streaming_increment_complete_tool_call(self, sample_tools):
|
||||
"""Test streaming parser with complete tool call in chunks."""
|
||||
parser = Qwen3ToolParser()
|
||||
|
||||
# Send bot token
|
||||
parser.parse_streaming_increment("<tool_call>\n", sample_tools)
|
||||
|
||||
# Send partial JSON with name
|
||||
result = parser.parse_streaming_increment('{"name":"get_weather"',
|
||||
sample_tools)
|
||||
|
||||
# Should send tool name
|
||||
assert len(result.calls) == 1
|
||||
assert result.calls[0].name == "get_weather"
|
||||
assert result.calls[0].parameters == ""
|
||||
|
||||
# Send arguments
|
||||
result = parser.parse_streaming_increment(
|
||||
',"arguments":{"location":"SF"}}\n</tool_call>', sample_tools)
|
||||
|
||||
# Should stream arguments
|
||||
assert len(result.calls) == 1
|
||||
assert json.loads(result.calls[0].parameters) == {"location": "SF"}
|
||||
|
||||
def test_parse_streaming_increment_end_token_handling(self, sample_tools):
|
||||
"""Test streaming parser handles end token correctly."""
|
||||
parser = Qwen3ToolParser()
|
||||
|
||||
# Send complete tool call
|
||||
parser.parse_streaming_increment(
|
||||
'<tool_call>\n{"name":"get_weather","arguments":{"location":"NYC"}}\n</tool_call>',
|
||||
sample_tools)
|
||||
|
||||
# The end token should be removed from normal text
|
||||
# Check buffer state
|
||||
assert parser._normal_text_buffer == ""
|
||||
|
||||
def test_parse_streaming_increment_multiple_tools_streaming(
|
||||
self, sample_tools):
|
||||
"""Test streaming parser handles multiple tool calls."""
|
||||
parser = Qwen3ToolParser()
|
||||
|
||||
# First tool
|
||||
parser.parse_streaming_increment('<tool_call>\n', sample_tools)
|
||||
parser.parse_streaming_increment(
|
||||
'{"name":"get_weather","arguments":{"location":"NYC"}}\n</tool_call>\n',
|
||||
sample_tools)
|
||||
|
||||
# Second tool
|
||||
parser.parse_streaming_increment('<tool_call>\n', sample_tools)
|
||||
result = parser.parse_streaming_increment('{"name":"search_web"',
|
||||
sample_tools)
|
||||
|
||||
# Should have started second tool
|
||||
assert result.calls[0].name == "search_web"
|
||||
assert result.calls[0].parameters == ""
|
||||
assert result.calls[0].tool_index == 1
|
||||
|
||||
def test_structure_info_function(self):
|
||||
"""Test structure_info returns correct lambda function."""
|
||||
parser = Qwen3ToolParser()
|
||||
func = parser.structure_info()
|
||||
|
||||
info = func("test_function")
|
||||
|
||||
assert isinstance(info, StructureInfo)
|
||||
assert info.begin == '<tool_call>\n{"name":"test_function", "arguments":'
|
||||
assert info.end == "}\n</tool_call>"
|
||||
assert info.trigger == "<tool_call>"
|
||||
|
||||
def test_structure_info_different_names(self):
|
||||
"""Test structure_info works with different function names."""
|
||||
parser = Qwen3ToolParser()
|
||||
func = parser.structure_info()
|
||||
|
||||
info1 = func("get_weather")
|
||||
info2 = func("search_web")
|
||||
|
||||
assert "get_weather" in info1.begin
|
||||
assert "search_web" in info2.begin
|
||||
assert info1.end == info2.end == "}\n</tool_call>"
|
||||
|
||||
def test_qwen3_format_compliance(self, sample_tools):
|
||||
"""Test that Qwen3ToolParser follows the documented format structure."""
|
||||
parser = Qwen3ToolParser()
|
||||
|
||||
# Test the exact format from the docstring
|
||||
text = '<tool_call>\n{"name":"get_weather", "arguments":{"location":"Tokyo"}}\n</tool_call>'
|
||||
|
||||
result = parser.detect_and_parse(text, sample_tools)
|
||||
|
||||
assert len(result.calls) == 1
|
||||
assert result.calls[0].name == "get_weather"
|
||||
assert json.loads(result.calls[0].parameters) == {"location": "Tokyo"}
|
||||
|
||||
def test_undefined_tool_in_qwen3_format(self, sample_tools):
|
||||
"""Test Qwen3ToolParser handles undefined tool gracefully."""
|
||||
parser = Qwen3ToolParser()
|
||||
text = '<tool_call>\n{"name":"undefined_func","arguments":{}}\n</tool_call>'
|
||||
|
||||
result = parser.detect_and_parse(text, sample_tools)
|
||||
|
||||
# Should not return any calls for undefined function
|
||||
assert len(result.calls) == 0
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Integration Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestToolParserIntegration:
|
||||
"""Integration tests for tool parsers."""
|
||||
|
||||
def test_end_to_end_single_tool(self, sample_tools):
|
||||
"""Test end-to-end parsing of a single tool call."""
|
||||
parser = Qwen3ToolParser()
|
||||
|
||||
# Simulate streaming
|
||||
chunks = [
|
||||
"<tool_call>\n", '{"name":"get', '_weather"', ',"arguments":',
|
||||
'{"location"', ':"Paris"}}\n', '</tool_call>'
|
||||
]
|
||||
|
||||
results = []
|
||||
for chunk in chunks:
|
||||
result = parser.parse_streaming_increment(chunk, sample_tools)
|
||||
if result.calls or result.normal_text:
|
||||
results.append(result)
|
||||
|
||||
# Should have received tool name and arguments
|
||||
assert any(r.calls for r in results)
|
||||
|
||||
def test_mixed_content_and_tool_calls(self, sample_tools):
|
||||
"""Test parsing text that mixes normal content with tool calls."""
|
||||
parser = Qwen3ToolParser()
|
||||
|
||||
text = (
|
||||
'I will check the weather for you.\n'
|
||||
'<tool_call>\n{"name":"get_weather","arguments":{"location":"London"}}\n</tool_call>\n'
|
||||
'Let me search that for you.')
|
||||
|
||||
result = parser.detect_and_parse(text, sample_tools)
|
||||
|
||||
assert "I will check the weather for you." in result.normal_text
|
||||
assert len(result.calls) == 1
|
||||
assert result.calls[0].name == "get_weather"
|
||||
|
||||
def test_parser_state_reset(self, sample_tools):
|
||||
"""Test that parser state can be used for multiple requests."""
|
||||
parser = Qwen3ToolParser()
|
||||
|
||||
# First request
|
||||
result1 = parser.detect_and_parse(
|
||||
'<tool_call>\n{"name":"get_weather","arguments":{"location":"NYC"}}\n</tool_call>',
|
||||
sample_tools)
|
||||
|
||||
# Reset internal state for new request
|
||||
parser2 = Qwen3ToolParser()
|
||||
|
||||
# Second request
|
||||
result2 = parser2.detect_and_parse(
|
||||
'<tool_call>\n{"name":"search_web","arguments":{"query":"test"}}\n</tool_call>',
|
||||
sample_tools)
|
||||
|
||||
assert result1.calls[0].name == "get_weather"
|
||||
assert result2.calls[0].name == "search_web"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
Loading…
Reference in New Issue
Block a user