mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[Refactor] Consolidate required/named tool_choice streaming into DelegatingParser (#41876)
Signed-off-by: sfeng33 <4florafeng@gmail.com>
This commit is contained in:
@@ -2,7 +2,6 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import json
|
||||
from copy import deepcopy
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import regex as re
|
||||
@@ -11,7 +10,7 @@ from pydantic import TypeAdapter
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
ChatCompletionToolsParam,
|
||||
)
|
||||
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
|
||||
from vllm.tool_parsers.streaming import extract_required_tool_call_streaming
|
||||
from vllm.tool_parsers.utils import get_json_schema_from_tools
|
||||
|
||||
pytestmark = pytest.mark.cpu_test
|
||||
@@ -281,8 +280,6 @@ def test_structured_outputs_json_without_parameters(
|
||||
@pytest.mark.parametrize("empty_params", [False, True])
|
||||
@pytest.mark.parametrize("delta_len", [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
|
||||
def test_streaming_output_valid(output, empty_params, delta_len):
|
||||
self = MagicMock()
|
||||
|
||||
output = deepcopy(output)
|
||||
if empty_params:
|
||||
output = [{"name": o["name"], "parameters": {}} for o in output]
|
||||
@@ -295,14 +292,13 @@ def test_streaming_output_valid(output, empty_params, delta_len):
|
||||
delta_text = output_json[i : i + delta_len]
|
||||
current_text = previous_text + delta_text
|
||||
|
||||
delta_message, function_name_returned = (
|
||||
OpenAIServingChat.extract_tool_call_required_streaming(
|
||||
self,
|
||||
previous_text=previous_text,
|
||||
current_text=current_text,
|
||||
delta_text=delta_text,
|
||||
function_name_returned=function_name_returned,
|
||||
)
|
||||
delta_message, function_name_returned = extract_required_tool_call_streaming(
|
||||
previous_text=previous_text,
|
||||
current_text=current_text,
|
||||
delta_text=delta_text,
|
||||
function_name_returned=function_name_returned,
|
||||
tool_call_idx=None,
|
||||
tool_call_id_type="random",
|
||||
)
|
||||
|
||||
if delta_message:
|
||||
@@ -332,8 +328,6 @@ def test_streaming_output_valid(output, empty_params, delta_len):
|
||||
|
||||
|
||||
def test_streaming_output_valid_with_trailing_extra_data():
|
||||
self = MagicMock()
|
||||
|
||||
output = [{"name": "get_current_weather", "parameters": {"city": "Vienna"}}]
|
||||
output_json = json.dumps(output) + "\nDONE"
|
||||
|
||||
@@ -345,14 +339,13 @@ def test_streaming_output_valid_with_trailing_extra_data():
|
||||
delta_text = output_json[i : i + delta_len]
|
||||
current_text = previous_text + delta_text
|
||||
|
||||
delta_message, function_name_returned = (
|
||||
OpenAIServingChat.extract_tool_call_required_streaming(
|
||||
self,
|
||||
previous_text=previous_text,
|
||||
current_text=current_text,
|
||||
delta_text=delta_text,
|
||||
function_name_returned=function_name_returned,
|
||||
)
|
||||
delta_message, function_name_returned = extract_required_tool_call_streaming(
|
||||
previous_text=previous_text,
|
||||
current_text=current_text,
|
||||
delta_text=delta_text,
|
||||
function_name_returned=function_name_returned,
|
||||
tool_call_idx=None,
|
||||
tool_call_id_type="random",
|
||||
)
|
||||
|
||||
if delta_message:
|
||||
|
||||
@@ -70,10 +70,6 @@ from vllm.reasoning import ReasoningParser
|
||||
from vllm.renderers import ChatParams
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers.streaming import (
|
||||
extract_named_tool_call_streaming,
|
||||
extract_required_tool_call_streaming,
|
||||
)
|
||||
from vllm.utils.collection_utils import as_list
|
||||
from vllm.utils.mistral import is_mistral_tokenizer, is_mistral_tool_parser
|
||||
|
||||
@@ -389,23 +385,6 @@ class OpenAIServingChat(OpenAIServing):
|
||||
return self.response_role
|
||||
return request.messages[-1]["role"]
|
||||
|
||||
def extract_tool_call_required_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str | None,
|
||||
delta_text: str,
|
||||
function_name_returned: bool,
|
||||
tool_call_idx: int | None = None,
|
||||
) -> tuple[DeltaMessage | None, bool]:
|
||||
return extract_required_tool_call_streaming(
|
||||
previous_text=previous_text,
|
||||
current_text=current_text,
|
||||
delta_text=delta_text,
|
||||
function_name_returned=function_name_returned,
|
||||
tool_call_idx=tool_call_idx,
|
||||
tool_call_id_type=self.tool_call_id_type,
|
||||
)
|
||||
|
||||
async def chat_completion_stream_generator(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
@@ -448,22 +427,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
and self._should_stream_with_auto_tool_parsing(request)
|
||||
)
|
||||
|
||||
# Determine whether required/named tool_choice should fall back to
|
||||
# the auto tool_parser path instead of the standard JSON-based parsing.
|
||||
# This happens when the parser declares supports_required_and_named=False
|
||||
# (e.g. GLM models that output XML instead of JSON).
|
||||
tool_choice_uses_parser = (
|
||||
self.tool_parser is not None
|
||||
and not self.tool_parser.supports_required_and_named
|
||||
and request.tools
|
||||
and (
|
||||
request.tool_choice == "required"
|
||||
or isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam)
|
||||
)
|
||||
)
|
||||
|
||||
all_previous_token_ids: list[list[int]] | None
|
||||
function_name_returned = [False] * num_choices
|
||||
if self.tool_call_id_type == "kimi_k2":
|
||||
history_tool_call_cnt = get_history_tool_calls_cnt(conversation)
|
||||
else:
|
||||
@@ -477,10 +441,10 @@ class OpenAIServingChat(OpenAIServing):
|
||||
if (
|
||||
is_mistral_grammar_path
|
||||
or tool_choice_auto
|
||||
or tool_choice_uses_parser
|
||||
or tool_choice_function_name
|
||||
or request.tool_choice == "required"
|
||||
or reasoning_parser
|
||||
):
|
||||
# These are only required in "auto" tool choice case
|
||||
all_previous_token_ids = [[] for _ in range(num_choices)]
|
||||
reasoning_end_arr = [False] * num_choices
|
||||
prompt_is_reasoning_end_arr: list[bool | None] = [None] * num_choices
|
||||
@@ -501,6 +465,10 @@ class OpenAIServingChat(OpenAIServing):
|
||||
)
|
||||
for _ in range(num_choices)
|
||||
]
|
||||
for p in parsers:
|
||||
if p is not None:
|
||||
p._stream_state.tool_call_id_type = self.tool_call_id_type
|
||||
p._stream_state.history_tool_call_cnt = history_tool_call_cnt
|
||||
else:
|
||||
parsers = [None] * num_choices
|
||||
except Exception as e:
|
||||
@@ -677,7 +645,8 @@ class OpenAIServingChat(OpenAIServing):
|
||||
if (
|
||||
is_mistral_grammar_path
|
||||
or tool_choice_auto
|
||||
or tool_choice_uses_parser
|
||||
or tool_choice_function_name
|
||||
or request.tool_choice == "required"
|
||||
or reasoning_parser
|
||||
):
|
||||
assert previous_texts is not None
|
||||
@@ -731,135 +700,6 @@ class OpenAIServingChat(OpenAIServing):
|
||||
current_token_ids = result.current_token_ids
|
||||
if result.tools_called:
|
||||
tools_streamed[i] = True
|
||||
# handle streaming deltas for tools with named tool_choice
|
||||
# Skip when tool_choice_uses_parser so it falls through
|
||||
# to the auto tool_parser branches below.
|
||||
elif tool_choice_function_name and not tool_choice_uses_parser:
|
||||
# When encountering think end id in prompt_token_ids
|
||||
# i.e {"enable_thinking": False},
|
||||
# check BEFORE calling the parser to avoid a spurious
|
||||
# reasoning delta on the first chunk.
|
||||
if (
|
||||
reasoning_parser
|
||||
and not reasoning_end_arr[i]
|
||||
and prompt_is_reasoning_end_arr[i]
|
||||
):
|
||||
reasoning_end_arr[i] = True
|
||||
|
||||
if (
|
||||
reasoning_parser
|
||||
and not reasoning_end_arr[i]
|
||||
and not reasoning_parser.is_reasoning_end(
|
||||
previous_token_ids
|
||||
)
|
||||
):
|
||||
assert reasoning_parser is not None
|
||||
delta_message = (
|
||||
reasoning_parser.extract_reasoning_streaming(
|
||||
previous_text,
|
||||
current_text,
|
||||
delta_text,
|
||||
previous_token_ids,
|
||||
current_token_ids,
|
||||
output.token_ids,
|
||||
)
|
||||
)
|
||||
# When encountering think end id in delta_token_ids,
|
||||
# set reasoning status to end.
|
||||
# Only keep 'content', remove 'reasoning'.
|
||||
if reasoning_parser.is_reasoning_end(
|
||||
as_list(output.token_ids)
|
||||
):
|
||||
reasoning_end_arr[i] = True
|
||||
if delta_message and delta_message.content:
|
||||
current_text = delta_message.content
|
||||
delta_message.content = None
|
||||
else:
|
||||
current_text = ""
|
||||
else:
|
||||
# Just to add remaining `content`
|
||||
if reasoning_parser:
|
||||
delta_text = previous_text + delta_text
|
||||
current_text = ""
|
||||
|
||||
delta_message, function_name_returned[i] = (
|
||||
extract_named_tool_call_streaming(
|
||||
delta_text=delta_text,
|
||||
function_name=tool_choice_function_name,
|
||||
function_name_returned=function_name_returned[i],
|
||||
tool_call_idx=history_tool_call_cnt,
|
||||
tool_call_id_type=self.tool_call_id_type,
|
||||
tokenizer=tokenizer,
|
||||
tool_call_array_index=i,
|
||||
)
|
||||
)
|
||||
if (
|
||||
delta_message
|
||||
and delta_message.tool_calls
|
||||
and delta_message.tool_calls[0].id is not None
|
||||
):
|
||||
history_tool_call_cnt += 1
|
||||
tools_streamed[i] = True
|
||||
|
||||
# Skip when tool_choice_uses_parser so it falls through
|
||||
# to the auto tool_parser branches below.
|
||||
elif (
|
||||
request.tool_choice == "required"
|
||||
and not tool_choice_uses_parser
|
||||
):
|
||||
assert previous_texts is not None
|
||||
previous_text = previous_texts[i]
|
||||
current_text = previous_text + delta_text
|
||||
fn_name_returned = function_name_returned[i]
|
||||
output_token_ids = as_list(output.token_ids)
|
||||
|
||||
if (
|
||||
reasoning_parser is not None
|
||||
and not reasoning_end_arr[i]
|
||||
and prompt_is_reasoning_end_arr[i]
|
||||
):
|
||||
reasoning_end_arr[i] = True
|
||||
|
||||
if reasoning_parser and not reasoning_end_arr[i]:
|
||||
delta_message = (
|
||||
reasoning_parser.extract_reasoning_streaming(
|
||||
previous_text,
|
||||
current_text,
|
||||
delta_text,
|
||||
previous_token_ids,
|
||||
current_token_ids,
|
||||
output_token_ids,
|
||||
)
|
||||
)
|
||||
if reasoning_parser.is_reasoning_end(output_token_ids):
|
||||
reasoning_end_arr[i] = True
|
||||
if delta_message and delta_message.content:
|
||||
current_text = delta_message.content
|
||||
delta_message.content = None
|
||||
else:
|
||||
# reasoning ended
|
||||
current_text = ""
|
||||
|
||||
else:
|
||||
# either finished reasoning or no reasoning at all
|
||||
content = current_text
|
||||
|
||||
delta_message, function_name_returned[i] = (
|
||||
self.extract_tool_call_required_streaming(
|
||||
previous_text=previous_text,
|
||||
current_text=content,
|
||||
delta_text=delta_text,
|
||||
function_name_returned=fn_name_returned,
|
||||
tool_call_idx=history_tool_call_cnt,
|
||||
)
|
||||
)
|
||||
if (
|
||||
delta_message
|
||||
and delta_message.tool_calls
|
||||
and delta_message.tool_calls[0].id is not None
|
||||
):
|
||||
history_tool_call_cnt += 1
|
||||
tools_streamed[i] = True
|
||||
|
||||
elif parser is not None:
|
||||
delta_message = parser.parse_delta(
|
||||
@@ -878,7 +718,8 @@ class OpenAIServingChat(OpenAIServing):
|
||||
if (
|
||||
is_mistral_grammar_path
|
||||
or tool_choice_auto
|
||||
or tool_choice_uses_parser
|
||||
or tool_choice_function_name
|
||||
or request.tool_choice == "required"
|
||||
or reasoning_parser
|
||||
) and not self.use_harmony:
|
||||
assert previous_texts is not None
|
||||
|
||||
@@ -587,9 +587,15 @@ class DelegatingParser(Parser):
|
||||
tool_call_id_type: str = "random",
|
||||
function_name_returned: bool = False,
|
||||
) -> tuple[DeltaMessage | None, bool]:
|
||||
if request.tool_choice and isinstance(
|
||||
request.tool_choice,
|
||||
(ToolChoiceFunction, ChatCompletionNamedToolChoiceParam),
|
||||
assert self._tool_parser is not None
|
||||
supports_required_and_named = self._tool_parser.supports_required_and_named
|
||||
if (
|
||||
supports_required_and_named
|
||||
and request.tool_choice
|
||||
and isinstance(
|
||||
request.tool_choice,
|
||||
(ToolChoiceFunction, ChatCompletionNamedToolChoiceParam),
|
||||
)
|
||||
):
|
||||
delta_message, function_name_returned = extract_named_tool_call_streaming(
|
||||
delta_text=delta_text,
|
||||
@@ -601,7 +607,7 @@ class DelegatingParser(Parser):
|
||||
)
|
||||
return delta_message, function_name_returned
|
||||
|
||||
if request.tool_choice == "required":
|
||||
if supports_required_and_named and request.tool_choice == "required":
|
||||
delta_message, function_name_returned = (
|
||||
extract_required_tool_call_streaming(
|
||||
previous_text=previous_text,
|
||||
@@ -706,6 +712,12 @@ class DelegatingParser(Parser):
|
||||
function_name_returned=state.function_name_returned,
|
||||
)
|
||||
)
|
||||
if (
|
||||
delta_message
|
||||
and delta_message.tool_calls
|
||||
and delta_message.tool_calls[0].id is not None
|
||||
):
|
||||
state.history_tool_call_cnt += 1
|
||||
|
||||
# No phase active: pass through as content
|
||||
if (
|
||||
|
||||
Reference in New Issue
Block a user