[Refactor] Consolidate required/named tool_choice streaming into DelegatingParser (#41876)

Signed-off-by: sfeng33 <4florafeng@gmail.com>
This commit is contained in:
Flora Feng
2026-05-07 12:50:59 -04:00
committed by GitHub
parent 9d6500b89d
commit 8eb401134e
3 changed files with 41 additions and 195 deletions
+15 -22
View File
@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from copy import deepcopy
from unittest.mock import MagicMock
import pytest
import regex as re
@@ -11,7 +10,7 @@ from pydantic import TypeAdapter
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionToolsParam,
)
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
from vllm.tool_parsers.streaming import extract_required_tool_call_streaming
from vllm.tool_parsers.utils import get_json_schema_from_tools
pytestmark = pytest.mark.cpu_test
@@ -281,8 +280,6 @@ def test_structured_outputs_json_without_parameters(
@pytest.mark.parametrize("empty_params", [False, True])
@pytest.mark.parametrize("delta_len", [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
def test_streaming_output_valid(output, empty_params, delta_len):
self = MagicMock()
output = deepcopy(output)
if empty_params:
output = [{"name": o["name"], "parameters": {}} for o in output]
@@ -295,14 +292,13 @@ def test_streaming_output_valid(output, empty_params, delta_len):
delta_text = output_json[i : i + delta_len]
current_text = previous_text + delta_text
delta_message, function_name_returned = (
OpenAIServingChat.extract_tool_call_required_streaming(
self,
previous_text=previous_text,
current_text=current_text,
delta_text=delta_text,
function_name_returned=function_name_returned,
)
delta_message, function_name_returned = extract_required_tool_call_streaming(
previous_text=previous_text,
current_text=current_text,
delta_text=delta_text,
function_name_returned=function_name_returned,
tool_call_idx=None,
tool_call_id_type="random",
)
if delta_message:
@@ -332,8 +328,6 @@ def test_streaming_output_valid(output, empty_params, delta_len):
def test_streaming_output_valid_with_trailing_extra_data():
self = MagicMock()
output = [{"name": "get_current_weather", "parameters": {"city": "Vienna"}}]
output_json = json.dumps(output) + "\nDONE"
@@ -345,14 +339,13 @@ def test_streaming_output_valid_with_trailing_extra_data():
delta_text = output_json[i : i + delta_len]
current_text = previous_text + delta_text
delta_message, function_name_returned = (
OpenAIServingChat.extract_tool_call_required_streaming(
self,
previous_text=previous_text,
current_text=current_text,
delta_text=delta_text,
function_name_returned=function_name_returned,
)
delta_message, function_name_returned = extract_required_tool_call_streaming(
previous_text=previous_text,
current_text=current_text,
delta_text=delta_text,
function_name_returned=function_name_returned,
tool_call_idx=None,
tool_call_id_type="random",
)
if delta_message:
@@ -70,10 +70,6 @@ from vllm.reasoning import ReasoningParser
from vllm.renderers import ChatParams
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.streaming import (
extract_named_tool_call_streaming,
extract_required_tool_call_streaming,
)
from vllm.utils.collection_utils import as_list
from vllm.utils.mistral import is_mistral_tokenizer, is_mistral_tool_parser
@@ -389,23 +385,6 @@ class OpenAIServingChat(OpenAIServing):
return self.response_role
return request.messages[-1]["role"]
def extract_tool_call_required_streaming(
self,
previous_text: str,
current_text: str | None,
delta_text: str,
function_name_returned: bool,
tool_call_idx: int | None = None,
) -> tuple[DeltaMessage | None, bool]:
return extract_required_tool_call_streaming(
previous_text=previous_text,
current_text=current_text,
delta_text=delta_text,
function_name_returned=function_name_returned,
tool_call_idx=tool_call_idx,
tool_call_id_type=self.tool_call_id_type,
)
async def chat_completion_stream_generator(
self,
request: ChatCompletionRequest,
@@ -448,22 +427,7 @@ class OpenAIServingChat(OpenAIServing):
and self._should_stream_with_auto_tool_parsing(request)
)
# Determine whether required/named tool_choice should fall back to
# the auto tool_parser path instead of the standard JSON-based parsing.
# This happens when the parser declares supports_required_and_named=False
# (e.g. GLM models that output XML instead of JSON).
tool_choice_uses_parser = (
self.tool_parser is not None
and not self.tool_parser.supports_required_and_named
and request.tools
and (
request.tool_choice == "required"
or isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam)
)
)
all_previous_token_ids: list[list[int]] | None
function_name_returned = [False] * num_choices
if self.tool_call_id_type == "kimi_k2":
history_tool_call_cnt = get_history_tool_calls_cnt(conversation)
else:
@@ -477,10 +441,10 @@ class OpenAIServingChat(OpenAIServing):
if (
is_mistral_grammar_path
or tool_choice_auto
or tool_choice_uses_parser
or tool_choice_function_name
or request.tool_choice == "required"
or reasoning_parser
):
# These are only required in "auto" tool choice case
all_previous_token_ids = [[] for _ in range(num_choices)]
reasoning_end_arr = [False] * num_choices
prompt_is_reasoning_end_arr: list[bool | None] = [None] * num_choices
@@ -501,6 +465,10 @@ class OpenAIServingChat(OpenAIServing):
)
for _ in range(num_choices)
]
for p in parsers:
if p is not None:
p._stream_state.tool_call_id_type = self.tool_call_id_type
p._stream_state.history_tool_call_cnt = history_tool_call_cnt
else:
parsers = [None] * num_choices
except Exception as e:
@@ -677,7 +645,8 @@ class OpenAIServingChat(OpenAIServing):
if (
is_mistral_grammar_path
or tool_choice_auto
or tool_choice_uses_parser
or tool_choice_function_name
or request.tool_choice == "required"
or reasoning_parser
):
assert previous_texts is not None
@@ -731,135 +700,6 @@ class OpenAIServingChat(OpenAIServing):
current_token_ids = result.current_token_ids
if result.tools_called:
tools_streamed[i] = True
# handle streaming deltas for tools with named tool_choice
# Skip when tool_choice_uses_parser so it falls through
# to the auto tool_parser branches below.
elif tool_choice_function_name and not tool_choice_uses_parser:
# When encountering think end id in prompt_token_ids
# i.e {"enable_thinking": False},
# check BEFORE calling the parser to avoid a spurious
# reasoning delta on the first chunk.
if (
reasoning_parser
and not reasoning_end_arr[i]
and prompt_is_reasoning_end_arr[i]
):
reasoning_end_arr[i] = True
if (
reasoning_parser
and not reasoning_end_arr[i]
and not reasoning_parser.is_reasoning_end(
previous_token_ids
)
):
assert reasoning_parser is not None
delta_message = (
reasoning_parser.extract_reasoning_streaming(
previous_text,
current_text,
delta_text,
previous_token_ids,
current_token_ids,
output.token_ids,
)
)
# When encountering think end id in delta_token_ids,
# set reasoning status to end.
# Only keep 'content', remove 'reasoning'.
if reasoning_parser.is_reasoning_end(
as_list(output.token_ids)
):
reasoning_end_arr[i] = True
if delta_message and delta_message.content:
current_text = delta_message.content
delta_message.content = None
else:
current_text = ""
else:
# Just to add remaining `content`
if reasoning_parser:
delta_text = previous_text + delta_text
current_text = ""
delta_message, function_name_returned[i] = (
extract_named_tool_call_streaming(
delta_text=delta_text,
function_name=tool_choice_function_name,
function_name_returned=function_name_returned[i],
tool_call_idx=history_tool_call_cnt,
tool_call_id_type=self.tool_call_id_type,
tokenizer=tokenizer,
tool_call_array_index=i,
)
)
if (
delta_message
and delta_message.tool_calls
and delta_message.tool_calls[0].id is not None
):
history_tool_call_cnt += 1
tools_streamed[i] = True
# Skip when tool_choice_uses_parser so it falls through
# to the auto tool_parser branches below.
elif (
request.tool_choice == "required"
and not tool_choice_uses_parser
):
assert previous_texts is not None
previous_text = previous_texts[i]
current_text = previous_text + delta_text
fn_name_returned = function_name_returned[i]
output_token_ids = as_list(output.token_ids)
if (
reasoning_parser is not None
and not reasoning_end_arr[i]
and prompt_is_reasoning_end_arr[i]
):
reasoning_end_arr[i] = True
if reasoning_parser and not reasoning_end_arr[i]:
delta_message = (
reasoning_parser.extract_reasoning_streaming(
previous_text,
current_text,
delta_text,
previous_token_ids,
current_token_ids,
output_token_ids,
)
)
if reasoning_parser.is_reasoning_end(output_token_ids):
reasoning_end_arr[i] = True
if delta_message and delta_message.content:
current_text = delta_message.content
delta_message.content = None
else:
# reasoning ended
current_text = ""
else:
# either finished reasoning or no reasoning at all
content = current_text
delta_message, function_name_returned[i] = (
self.extract_tool_call_required_streaming(
previous_text=previous_text,
current_text=content,
delta_text=delta_text,
function_name_returned=fn_name_returned,
tool_call_idx=history_tool_call_cnt,
)
)
if (
delta_message
and delta_message.tool_calls
and delta_message.tool_calls[0].id is not None
):
history_tool_call_cnt += 1
tools_streamed[i] = True
elif parser is not None:
delta_message = parser.parse_delta(
@@ -878,7 +718,8 @@ class OpenAIServingChat(OpenAIServing):
if (
is_mistral_grammar_path
or tool_choice_auto
or tool_choice_uses_parser
or tool_choice_function_name
or request.tool_choice == "required"
or reasoning_parser
) and not self.use_harmony:
assert previous_texts is not None
+16 -4
View File
@@ -587,9 +587,15 @@ class DelegatingParser(Parser):
tool_call_id_type: str = "random",
function_name_returned: bool = False,
) -> tuple[DeltaMessage | None, bool]:
if request.tool_choice and isinstance(
request.tool_choice,
(ToolChoiceFunction, ChatCompletionNamedToolChoiceParam),
assert self._tool_parser is not None
supports_required_and_named = self._tool_parser.supports_required_and_named
if (
supports_required_and_named
and request.tool_choice
and isinstance(
request.tool_choice,
(ToolChoiceFunction, ChatCompletionNamedToolChoiceParam),
)
):
delta_message, function_name_returned = extract_named_tool_call_streaming(
delta_text=delta_text,
@@ -601,7 +607,7 @@ class DelegatingParser(Parser):
)
return delta_message, function_name_returned
if request.tool_choice == "required":
if supports_required_and_named and request.tool_choice == "required":
delta_message, function_name_returned = (
extract_required_tool_call_streaming(
previous_text=previous_text,
@@ -706,6 +712,12 @@ class DelegatingParser(Parser):
function_name_returned=state.function_name_returned,
)
)
if (
delta_message
and delta_message.tool_calls
and delta_message.tool_calls[0].id is not None
):
state.history_tool_call_cnt += 1
# No phase active: pass through as content
if (