mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[Refactor] Move unstreamed tool-arg flush from serving layer to parser (#44017)
Signed-off-by: sfeng33 <4florafeng@gmail.com>
This commit is contained in:
@@ -1935,8 +1935,10 @@ async def test_streaming_n_gt1_independent_tool_parsers():
|
||||
finished=True,
|
||||
)
|
||||
|
||||
# Collect tool-call deltas per choice from the SSE stream.
|
||||
# Collect tool-call deltas and finish_reasons per choice from the SSE
|
||||
# stream.
|
||||
tc_deltas_by_choice: dict[int, list[dict]] = {i: [] for i in range(num_choices)}
|
||||
finish_reasons_by_choice: dict[int, list[str]] = {i: [] for i in range(num_choices)}
|
||||
async for chunk_str in serving_chat.chat_completion_stream_generator(
|
||||
request=request,
|
||||
result_generator=result_generator(),
|
||||
@@ -1959,6 +1961,8 @@ async def test_streaming_n_gt1_independent_tool_parsers():
|
||||
if delta.get("tool_calls"):
|
||||
for tc in delta["tool_calls"]:
|
||||
tc_deltas_by_choice[idx].append(tc)
|
||||
if choice.get("finish_reason") is not None:
|
||||
finish_reasons_by_choice[idx].append(choice["finish_reason"])
|
||||
|
||||
# Both choices must independently produce the correct tool call.
|
||||
for choice_idx in range(num_choices):
|
||||
@@ -1984,141 +1988,11 @@ async def test_streaming_n_gt1_independent_tool_parsers():
|
||||
f"Choice {choice_idx}: expected {{'city': 'Tokyo'}}, got {parsed_args}"
|
||||
)
|
||||
|
||||
|
||||
class TestCreateRemainingArgsDelta:
|
||||
"""Tests for _create_remaining_args_delta helper function.
|
||||
|
||||
This helper is used when streaming tool calls to preserve id/type/name
|
||||
fields in the finish chunk, which would otherwise be lost.
|
||||
"""
|
||||
|
||||
def test_preserves_id_type_name(self):
|
||||
"""Test that id, type, and name are preserved from original delta."""
|
||||
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
DeltaFunctionCall,
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
reasons = finish_reasons_by_choice[choice_idx]
|
||||
assert len(reasons) == 1, (
|
||||
f"Choice {choice_idx}: expected exactly 1 finish_reason, got {reasons}"
|
||||
)
|
||||
|
||||
original_delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=0,
|
||||
id="call_abc123",
|
||||
type="function",
|
||||
function=DeltaFunctionCall(
|
||||
name="get_weather",
|
||||
arguments='{"location": "Paris"}',
|
||||
),
|
||||
)
|
||||
]
|
||||
assert reasons[0] == "tool_calls", (
|
||||
f"Choice {choice_idx}: expected finish_reason='tool_calls', "
|
||||
f"got '{reasons[0]}'"
|
||||
)
|
||||
|
||||
result = OpenAIServingChat._create_remaining_args_delta(
|
||||
original_delta, '", "unit": "celsius"}', 0
|
||||
)
|
||||
|
||||
assert len(result.tool_calls) == 1
|
||||
tc = result.tool_calls[0]
|
||||
assert tc.index == 0
|
||||
assert tc.id == "call_abc123"
|
||||
assert tc.type == "function"
|
||||
assert tc.function.name == "get_weather"
|
||||
assert tc.function.arguments == '", "unit": "celsius"}'
|
||||
|
||||
def test_matches_by_index(self):
|
||||
"""Test that the correct tool call is matched by index."""
|
||||
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
DeltaFunctionCall,
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
)
|
||||
|
||||
original_delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=0,
|
||||
id="call_first",
|
||||
type="function",
|
||||
function=DeltaFunctionCall(name="func_a", arguments="{}"),
|
||||
),
|
||||
DeltaToolCall(
|
||||
index=1,
|
||||
id="call_second",
|
||||
type="function",
|
||||
function=DeltaFunctionCall(name="func_b", arguments="{}"),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
result = OpenAIServingChat._create_remaining_args_delta(
|
||||
original_delta, '{"extra": true}', 1
|
||||
)
|
||||
|
||||
assert len(result.tool_calls) == 1
|
||||
tc = result.tool_calls[0]
|
||||
assert tc.index == 1
|
||||
assert tc.id == "call_second"
|
||||
assert tc.function.name == "func_b"
|
||||
|
||||
def test_no_matching_tool_call(self):
|
||||
"""Test graceful handling when no matching tool call is found."""
|
||||
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
DeltaFunctionCall,
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
)
|
||||
|
||||
original_delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=0,
|
||||
id="call_zero",
|
||||
type="function",
|
||||
function=DeltaFunctionCall(name="func", arguments="{}"),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
result = OpenAIServingChat._create_remaining_args_delta(
|
||||
original_delta, '{"arg": 1}', 5
|
||||
)
|
||||
|
||||
assert len(result.tool_calls) == 1
|
||||
tc = result.tool_calls[0]
|
||||
assert tc.index == 5
|
||||
assert tc.id is None
|
||||
assert tc.type is None
|
||||
assert tc.function.name is None
|
||||
assert tc.function.arguments == '{"arg": 1}'
|
||||
|
||||
def test_function_is_none(self):
|
||||
"""Test handling when original tool call has no function."""
|
||||
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.engine.protocol import DeltaMessage, DeltaToolCall
|
||||
|
||||
original_delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=0,
|
||||
id="call_nofunc",
|
||||
type="function",
|
||||
function=None,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
result = OpenAIServingChat._create_remaining_args_delta(
|
||||
original_delta, '{"data": "value"}', 0
|
||||
)
|
||||
|
||||
assert len(result.tool_calls) == 1
|
||||
tc = result.tool_calls[0]
|
||||
assert tc.index == 0
|
||||
assert tc.id == "call_nofunc"
|
||||
assert tc.type == "function"
|
||||
assert tc.function.name is None
|
||||
assert tc.function.arguments == '{"data": "value"}'
|
||||
|
||||
@@ -235,3 +235,86 @@ def test_parse_delta_reasoning_only_thinking_disabled(tokenizer, request_obj):
|
||||
assert "Hello" in content
|
||||
assert "assist" in content
|
||||
assert len(tool_calls) == 0
|
||||
|
||||
|
||||
def test_parse_delta_finished_no_flush_without_tool_call_delta(tokenizer, request_obj):
|
||||
"""When finished=True but the final parse_delta produces no
|
||||
tool-call delta, unstreamed args are not flushed."""
|
||||
parser = make_parser(tokenizer, reasoning=False, tool=True)
|
||||
|
||||
results = stream_text(
|
||||
parser, tokenizer, MODEL_OUTPUT, request_obj, prompt_token_ids=[]
|
||||
)
|
||||
_, _, tool_calls = collect_fields(results)
|
||||
assert len(tool_calls) > 0
|
||||
|
||||
streamed = parser._tool_parser.streamed_args_for_tool[0]
|
||||
assert len(streamed) > 5
|
||||
parser._tool_parser.streamed_args_for_tool[0] = streamed[:-5]
|
||||
|
||||
# Prevent normal extraction from catching the gap — without a
|
||||
# tool-call delta to merge into, the flush is skipped.
|
||||
parser._tool_parser.extract_tool_calls_streaming = lambda *a, **kw: None
|
||||
|
||||
flush_result = parser.parse_delta("", [], request_obj, finished=True)
|
||||
assert flush_result is None or flush_result.tool_calls is None
|
||||
|
||||
|
||||
def test_parse_delta_finished_no_extra_args_when_fully_streamed(tokenizer, request_obj):
|
||||
"""When all args have been streamed, finished=True must not
|
||||
produce extra or duplicate arguments."""
|
||||
parser = make_parser(tokenizer, reasoning=False, tool=True)
|
||||
results = stream_text(
|
||||
parser, tokenizer, MODEL_OUTPUT, request_obj, prompt_token_ids=[]
|
||||
)
|
||||
_, _, tool_calls = collect_fields(results)
|
||||
|
||||
assert len(tool_calls) > 0
|
||||
assert tool_calls[0].function.name == "get_weather"
|
||||
tool_args = "".join(
|
||||
tc.function.arguments for tc in tool_calls if tc.function.arguments
|
||||
)
|
||||
assert json.loads(tool_args) == {"city": "Dallas"}
|
||||
|
||||
flush_result = parser.parse_delta("", [], request_obj, finished=True)
|
||||
assert flush_result is None or flush_result.tool_calls is None
|
||||
|
||||
|
||||
def test_parse_delta_finished_appends_remaining_args(tokenizer, request_obj):
|
||||
"""When finished=True and the tool parser has unstreamed args,
|
||||
parse_delta appends the remaining arguments to the tool-call delta."""
|
||||
parser = make_parser(tokenizer, reasoning=False, tool=True)
|
||||
token_ids = tokenizer.encode(MODEL_OUTPUT, add_special_tokens=False)
|
||||
|
||||
remainder = ',"unit":"celsius"}'
|
||||
prompt_ids: list[int] | None = []
|
||||
results: list[DeltaMessage | None] = []
|
||||
for i, tid in enumerate(token_ids):
|
||||
prev = results[-1] if results else None
|
||||
prev_had_args = (
|
||||
prev
|
||||
and prev.tool_calls
|
||||
and any(tc.function and tc.function.arguments for tc in prev.tool_calls)
|
||||
)
|
||||
|
||||
if prev_had_args:
|
||||
parser._tool_parser.get_remaining_unstreamed_args = lambda: remainder
|
||||
|
||||
result = parser.parse_delta(
|
||||
tokenizer.decode([tid]),
|
||||
[tid],
|
||||
request_obj,
|
||||
prompt_token_ids=prompt_ids,
|
||||
finished=prev_had_args,
|
||||
)
|
||||
prompt_ids = None
|
||||
results.append(result)
|
||||
|
||||
if prev_had_args:
|
||||
break
|
||||
|
||||
_, _, tool_calls = collect_fields(results)
|
||||
tool_args = "".join(
|
||||
tc.function.arguments for tc in tool_calls if tc.function.arguments
|
||||
)
|
||||
assert tool_args.endswith(remainder)
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import json
|
||||
import time
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from collections.abc import Sequence as GenericSequence
|
||||
@@ -40,9 +39,7 @@ from vllm.entrypoints.openai.chat_completion.stream_harmony import (
|
||||
extract_harmony_streaming_delta,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
DeltaFunctionCall,
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ErrorResponse,
|
||||
FunctionCall,
|
||||
PromptTokenUsageInfo,
|
||||
@@ -65,7 +62,7 @@ from vllm.entrypoints.utils import get_max_tokens, should_include_usage
|
||||
from vllm.inputs import EngineInput
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logprobs import Logprob
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.parser import ParserManager
|
||||
from vllm.parser.abstract_parser import Parser
|
||||
from vllm.reasoning import ReasoningParser
|
||||
@@ -715,6 +712,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
delta_token_ids=as_list(output.token_ids),
|
||||
request=request,
|
||||
prompt_token_ids=res.prompt_token_ids,
|
||||
finished=output.finish_reason is not None,
|
||||
)
|
||||
if delta_message and delta_message.tool_calls:
|
||||
tools_streamed[i] = True
|
||||
@@ -805,81 +803,13 @@ class OpenAIServingChat(OpenAIServing):
|
||||
# finish_reason='error' indicates a retryable error
|
||||
self._raise_if_error(output.finish_reason, request_id)
|
||||
|
||||
# check to make sure we haven't "forgotten" to stream
|
||||
# any tokens that were generated but previously
|
||||
# matched by partial json parsing
|
||||
# only happens if we are NOT using structured outputs
|
||||
index = 0
|
||||
auto_tools_called = False
|
||||
if tool_parser:
|
||||
auto_tools_called = len(tool_parser.prev_tool_call_arr) > 0
|
||||
index = (
|
||||
len(tool_parser.prev_tool_call_arr) - 1
|
||||
if auto_tools_called
|
||||
else 0
|
||||
)
|
||||
should_check = (
|
||||
self._should_check_for_unstreamed_tool_arg_tokens(
|
||||
delta_message, output
|
||||
)
|
||||
)
|
||||
# only check if there are any tool calls
|
||||
# detected by partial parsing
|
||||
if should_check and tool_parser and auto_tools_called:
|
||||
latest_delta_len = 0
|
||||
if (
|
||||
isinstance(
|
||||
delta_message.tool_calls[0].function,
|
||||
DeltaFunctionCall,
|
||||
)
|
||||
) and isinstance(
|
||||
delta_message.tool_calls[0].function.arguments, str
|
||||
):
|
||||
latest_delta_len = len(
|
||||
delta_message.tool_calls[0].function.arguments
|
||||
)
|
||||
|
||||
# get the expected call based on partial JSON
|
||||
# parsing which "autocompletes" the JSON.
|
||||
# Tool parsers (e.g. Qwen3Coder) store
|
||||
# arguments as a JSON string in
|
||||
# prev_tool_call_arr. Calling json.dumps()
|
||||
# on an already-serialized string would
|
||||
# double-serialize it (e.g. '{"k":1}' becomes
|
||||
# '"{\\"k\\":1}"'), which then causes the
|
||||
# replace() below to fail and append the
|
||||
# entire double-serialized string as a
|
||||
# spurious final delta.
|
||||
args = tool_parser.prev_tool_call_arr[index].get(
|
||||
"arguments", {}
|
||||
)
|
||||
if isinstance(args, str):
|
||||
expected_call = args
|
||||
else:
|
||||
expected_call = json.dumps(args, ensure_ascii=False)
|
||||
|
||||
# get what we've streamed so far for arguments
|
||||
# for the current tool
|
||||
actual_call = tool_parser.streamed_args_for_tool[index]
|
||||
if latest_delta_len > 0:
|
||||
actual_call = actual_call[:-latest_delta_len]
|
||||
|
||||
# check to see if there's anything left to stream
|
||||
remaining_call = expected_call.replace(actual_call, "", 1)
|
||||
# set that as a delta message
|
||||
delta_message = self._create_remaining_args_delta(
|
||||
delta_message, remaining_call, index
|
||||
)
|
||||
|
||||
# Send the finish response for each request.n only once
|
||||
# In OpenAI's API, when a tool is called, the
|
||||
# finish_reason is:
|
||||
# "tool_calls" for "auto" or "required" tool calls,
|
||||
# and "stop" for named tool calls.
|
||||
if (
|
||||
auto_tools_called
|
||||
or (tools_streamed[i] and not tool_choice_function_name)
|
||||
or (self.use_harmony and harmony_tools_streamed[i])
|
||||
if (tools_streamed[i] and not tool_choice_function_name) or (
|
||||
self.use_harmony and harmony_tools_streamed[i]
|
||||
):
|
||||
finish_reason_ = "tool_calls"
|
||||
else:
|
||||
@@ -1535,56 +1465,3 @@ class OpenAIServingChat(OpenAIServing):
|
||||
and self.enable_auto_tools
|
||||
and request.tool_choice in ["auto", None]
|
||||
)
|
||||
|
||||
def _should_check_for_unstreamed_tool_arg_tokens(
|
||||
self,
|
||||
delta_message: DeltaMessage | None,
|
||||
output: CompletionOutput,
|
||||
) -> bool:
|
||||
"""
|
||||
Check to see if we should check for unstreamed tool arguments tokens.
|
||||
This is only applicable when auto tool parsing is enabled, the delta
|
||||
is a tool call with arguments.
|
||||
"""
|
||||
|
||||
return bool(
|
||||
# if there is a delta message that includes tool calls which
|
||||
# include a function that has arguments
|
||||
output.finish_reason is not None
|
||||
and self.enable_auto_tools
|
||||
and self.tool_parser
|
||||
and delta_message
|
||||
and delta_message.tool_calls
|
||||
and delta_message.tool_calls[0]
|
||||
and delta_message.tool_calls[0].function
|
||||
and delta_message.tool_calls[0].function.arguments is not None
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _create_remaining_args_delta(
|
||||
delta_message: DeltaMessage,
|
||||
remaining_call: str,
|
||||
index: int,
|
||||
) -> DeltaMessage:
|
||||
"""
|
||||
Create a delta message for remaining tool arguments, preserving
|
||||
id/type/name from the original delta.
|
||||
"""
|
||||
original_tc = next(
|
||||
(tc for tc in delta_message.tool_calls if tc.index == index),
|
||||
None,
|
||||
)
|
||||
original_fn = original_tc.function if original_tc else None
|
||||
return DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=index,
|
||||
id=original_tc.id if original_tc else None,
|
||||
type=original_tc.type if original_tc else None,
|
||||
function=DeltaFunctionCall(
|
||||
name=original_fn.name if original_fn else None,
|
||||
arguments=remaining_call,
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
@@ -320,6 +320,7 @@ class Parser:
|
||||
delta_token_ids: list[int],
|
||||
request: ChatCompletionRequest | ResponsesRequest,
|
||||
prompt_token_ids: list[int] | None = None,
|
||||
finished: bool = False,
|
||||
) -> DeltaMessage | None:
|
||||
"""Parse a single streaming delta, orchestrating reasoning then
|
||||
tool call extraction via internal stream state.
|
||||
@@ -656,12 +657,28 @@ class DelegatingParser(Parser):
|
||||
return False
|
||||
return state.reasoning_ended
|
||||
|
||||
def _append_unstreamed_tool_args(
|
||||
self,
|
||||
delta_message: DeltaMessage | None,
|
||||
) -> None:
|
||||
"""Append parsed-but-unstreamed tool-call arguments to *delta_message*."""
|
||||
if (
|
||||
self._tool_parser is not None
|
||||
and delta_message
|
||||
and delta_message.tool_calls
|
||||
and (last_tc := delta_message.tool_calls[-1]).function
|
||||
):
|
||||
last_tc.function.arguments = (
|
||||
last_tc.function.arguments or ""
|
||||
) + self._tool_parser.get_remaining_unstreamed_args()
|
||||
|
||||
def parse_delta(
|
||||
self,
|
||||
delta_text: str,
|
||||
delta_token_ids: list[int],
|
||||
request: ChatCompletionRequest | ResponsesRequest,
|
||||
prompt_token_ids: list[int] | None = None,
|
||||
finished: bool = False,
|
||||
) -> DeltaMessage | None:
|
||||
state = self._stream_state
|
||||
|
||||
@@ -745,6 +762,10 @@ class DelegatingParser(Parser):
|
||||
|
||||
state.previous_text = current_text
|
||||
state.previous_token_ids = current_token_ids
|
||||
|
||||
if finished:
|
||||
self._append_unstreamed_tool_args(delta_message)
|
||||
|
||||
return delta_message
|
||||
|
||||
|
||||
|
||||
@@ -79,6 +79,25 @@ class ToolParser:
|
||||
else:
|
||||
self.tools = []
|
||||
|
||||
def get_remaining_unstreamed_args(self) -> str:
|
||||
"""Return tool call arguments parsed but not yet streamed."""
|
||||
if not self.prev_tool_call_arr:
|
||||
return ""
|
||||
index = len(self.prev_tool_call_arr) - 1
|
||||
args = self.prev_tool_call_arr[index].get("arguments", {})
|
||||
if isinstance(args, str):
|
||||
expected = args
|
||||
else:
|
||||
expected = json.dumps(args, ensure_ascii=False)
|
||||
actual = (
|
||||
self.streamed_args_for_tool[index]
|
||||
if index < len(self.streamed_args_for_tool)
|
||||
else ""
|
||||
)
|
||||
if expected.startswith(actual):
|
||||
return expected[len(actual) :]
|
||||
return ""
|
||||
|
||||
@cached_property
|
||||
def vocab(self) -> dict[str, int]:
|
||||
# NOTE: Only PreTrainedTokenizerFast is guaranteed to have .vocab
|
||||
|
||||
Reference in New Issue
Block a user