[Refactor] Move unstreamed tool-arg flush from serving layer to parser (#44017)

Signed-off-by: sfeng33 <4florafeng@gmail.com>
This commit is contained in:
Flora Feng
2026-06-01 22:37:43 -04:00
committed by GitHub
parent 816cc73a9b
commit 9affc17a05
5 changed files with 138 additions and 264 deletions
@@ -1935,8 +1935,10 @@ async def test_streaming_n_gt1_independent_tool_parsers():
finished=True,
)
# Collect tool-call deltas per choice from the SSE stream.
# Collect tool-call deltas and finish_reasons per choice from the SSE
# stream.
tc_deltas_by_choice: dict[int, list[dict]] = {i: [] for i in range(num_choices)}
finish_reasons_by_choice: dict[int, list[str]] = {i: [] for i in range(num_choices)}
async for chunk_str in serving_chat.chat_completion_stream_generator(
request=request,
result_generator=result_generator(),
@@ -1959,6 +1961,8 @@ async def test_streaming_n_gt1_independent_tool_parsers():
if delta.get("tool_calls"):
for tc in delta["tool_calls"]:
tc_deltas_by_choice[idx].append(tc)
if choice.get("finish_reason") is not None:
finish_reasons_by_choice[idx].append(choice["finish_reason"])
# Both choices must independently produce the correct tool call.
for choice_idx in range(num_choices):
@@ -1984,141 +1988,11 @@ async def test_streaming_n_gt1_independent_tool_parsers():
f"Choice {choice_idx}: expected {{'city': 'Tokyo'}}, got {parsed_args}"
)
class TestCreateRemainingArgsDelta:
"""Tests for _create_remaining_args_delta helper function.
This helper is used when streaming tool calls to preserve id/type/name
fields in the finish chunk, which would otherwise be lost.
"""
def test_preserves_id_type_name(self):
"""Test that id, type, and name are preserved from original delta."""
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
from vllm.entrypoints.openai.engine.protocol import (
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
reasons = finish_reasons_by_choice[choice_idx]
assert len(reasons) == 1, (
f"Choice {choice_idx}: expected exactly 1 finish_reason, got {reasons}"
)
original_delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=0,
id="call_abc123",
type="function",
function=DeltaFunctionCall(
name="get_weather",
arguments='{"location": "Paris"}',
),
)
]
assert reasons[0] == "tool_calls", (
f"Choice {choice_idx}: expected finish_reason='tool_calls', "
f"got '{reasons[0]}'"
)
result = OpenAIServingChat._create_remaining_args_delta(
original_delta, '", "unit": "celsius"}', 0
)
assert len(result.tool_calls) == 1
tc = result.tool_calls[0]
assert tc.index == 0
assert tc.id == "call_abc123"
assert tc.type == "function"
assert tc.function.name == "get_weather"
assert tc.function.arguments == '", "unit": "celsius"}'
def test_matches_by_index(self):
"""Test that the correct tool call is matched by index."""
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
from vllm.entrypoints.openai.engine.protocol import (
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
)
original_delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=0,
id="call_first",
type="function",
function=DeltaFunctionCall(name="func_a", arguments="{}"),
),
DeltaToolCall(
index=1,
id="call_second",
type="function",
function=DeltaFunctionCall(name="func_b", arguments="{}"),
),
]
)
result = OpenAIServingChat._create_remaining_args_delta(
original_delta, '{"extra": true}', 1
)
assert len(result.tool_calls) == 1
tc = result.tool_calls[0]
assert tc.index == 1
assert tc.id == "call_second"
assert tc.function.name == "func_b"
def test_no_matching_tool_call(self):
"""Test graceful handling when no matching tool call is found."""
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
from vllm.entrypoints.openai.engine.protocol import (
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
)
original_delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=0,
id="call_zero",
type="function",
function=DeltaFunctionCall(name="func", arguments="{}"),
)
]
)
result = OpenAIServingChat._create_remaining_args_delta(
original_delta, '{"arg": 1}', 5
)
assert len(result.tool_calls) == 1
tc = result.tool_calls[0]
assert tc.index == 5
assert tc.id is None
assert tc.type is None
assert tc.function.name is None
assert tc.function.arguments == '{"arg": 1}'
def test_function_is_none(self):
"""Test handling when original tool call has no function."""
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
from vllm.entrypoints.openai.engine.protocol import DeltaMessage, DeltaToolCall
original_delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=0,
id="call_nofunc",
type="function",
function=None,
)
]
)
result = OpenAIServingChat._create_remaining_args_delta(
original_delta, '{"data": "value"}', 0
)
assert len(result.tool_calls) == 1
tc = result.tool_calls[0]
assert tc.index == 0
assert tc.id == "call_nofunc"
assert tc.type == "function"
assert tc.function.name is None
assert tc.function.arguments == '{"data": "value"}'
+83
View File
@@ -235,3 +235,86 @@ def test_parse_delta_reasoning_only_thinking_disabled(tokenizer, request_obj):
assert "Hello" in content
assert "assist" in content
assert len(tool_calls) == 0
def test_parse_delta_finished_no_flush_without_tool_call_delta(tokenizer, request_obj):
"""When finished=True but the final parse_delta produces no
tool-call delta, unstreamed args are not flushed."""
parser = make_parser(tokenizer, reasoning=False, tool=True)
results = stream_text(
parser, tokenizer, MODEL_OUTPUT, request_obj, prompt_token_ids=[]
)
_, _, tool_calls = collect_fields(results)
assert len(tool_calls) > 0
streamed = parser._tool_parser.streamed_args_for_tool[0]
assert len(streamed) > 5
parser._tool_parser.streamed_args_for_tool[0] = streamed[:-5]
# Prevent normal extraction from catching the gap — without a
# tool-call delta to merge into, the flush is skipped.
parser._tool_parser.extract_tool_calls_streaming = lambda *a, **kw: None
flush_result = parser.parse_delta("", [], request_obj, finished=True)
assert flush_result is None or flush_result.tool_calls is None
def test_parse_delta_finished_no_extra_args_when_fully_streamed(tokenizer, request_obj):
"""When all args have been streamed, finished=True must not
produce extra or duplicate arguments."""
parser = make_parser(tokenizer, reasoning=False, tool=True)
results = stream_text(
parser, tokenizer, MODEL_OUTPUT, request_obj, prompt_token_ids=[]
)
_, _, tool_calls = collect_fields(results)
assert len(tool_calls) > 0
assert tool_calls[0].function.name == "get_weather"
tool_args = "".join(
tc.function.arguments for tc in tool_calls if tc.function.arguments
)
assert json.loads(tool_args) == {"city": "Dallas"}
flush_result = parser.parse_delta("", [], request_obj, finished=True)
assert flush_result is None or flush_result.tool_calls is None
def test_parse_delta_finished_appends_remaining_args(tokenizer, request_obj):
"""When finished=True and the tool parser has unstreamed args,
parse_delta appends the remaining arguments to the tool-call delta."""
parser = make_parser(tokenizer, reasoning=False, tool=True)
token_ids = tokenizer.encode(MODEL_OUTPUT, add_special_tokens=False)
remainder = ',"unit":"celsius"}'
prompt_ids: list[int] | None = []
results: list[DeltaMessage | None] = []
for i, tid in enumerate(token_ids):
prev = results[-1] if results else None
prev_had_args = (
prev
and prev.tool_calls
and any(tc.function and tc.function.arguments for tc in prev.tool_calls)
)
if prev_had_args:
parser._tool_parser.get_remaining_unstreamed_args = lambda: remainder
result = parser.parse_delta(
tokenizer.decode([tid]),
[tid],
request_obj,
prompt_token_ids=prompt_ids,
finished=prev_had_args,
)
prompt_ids = None
results.append(result)
if prev_had_args:
break
_, _, tool_calls = collect_fields(results)
tool_args = "".join(
tc.function.arguments for tc in tool_calls if tc.function.arguments
)
assert tool_args.endswith(remainder)
@@ -3,7 +3,6 @@
import asyncio
import io
import json
import time
from collections.abc import AsyncGenerator, AsyncIterator
from collections.abc import Sequence as GenericSequence
@@ -40,9 +39,7 @@ from vllm.entrypoints.openai.chat_completion.stream_harmony import (
extract_harmony_streaming_delta,
)
from vllm.entrypoints.openai.engine.protocol import (
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ErrorResponse,
FunctionCall,
PromptTokenUsageInfo,
@@ -65,7 +62,7 @@ from vllm.entrypoints.utils import get_max_tokens, should_include_usage
from vllm.inputs import EngineInput
from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.outputs import RequestOutput
from vllm.parser import ParserManager
from vllm.parser.abstract_parser import Parser
from vllm.reasoning import ReasoningParser
@@ -715,6 +712,7 @@ class OpenAIServingChat(OpenAIServing):
delta_token_ids=as_list(output.token_ids),
request=request,
prompt_token_ids=res.prompt_token_ids,
finished=output.finish_reason is not None,
)
if delta_message and delta_message.tool_calls:
tools_streamed[i] = True
@@ -805,81 +803,13 @@ class OpenAIServingChat(OpenAIServing):
# finish_reason='error' indicates a retryable error
self._raise_if_error(output.finish_reason, request_id)
# check to make sure we haven't "forgotten" to stream
# any tokens that were generated but previously
# matched by partial json parsing
# only happens if we are NOT using structured outputs
index = 0
auto_tools_called = False
if tool_parser:
auto_tools_called = len(tool_parser.prev_tool_call_arr) > 0
index = (
len(tool_parser.prev_tool_call_arr) - 1
if auto_tools_called
else 0
)
should_check = (
self._should_check_for_unstreamed_tool_arg_tokens(
delta_message, output
)
)
# only check if there are any tool calls
# detected by partial parsing
if should_check and tool_parser and auto_tools_called:
latest_delta_len = 0
if (
isinstance(
delta_message.tool_calls[0].function,
DeltaFunctionCall,
)
) and isinstance(
delta_message.tool_calls[0].function.arguments, str
):
latest_delta_len = len(
delta_message.tool_calls[0].function.arguments
)
# get the expected call based on partial JSON
# parsing which "autocompletes" the JSON.
# Tool parsers (e.g. Qwen3Coder) store
# arguments as a JSON string in
# prev_tool_call_arr. Calling json.dumps()
# on an already-serialized string would
# double-serialize it (e.g. '{"k":1}' becomes
# '"{\\"k\\":1}"'), which then causes the
# replace() below to fail and append the
# entire double-serialized string as a
# spurious final delta.
args = tool_parser.prev_tool_call_arr[index].get(
"arguments", {}
)
if isinstance(args, str):
expected_call = args
else:
expected_call = json.dumps(args, ensure_ascii=False)
# get what we've streamed so far for arguments
# for the current tool
actual_call = tool_parser.streamed_args_for_tool[index]
if latest_delta_len > 0:
actual_call = actual_call[:-latest_delta_len]
# check to see if there's anything left to stream
remaining_call = expected_call.replace(actual_call, "", 1)
# set that as a delta message
delta_message = self._create_remaining_args_delta(
delta_message, remaining_call, index
)
# Send the finish response for each request.n only once
# In OpenAI's API, when a tool is called, the
# finish_reason is:
# "tool_calls" for "auto" or "required" tool calls,
# and "stop" for named tool calls.
if (
auto_tools_called
or (tools_streamed[i] and not tool_choice_function_name)
or (self.use_harmony and harmony_tools_streamed[i])
if (tools_streamed[i] and not tool_choice_function_name) or (
self.use_harmony and harmony_tools_streamed[i]
):
finish_reason_ = "tool_calls"
else:
@@ -1535,56 +1465,3 @@ class OpenAIServingChat(OpenAIServing):
and self.enable_auto_tools
and request.tool_choice in ["auto", None]
)
def _should_check_for_unstreamed_tool_arg_tokens(
self,
delta_message: DeltaMessage | None,
output: CompletionOutput,
) -> bool:
"""
Check to see if we should check for unstreamed tool arguments tokens.
This is only applicable when auto tool parsing is enabled, the delta
is a tool call with arguments.
"""
return bool(
# if there is a delta message that includes tool calls which
# include a function that has arguments
output.finish_reason is not None
and self.enable_auto_tools
and self.tool_parser
and delta_message
and delta_message.tool_calls
and delta_message.tool_calls[0]
and delta_message.tool_calls[0].function
and delta_message.tool_calls[0].function.arguments is not None
)
@staticmethod
def _create_remaining_args_delta(
delta_message: DeltaMessage,
remaining_call: str,
index: int,
) -> DeltaMessage:
"""
Create a delta message for remaining tool arguments, preserving
id/type/name from the original delta.
"""
original_tc = next(
(tc for tc in delta_message.tool_calls if tc.index == index),
None,
)
original_fn = original_tc.function if original_tc else None
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=index,
id=original_tc.id if original_tc else None,
type=original_tc.type if original_tc else None,
function=DeltaFunctionCall(
name=original_fn.name if original_fn else None,
arguments=remaining_call,
),
)
]
)
+21
View File
@@ -320,6 +320,7 @@ class Parser:
delta_token_ids: list[int],
request: ChatCompletionRequest | ResponsesRequest,
prompt_token_ids: list[int] | None = None,
finished: bool = False,
) -> DeltaMessage | None:
"""Parse a single streaming delta, orchestrating reasoning then
tool call extraction via internal stream state.
@@ -656,12 +657,28 @@ class DelegatingParser(Parser):
return False
return state.reasoning_ended
def _append_unstreamed_tool_args(
self,
delta_message: DeltaMessage | None,
) -> None:
"""Append parsed-but-unstreamed tool-call arguments to *delta_message*."""
if (
self._tool_parser is not None
and delta_message
and delta_message.tool_calls
and (last_tc := delta_message.tool_calls[-1]).function
):
last_tc.function.arguments = (
last_tc.function.arguments or ""
) + self._tool_parser.get_remaining_unstreamed_args()
def parse_delta(
self,
delta_text: str,
delta_token_ids: list[int],
request: ChatCompletionRequest | ResponsesRequest,
prompt_token_ids: list[int] | None = None,
finished: bool = False,
) -> DeltaMessage | None:
state = self._stream_state
@@ -745,6 +762,10 @@ class DelegatingParser(Parser):
state.previous_text = current_text
state.previous_token_ids = current_token_ids
if finished:
self._append_unstreamed_tool_args(delta_message)
return delta_message
+19
View File
@@ -79,6 +79,25 @@ class ToolParser:
else:
self.tools = []
def get_remaining_unstreamed_args(self) -> str:
"""Return tool call arguments parsed but not yet streamed."""
if not self.prev_tool_call_arr:
return ""
index = len(self.prev_tool_call_arr) - 1
args = self.prev_tool_call_arr[index].get("arguments", {})
if isinstance(args, str):
expected = args
else:
expected = json.dumps(args, ensure_ascii=False)
actual = (
self.streamed_args_for_tool[index]
if index < len(self.streamed_args_for_tool)
else ""
)
if expected.startswith(actual):
return expected[len(actual) :]
return ""
@cached_property
def vocab(self) -> dict[str, int]:
# NOTE: Only PreTrainedTokenizerFast is guaranteed to have .vocab