mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[Frontend] add FunctionGemma tool parser support (#31218)
Signed-off-by: gateremark <gateremg@gmail.com>
This commit is contained in:
@@ -372,6 +372,30 @@ Supported models:
|
||||
|
||||
Flags: `--tool-call-parser glm47`
|
||||
|
||||
### FunctionGemma Models (`functiongemma`)
|
||||
|
||||
Google's FunctionGemma is a lightweight (270M parameter) model specifically designed for function calling.
|
||||
It's built on Gemma 3 and optimized for edge deployment on devices like laptops and phones.
|
||||
|
||||
Supported models:
|
||||
|
||||
* `google/functiongemma-270m-it`
|
||||
|
||||
FunctionGemma uses a unique output format with `<start_function_call>` and `<end_function_call>` tags:
|
||||
|
||||
```text
|
||||
<start_function_call>call:get_weather{location:<escape>London<escape>}<end_function_call>
|
||||
```
|
||||
|
||||
The model is designed to be fine-tuned for specific function-calling tasks for best results.
|
||||
|
||||
Flags: `--tool-call-parser functiongemma --chat-template examples/tool_chat_template_functiongemma.jinja`
|
||||
|
||||
!!! note
|
||||
FunctionGemma is intended to be fine-tuned for your specific function-calling task.
|
||||
The base model provides general function calling capabilities, but best results
|
||||
are achieved with task-specific fine-tuning. See Google's [FunctionGemma documentation](https://ai.google.dev/gemma/docs/functiongemma) for fine-tuning guides.
|
||||
|
||||
### Qwen3-Coder Models (`qwen3_xml`)
|
||||
|
||||
Supported models:
|
||||
|
||||
@@ -0,0 +1,54 @@
|
||||
{%- set ns = namespace(developer_content='', has_tools=false) -%}
|
||||
|
||||
{%- if tools is defined and tools | length > 0 -%}
|
||||
{%- set ns.has_tools = true -%}
|
||||
{%- endif -%}
|
||||
|
||||
{%- for message in messages -%}
|
||||
{%- if message.role == 'developer' or message.role == 'system' -%}
|
||||
<start_of_turn>user
|
||||
{{ message.content }}
|
||||
{%- if ns.has_tools %}
|
||||
|
||||
Available functions:
|
||||
{%- for tool in tools %}
|
||||
{%- if tool.type == 'function' %}
|
||||
|
||||
Function: {{ tool.function.name }}
|
||||
Description: {{ tool.function.description | default('No description provided') }}
|
||||
Parameters: {{ tool.function.parameters | tojson }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
<end_of_turn>
|
||||
{%- elif message.role == 'user' -%}
|
||||
<start_of_turn>user
|
||||
{{ message.content }}<end_of_turn>
|
||||
{%- elif message.role == 'assistant' -%}
|
||||
{%- if message.tool_calls is defined and message.tool_calls | length > 0 -%}
|
||||
<start_of_turn>model
|
||||
{%- for tool_call in message.tool_calls %}
|
||||
<start_function_call>call:{{ tool_call.function.name }}{
|
||||
{%- set args = tool_call.function.arguments -%}
|
||||
{%- if args is string -%}
|
||||
{%- set args = args | fromjson -%}
|
||||
{%- endif -%}
|
||||
{%- for key, value in args.items() -%}
|
||||
{{ key }}:<escape>{{ value }}<escape>{% if not loop.last %},{% endif %}
|
||||
{%- endfor -%}
|
||||
}<end_function_call>
|
||||
{%- endfor %}
|
||||
<end_of_turn>
|
||||
{%- else -%}
|
||||
<start_of_turn>model
|
||||
{{ message.content }}<end_of_turn>
|
||||
{%- endif -%}
|
||||
{%- elif message.role == 'tool' -%}
|
||||
<start_of_turn>user
|
||||
Function result for {{ message.name | default('function') }}: {{ message.content }}<end_of_turn>
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
|
||||
{%- if add_generation_prompt -%}
|
||||
<start_of_turn>model
|
||||
{%- endif -%}
|
||||
@@ -0,0 +1,154 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
||||
from vllm.tool_parsers.functiongemma_tool_parser import FunctionGemmaToolParser
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tokenizer():
|
||||
tokenizer = MagicMock()
|
||||
tokenizer.encode.return_value = [1, 2, 3]
|
||||
tokenizer.get_vocab.return_value = {}
|
||||
return tokenizer
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def parser(mock_tokenizer):
|
||||
return FunctionGemmaToolParser(mock_tokenizer)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_request():
|
||||
request = MagicMock(spec=ChatCompletionRequest)
|
||||
request.tools = []
|
||||
request.tool_choice = "auto"
|
||||
return request
|
||||
|
||||
|
||||
class TestExtractToolCalls:
|
||||
def test_no_tool_calls(self, parser, mock_request):
|
||||
model_output = "Hello, how can I help you today?"
|
||||
result = parser.extract_tool_calls(model_output, mock_request)
|
||||
|
||||
assert result.tools_called is False
|
||||
assert result.tool_calls == []
|
||||
assert result.content == model_output
|
||||
|
||||
def test_single_tool_call(self, parser, mock_request):
|
||||
model_output = (
|
||||
"<start_function_call>call:get_weather{location:<escape>London<escape>}"
|
||||
"<end_function_call>"
|
||||
)
|
||||
result = parser.extract_tool_calls(model_output, mock_request)
|
||||
|
||||
assert result.tools_called is True
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0].function.name == "get_weather"
|
||||
assert '"location": "London"' in result.tool_calls[0].function.arguments
|
||||
|
||||
def test_multiple_arguments(self, parser, mock_request):
|
||||
model_output = (
|
||||
"<start_function_call>call:get_weather{"
|
||||
"location:<escape>San Francisco<escape>,"
|
||||
"unit:<escape>celsius<escape>}"
|
||||
"<end_function_call>"
|
||||
)
|
||||
result = parser.extract_tool_calls(model_output, mock_request)
|
||||
|
||||
assert result.tools_called is True
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0].function.name == "get_weather"
|
||||
args = result.tool_calls[0].function.arguments
|
||||
assert "San Francisco" in args
|
||||
assert "celsius" in args
|
||||
|
||||
def test_text_before_tool_call(self, parser, mock_request):
|
||||
model_output = (
|
||||
"Let me check the weather for you. "
|
||||
"<start_function_call>call:get_weather{location:<escape>Paris<escape>}"
|
||||
"<end_function_call>"
|
||||
)
|
||||
result = parser.extract_tool_calls(model_output, mock_request)
|
||||
|
||||
assert result.tools_called is True
|
||||
assert result.content == "Let me check the weather for you."
|
||||
|
||||
def test_multiple_tool_calls(self, parser, mock_request):
|
||||
model_output = (
|
||||
"<start_function_call>call:get_weather{location:<escape>London<escape>}"
|
||||
"<end_function_call>"
|
||||
"<start_function_call>call:get_time{timezone:<escape>UTC<escape>}"
|
||||
"<end_function_call>"
|
||||
)
|
||||
result = parser.extract_tool_calls(model_output, mock_request)
|
||||
|
||||
assert result.tools_called is True
|
||||
assert len(result.tool_calls) == 2
|
||||
assert result.tool_calls[0].function.name == "get_weather"
|
||||
assert result.tool_calls[1].function.name == "get_time"
|
||||
|
||||
|
||||
class TestParseArguments:
|
||||
def test_empty_arguments(self, parser):
|
||||
result = parser._parse_arguments("")
|
||||
assert result == {}
|
||||
|
||||
def test_single_string_argument(self, parser):
|
||||
result = parser._parse_arguments("city:<escape>Tokyo<escape>")
|
||||
assert result == {"city": "Tokyo"}
|
||||
|
||||
def test_multiple_arguments(self, parser):
|
||||
args_str = "city:<escape>Tokyo<escape>,country:<escape>Japan<escape>"
|
||||
result = parser._parse_arguments(args_str)
|
||||
assert result == {"city": "Tokyo", "country": "Japan"}
|
||||
|
||||
def test_numeric_argument(self, parser):
|
||||
result = parser._parse_arguments("count:<escape>42<escape>")
|
||||
assert result == {"count": 42}
|
||||
|
||||
def test_boolean_argument(self, parser):
|
||||
result = parser._parse_arguments("enabled:<escape>true<escape>")
|
||||
assert result == {"enabled": True}
|
||||
|
||||
def test_argument_with_spaces(self, parser):
|
||||
result = parser._parse_arguments("message:<escape>Hello World<escape>")
|
||||
assert result == {"message": "Hello World"}
|
||||
|
||||
|
||||
class TestAdjustRequest:
|
||||
def test_skip_special_tokens_disabled(self, parser, mock_request):
|
||||
mock_request.tools = [{"type": "function", "function": {"name": "test"}}]
|
||||
mock_request.tool_choice = "auto"
|
||||
mock_request.skip_special_tokens = True
|
||||
|
||||
result = parser.adjust_request(mock_request)
|
||||
assert result.skip_special_tokens is False
|
||||
|
||||
def test_skip_special_tokens_when_tool_choice_none(self, parser, mock_request):
|
||||
mock_request.tools = [{"type": "function", "function": {"name": "test"}}]
|
||||
mock_request.tool_choice = "none"
|
||||
mock_request.skip_special_tokens = True
|
||||
|
||||
result = parser.adjust_request(mock_request)
|
||||
assert result.skip_special_tokens is True
|
||||
|
||||
|
||||
class TestBufferDeltaText:
|
||||
def test_regular_text_not_buffered(self, parser):
|
||||
result = parser._buffer_delta_text("hello")
|
||||
assert result == "hello"
|
||||
assert parser.buffered_delta_text == ""
|
||||
|
||||
def test_complete_tag_flushed(self, parser):
|
||||
parser.buffered_delta_text = "<start_function_"
|
||||
result = parser._buffer_delta_text("call>")
|
||||
assert "<start_function_call>" in result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
@@ -142,6 +142,10 @@ _TOOL_PARSERS_TO_REGISTER = {
|
||||
"gigachat3_tool_parser",
|
||||
"GigaChat3ToolParser",
|
||||
),
|
||||
"functiongemma": (
|
||||
"functiongemma_tool_parser",
|
||||
"FunctionGemmaToolParser",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,321 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
|
||||
import regex as re
|
||||
|
||||
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
DeltaFunctionCall,
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers.abstract_tool_parser import ToolParser
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class FunctionGemmaToolParser(ToolParser):
|
||||
"""
|
||||
Tool parser for Google's FunctionGemma model (google/functiongemma-270m-it).
|
||||
|
||||
Handles the FunctionGemma function call format:
|
||||
<start_function_call>call:func_name{param:<escape>value<escape>}<end_function_call>
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: TokenizerLike):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
# Streaming state
|
||||
self.current_tool_name_sent: bool = False
|
||||
self.prev_tool_call_arr: list[dict] = []
|
||||
self.current_tool_id: int = -1
|
||||
self.streamed_args_for_tool: list[str] = []
|
||||
|
||||
# FunctionGemma tokens
|
||||
self.tool_call_start_token: str = "<start_function_call>"
|
||||
self.tool_call_end_token: str = "<end_function_call>"
|
||||
|
||||
# Regex patterns
|
||||
self.tool_call_regex = re.compile(
|
||||
r"<start_function_call>call:(\w+)\{(.*?)\}<end_function_call>"
|
||||
r"|<start_function_call>call:(\w+)\{(.*)",
|
||||
re.DOTALL,
|
||||
)
|
||||
self.arg_regex = re.compile(
|
||||
r"(\w+):<escape>(.*?)<escape>",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
if self.model_tokenizer:
|
||||
self.tool_call_start_token_ids = self.model_tokenizer.encode(
|
||||
self.tool_call_start_token, add_special_tokens=False
|
||||
)
|
||||
self.tool_call_end_token_ids = self.model_tokenizer.encode(
|
||||
self.tool_call_end_token, add_special_tokens=False
|
||||
)
|
||||
else:
|
||||
self.tool_call_start_token_ids = []
|
||||
self.tool_call_end_token_ids = []
|
||||
|
||||
self.buffered_delta_text = ""
|
||||
|
||||
def _parse_arguments(self, args_str: str) -> dict:
|
||||
"""Parse FunctionGemma argument string into a dictionary."""
|
||||
arguments = {}
|
||||
if not args_str:
|
||||
return arguments
|
||||
|
||||
matches = self.arg_regex.findall(args_str)
|
||||
for key, value in matches:
|
||||
try:
|
||||
parsed_value = json.loads(value)
|
||||
arguments[key] = parsed_value
|
||||
except json.JSONDecodeError:
|
||||
arguments[key] = value
|
||||
|
||||
return arguments
|
||||
|
||||
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
|
||||
request = super().adjust_request(request)
|
||||
if request.tools and request.tool_choice != "none":
|
||||
request.skip_special_tokens = False
|
||||
return request
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
if self.tool_call_start_token not in model_output:
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
try:
|
||||
matches = self.tool_call_regex.findall(model_output)
|
||||
|
||||
if not matches:
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
tool_calls: list[ToolCall] = []
|
||||
|
||||
for match in matches:
|
||||
func_name = match[0] if match[0] else match[2]
|
||||
args_str = match[1] if match[1] else match[3]
|
||||
|
||||
if not func_name:
|
||||
continue
|
||||
|
||||
arguments = self._parse_arguments(args_str)
|
||||
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=func_name,
|
||||
arguments=json.dumps(arguments, ensure_ascii=False),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if tool_calls:
|
||||
content_end = model_output.find(self.tool_call_start_token)
|
||||
content = (
|
||||
model_output[:content_end].strip() if content_end > 0 else None
|
||||
)
|
||||
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=tool_calls,
|
||||
content=content if content else None,
|
||||
)
|
||||
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error extracting tool calls from FunctionGemma response")
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
def _buffer_delta_text(self, delta_text: str) -> str:
|
||||
"""Buffer incoming delta text to handle multi-token special sequences."""
|
||||
potential_start = "<start_function_call>"
|
||||
potential_end = "<end_function_call>"
|
||||
|
||||
combined = self.buffered_delta_text + delta_text
|
||||
|
||||
if combined.endswith(potential_start) or combined.endswith(potential_end):
|
||||
self.buffered_delta_text = ""
|
||||
return combined
|
||||
|
||||
for tag in [potential_start, potential_end]:
|
||||
for i in range(1, len(tag)):
|
||||
if combined.endswith(tag[:i]):
|
||||
self.buffered_delta_text = combined[-(i):]
|
||||
return combined[:-i]
|
||||
|
||||
self.buffered_delta_text = ""
|
||||
return combined
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> DeltaMessage | None:
|
||||
delta_text = self._buffer_delta_text(delta_text)
|
||||
current_text = previous_text + delta_text
|
||||
|
||||
if self.tool_call_start_token not in current_text:
|
||||
if delta_text:
|
||||
return DeltaMessage(content=delta_text)
|
||||
return None
|
||||
|
||||
try:
|
||||
start_count = current_text.count(self.tool_call_start_token)
|
||||
end_count = current_text.count(self.tool_call_end_token)
|
||||
prev_start_count = previous_text.count(self.tool_call_start_token)
|
||||
prev_end_count = previous_text.count(self.tool_call_end_token)
|
||||
|
||||
if self.tool_call_start_token not in current_text:
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
# Starting a new function call
|
||||
if start_count > prev_start_count and start_count > end_count:
|
||||
self.current_tool_id += 1
|
||||
self.current_tool_name_sent = False
|
||||
self.streamed_args_for_tool.append("")
|
||||
self.prev_tool_call_arr.append({})
|
||||
logger.debug("Starting new tool call %d", self.current_tool_id)
|
||||
return None
|
||||
|
||||
# In the middle of a function call
|
||||
if start_count > end_count:
|
||||
last_start = current_text.rfind(self.tool_call_start_token)
|
||||
partial_call = current_text[
|
||||
last_start + len(self.tool_call_start_token) :
|
||||
]
|
||||
|
||||
if partial_call.startswith("call:"):
|
||||
func_part = partial_call[5:]
|
||||
|
||||
if "{" in func_part:
|
||||
func_name = func_part.split("{")[0]
|
||||
args_part = (
|
||||
func_part.split("{", 1)[1] if "{" in func_part else ""
|
||||
)
|
||||
|
||||
if not self.current_tool_name_sent and func_name:
|
||||
self.current_tool_name_sent = True
|
||||
self.prev_tool_call_arr[self.current_tool_id] = {
|
||||
"name": func_name,
|
||||
"arguments": {},
|
||||
}
|
||||
return DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
type="function",
|
||||
id=make_tool_call_id(),
|
||||
function=DeltaFunctionCall(
|
||||
name=func_name
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
if self.current_tool_name_sent and args_part:
|
||||
current_args = self._parse_arguments(args_part)
|
||||
if current_args:
|
||||
current_args_json = json.dumps(
|
||||
current_args, ensure_ascii=False
|
||||
)
|
||||
prev_streamed = self.streamed_args_for_tool[
|
||||
self.current_tool_id
|
||||
]
|
||||
|
||||
if len(current_args_json) > len(prev_streamed):
|
||||
diff = current_args_json[len(prev_streamed) :]
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id
|
||||
] = current_args_json
|
||||
self.prev_tool_call_arr[self.current_tool_id][
|
||||
"arguments"
|
||||
] = current_args
|
||||
|
||||
return DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=diff
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
# Function call just ended
|
||||
if end_count > prev_end_count:
|
||||
if self.current_tool_id >= 0 and self.current_tool_id < len(
|
||||
self.prev_tool_call_arr
|
||||
):
|
||||
all_calls = self.tool_call_regex.findall(current_text)
|
||||
args = {}
|
||||
if self.current_tool_id < len(all_calls):
|
||||
match = all_calls[self.current_tool_id]
|
||||
if match[0]:
|
||||
args_str = match[1]
|
||||
args = self._parse_arguments(args_str)
|
||||
self.prev_tool_call_arr[self.current_tool_id][
|
||||
"arguments"
|
||||
] = args
|
||||
|
||||
if args:
|
||||
args_json = json.dumps(args, ensure_ascii=False)
|
||||
prev_streamed = self.streamed_args_for_tool[
|
||||
self.current_tool_id
|
||||
]
|
||||
if len(args_json) > len(prev_streamed):
|
||||
diff = args_json[len(prev_streamed) :]
|
||||
self.streamed_args_for_tool[self.current_tool_id] = (
|
||||
args_json
|
||||
)
|
||||
return DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=diff
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
]
|
||||
)
|
||||
return None
|
||||
|
||||
if delta_text:
|
||||
return DeltaMessage(content=delta_text)
|
||||
return None
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error in streaming tool call extraction")
|
||||
return None
|
||||
Reference in New Issue
Block a user