Merge branch 'main' into wosouk/dsv4-attn-cleanup-2

Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
This commit is contained in:
Woosuk Kwon
2026-06-02 03:32:33 +00:00
23 changed files with 834 additions and 496 deletions
+39
View File
@@ -0,0 +1,39 @@
#!/bin/bash
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
set -euo pipefail
REQUIREMENTS_FILE="${KV_CONNECTORS_REQUIREMENTS:-/vllm-workspace/requirements/kv_connectors.txt}"
uv pip install --system -r "${REQUIREMENTS_FILE}"
NIXL_METADATA=$(python3 - <<'PY'
import importlib.metadata as metadata
import torch
cuda_version = torch.version.cuda
if cuda_version is None:
raise SystemExit("torch.version.cuda is not set")
print(cuda_version.split(".", 1)[0], metadata.version("nixl"))
PY
)
read -r CUDA_MAJOR NIXL_VERSION <<<"${NIXL_METADATA}"
# nixl>=1.1.0 can install multiple CUDA wheel variants. Keep only the variant
# matching this CI image so nixl_ep_cpp links against the available libcudart.
uv pip uninstall --system nixl-cu12 nixl-cu13 2>/dev/null || true
uv pip install --system --no-deps "nixl-cu${CUDA_MAJOR}==${NIXL_VERSION}"
python3 - <<'PY'
import importlib.metadata as metadata
for package_name in ("nixl", "nixl-cu12", "nixl-cu13"):
try:
version = metadata.version(package_name)
except metadata.PackageNotFoundError:
version = "not installed"
print(f"{package_name}: {version}")
PY
+9 -9
View File
@@ -11,7 +11,7 @@ steps:
- vllm/distributed/kv_transfer/kv_connector/v1/nixl/
- tests/v1/kv_connector/nixl_integration/
commands:
- uv pip install --system -r /vllm-workspace/requirements/kv_connectors.txt
- bash /vllm-workspace/.buildkite/scripts/install-kv-connectors.sh
- bash v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh
- label: Distributed FlashInfer NixlConnector PD accuracy (4 GPUs)
key: distributed-flashinfer-nixlconnector-pd-accuracy-4-gpus
@@ -22,7 +22,7 @@ steps:
- vllm/distributed/kv_transfer/kv_connector/v1/nixl/
- tests/v1/kv_connector/nixl_integration/
commands:
- uv pip install --system -r /vllm-workspace/requirements/kv_connectors.txt
- bash /vllm-workspace/.buildkite/scripts/install-kv-connectors.sh
- FLASHINFER=1 bash v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh
- label: DP EP Distributed NixlConnector PD accuracy tests (4 GPUs)
@@ -34,7 +34,7 @@ steps:
- vllm/distributed/kv_transfer/kv_connector/v1/nixl/
- tests/v1/kv_connector/nixl_integration/
commands:
- uv pip install --system -r /vllm-workspace/requirements/kv_connectors.txt
- bash /vllm-workspace/.buildkite/scripts/install-kv-connectors.sh
- DP_EP=1 bash v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh
- label: CrossLayer KV layout Distributed NixlConnector PD accuracy tests (4 GPUs)
@@ -46,7 +46,7 @@ steps:
- vllm/distributed/kv_transfer/kv_connector/v1/nixl/
- tests/v1/kv_connector/nixl_integration/
commands:
- uv pip install --system -r /vllm-workspace/requirements/kv_connectors.txt
- bash /vllm-workspace/.buildkite/scripts/install-kv-connectors.sh
- CROSS_LAYERS_BLOCKS=True bash v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh
- label: Hybrid SSM NixlConnector PD accuracy tests (4 GPUs)
@@ -58,7 +58,7 @@ steps:
- vllm/distributed/kv_transfer/kv_connector/v1/nixl/
- tests/v1/kv_connector/nixl_integration/
commands:
- uv pip install --system -r /vllm-workspace/requirements/kv_connectors.txt
- bash /vllm-workspace/.buildkite/scripts/install-kv-connectors.sh
- HYBRID_SSM=1 bash v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh
- label: MultiConnector (Nixl+Offloading) PD accuracy (2 GPUs)
@@ -73,7 +73,7 @@ steps:
- vllm/distributed/kv_transfer/kv_connector/v1/offloading/
- tests/v1/kv_connector/nixl_integration/
commands:
- uv pip install --system -r /vllm-workspace/requirements/kv_connectors.txt
- bash /vllm-workspace/.buildkite/scripts/install-kv-connectors.sh
- bash v1/kv_connector/nixl_integration/run_multi_connector_accuracy_test.sh
- label: NixlConnector PD + Spec Decode acceptance (2 GPUs)
@@ -87,7 +87,7 @@ steps:
- vllm/v1/worker/kv_connector_model_runner_mixin.py
- tests/v1/kv_connector/nixl_integration/
commands:
- uv pip install --system -r /vllm-workspace/requirements/kv_connectors.txt
- bash /vllm-workspace/.buildkite/scripts/install-kv-connectors.sh
- bash v1/kv_connector/nixl_integration/config_sweep_spec_decode_test.sh
- label: MultiConnector (Nixl+Offloading) PD edge cases (2 GPUs)
@@ -102,5 +102,5 @@ steps:
- vllm/distributed/kv_transfer/kv_connector/v1/offloading/
- tests/v1/kv_connector/nixl_integration/
commands:
- uv pip install --system -r /vllm-workspace/requirements/kv_connectors.txt
- bash v1/kv_connector/nixl_integration/run_multi_connector_edge_case_test.sh
- bash /vllm-workspace/.buildkite/scripts/install-kv-connectors.sh
- bash v1/kv_connector/nixl_integration/run_multi_connector_edge_case_test.sh
+1 -1
View File
@@ -86,7 +86,7 @@ steps:
- tests/v1/metrics
- tests/entrypoints/openai/correctness/test_lmeval.py
commands:
- uv pip install --system -r /vllm-workspace/requirements/kv_connectors.txt
- bash /vllm-workspace/.buildkite/scripts/install-kv-connectors.sh
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
# split the test to avoid interference
- pytest -v -s -m 'not cpu_test' v1/core
+1 -1
View File
@@ -9,7 +9,7 @@ ARG PYTORCH_AUDIO_BRANCH="v2.9.0"
ARG PYTORCH_AUDIO_REPO="https://github.com/pytorch/audio.git"
ARG FA_BRANCH="0e60e394"
ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git"
ARG AITER_BRANCH="v0.1.13"
ARG AITER_BRANCH="v0.1.13.post1"
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
ARG MORI_BRANCH="v1.1.0"
ARG MORI_REPO="https://github.com/ROCm/mori.git"
@@ -55,6 +55,15 @@ def test_dynamic_shapes_compilation(
evaluate_guards,
):
"""Test that all dynamic shapes types compile successfully"""
if shapes_type == DynamicShapesType.UNBACKED and not is_torch_equal_or_newer(
"2.11.0"
):
# NOTE[ROCm]: shape_id (used by Qwen2/Llama to relate input dims) only
# landed in torch 2.11, but the ROCm CI still runs torch 2.10.x. On
# older torch there's no way to express it, so unbacked shapes go
# data-dependent and compilation blows up -- nothing to test.
pytest.skip("unbacked dynamic shapes with shape_id require torch>=2.11")
if evaluate_guards and shapes_type == DynamicShapesType.UNBACKED:
pytest.skip("unbacked dynamic shapes do not add guards")
@@ -6,9 +6,13 @@ from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from pydantic import ValidationError
from vllm.config.multimodal import MultiModalConfig
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.chat_completion.protocol import (
BatchChatCompletionRequest,
ChatCompletionRequest,
)
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
from vllm.entrypoints.openai.engine.protocol import GenerationError
from vllm.entrypoints.openai.models.protocol import BaseModelPath
@@ -444,3 +448,45 @@ def test_json_schema_response_format_missing_schema():
messages=[{"role": "user", "content": "hello"}],
response_format={"type": "json_schema"},
)
@pytest.mark.parametrize("format_value", [None, {}])
def test_structural_tag_response_format_invalid(format_value):
"""Malformed structural tags should be rejected during request validation."""
with pytest.raises(
ValidationError,
match="Invalid response_format structural_tag",
):
ChatCompletionRequest(
model=MODEL_NAME,
messages=[{"role": "user", "content": "hello"}],
response_format={"type": "structural_tag", "format": format_value},
)
@pytest.mark.parametrize("format_value", [None, {}])
def test_batch_structural_tag_response_format_invalid(format_value):
"""Batch chat should reject malformed structural tags at request parsing."""
with pytest.raises(
ValidationError,
match="Invalid response_format structural_tag",
):
BatchChatCompletionRequest(
model=MODEL_NAME,
messages=[[{"role": "user", "content": "hello"}]],
response_format={"type": "structural_tag", "format": format_value},
)
@pytest.mark.parametrize("structural_tag", ["not json", ""])
def test_structured_outputs_structural_tag_invalid(structural_tag):
"""Malformed direct structured_outputs structural tags should be rejected."""
with pytest.raises(
ValidationError,
match="Invalid structured_outputs structural_tag",
):
ChatCompletionRequest(
model=MODEL_NAME,
messages=[{"role": "user", "content": "hello"}],
structured_outputs={"structural_tag": structural_tag},
)
@@ -1935,8 +1935,10 @@ async def test_streaming_n_gt1_independent_tool_parsers():
finished=True,
)
# Collect tool-call deltas per choice from the SSE stream.
# Collect tool-call deltas and finish_reasons per choice from the SSE
# stream.
tc_deltas_by_choice: dict[int, list[dict]] = {i: [] for i in range(num_choices)}
finish_reasons_by_choice: dict[int, list[str]] = {i: [] for i in range(num_choices)}
async for chunk_str in serving_chat.chat_completion_stream_generator(
request=request,
result_generator=result_generator(),
@@ -1959,6 +1961,8 @@ async def test_streaming_n_gt1_independent_tool_parsers():
if delta.get("tool_calls"):
for tc in delta["tool_calls"]:
tc_deltas_by_choice[idx].append(tc)
if choice.get("finish_reason") is not None:
finish_reasons_by_choice[idx].append(choice["finish_reason"])
# Both choices must independently produce the correct tool call.
for choice_idx in range(num_choices):
@@ -1984,141 +1988,11 @@ async def test_streaming_n_gt1_independent_tool_parsers():
f"Choice {choice_idx}: expected {{'city': 'Tokyo'}}, got {parsed_args}"
)
class TestCreateRemainingArgsDelta:
"""Tests for _create_remaining_args_delta helper function.
This helper is used when streaming tool calls to preserve id/type/name
fields in the finish chunk, which would otherwise be lost.
"""
def test_preserves_id_type_name(self):
"""Test that id, type, and name are preserved from original delta."""
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
from vllm.entrypoints.openai.engine.protocol import (
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
reasons = finish_reasons_by_choice[choice_idx]
assert len(reasons) == 1, (
f"Choice {choice_idx}: expected exactly 1 finish_reason, got {reasons}"
)
original_delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=0,
id="call_abc123",
type="function",
function=DeltaFunctionCall(
name="get_weather",
arguments='{"location": "Paris"}',
),
)
]
assert reasons[0] == "tool_calls", (
f"Choice {choice_idx}: expected finish_reason='tool_calls', "
f"got '{reasons[0]}'"
)
result = OpenAIServingChat._create_remaining_args_delta(
original_delta, '", "unit": "celsius"}', 0
)
assert len(result.tool_calls) == 1
tc = result.tool_calls[0]
assert tc.index == 0
assert tc.id == "call_abc123"
assert tc.type == "function"
assert tc.function.name == "get_weather"
assert tc.function.arguments == '", "unit": "celsius"}'
def test_matches_by_index(self):
"""Test that the correct tool call is matched by index."""
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
from vllm.entrypoints.openai.engine.protocol import (
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
)
original_delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=0,
id="call_first",
type="function",
function=DeltaFunctionCall(name="func_a", arguments="{}"),
),
DeltaToolCall(
index=1,
id="call_second",
type="function",
function=DeltaFunctionCall(name="func_b", arguments="{}"),
),
]
)
result = OpenAIServingChat._create_remaining_args_delta(
original_delta, '{"extra": true}', 1
)
assert len(result.tool_calls) == 1
tc = result.tool_calls[0]
assert tc.index == 1
assert tc.id == "call_second"
assert tc.function.name == "func_b"
def test_no_matching_tool_call(self):
"""Test graceful handling when no matching tool call is found."""
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
from vllm.entrypoints.openai.engine.protocol import (
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
)
original_delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=0,
id="call_zero",
type="function",
function=DeltaFunctionCall(name="func", arguments="{}"),
)
]
)
result = OpenAIServingChat._create_remaining_args_delta(
original_delta, '{"arg": 1}', 5
)
assert len(result.tool_calls) == 1
tc = result.tool_calls[0]
assert tc.index == 5
assert tc.id is None
assert tc.type is None
assert tc.function.name is None
assert tc.function.arguments == '{"arg": 1}'
def test_function_is_none(self):
"""Test handling when original tool call has no function."""
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
from vllm.entrypoints.openai.engine.protocol import DeltaMessage, DeltaToolCall
original_delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=0,
id="call_nofunc",
type="function",
function=None,
)
]
)
result = OpenAIServingChat._create_remaining_args_delta(
original_delta, '{"data": "value"}', 0
)
assert len(result.tool_calls) == 1
tc = result.tool_calls[0]
assert tc.index == 0
assert tc.id == "call_nofunc"
assert tc.type == "function"
assert tc.function.name is None
assert tc.function.arguments == '{"data": "value"}'
@@ -6,6 +6,7 @@ from typing import Any
from unittest.mock import AsyncMock, MagicMock
import pytest
from pydantic import ValidationError
from vllm.config.multimodal import MultiModalConfig
from vllm.entrypoints.openai.completion.protocol import CompletionRequest
@@ -302,6 +303,36 @@ def test_json_schema_response_format_missing_schema():
)
@pytest.mark.parametrize("format_value", [None, {}])
def test_structural_tag_response_format_invalid(format_value):
"""Malformed structural tags should be rejected during request validation."""
with pytest.raises(
ValidationError,
match="Invalid response_format structural_tag",
):
CompletionRequest(
model=MODEL_NAME,
prompt="Test prompt",
max_tokens=10,
response_format={"type": "structural_tag", "format": format_value},
)
@pytest.mark.parametrize("structural_tag", ["not json", ""])
def test_structured_outputs_structural_tag_invalid(structural_tag):
"""Malformed direct structured_outputs structural tags should be rejected."""
with pytest.raises(
ValidationError,
match="Invalid structured_outputs structural_tag",
):
CompletionRequest(
model=MODEL_NAME,
prompt="Test prompt",
max_tokens=10,
structured_outputs={"structural_tag": structural_tag},
)
def test_negative_prompt_token_ids_nested():
"""Negative token IDs in prompt (nested list) should raise validation error."""
with pytest.raises(Exception, match="greater than or equal to 0"):
+83
View File
@@ -235,3 +235,86 @@ def test_parse_delta_reasoning_only_thinking_disabled(tokenizer, request_obj):
assert "Hello" in content
assert "assist" in content
assert len(tool_calls) == 0
def test_parse_delta_finished_no_flush_without_tool_call_delta(tokenizer, request_obj):
"""When finished=True but the final parse_delta produces no
tool-call delta, unstreamed args are not flushed."""
parser = make_parser(tokenizer, reasoning=False, tool=True)
results = stream_text(
parser, tokenizer, MODEL_OUTPUT, request_obj, prompt_token_ids=[]
)
_, _, tool_calls = collect_fields(results)
assert len(tool_calls) > 0
streamed = parser._tool_parser.streamed_args_for_tool[0]
assert len(streamed) > 5
parser._tool_parser.streamed_args_for_tool[0] = streamed[:-5]
# Prevent normal extraction from catching the gap — without a
# tool-call delta to merge into, the flush is skipped.
parser._tool_parser.extract_tool_calls_streaming = lambda *a, **kw: None
flush_result = parser.parse_delta("", [], request_obj, finished=True)
assert flush_result is None or flush_result.tool_calls is None
def test_parse_delta_finished_no_extra_args_when_fully_streamed(tokenizer, request_obj):
"""When all args have been streamed, finished=True must not
produce extra or duplicate arguments."""
parser = make_parser(tokenizer, reasoning=False, tool=True)
results = stream_text(
parser, tokenizer, MODEL_OUTPUT, request_obj, prompt_token_ids=[]
)
_, _, tool_calls = collect_fields(results)
assert len(tool_calls) > 0
assert tool_calls[0].function.name == "get_weather"
tool_args = "".join(
tc.function.arguments for tc in tool_calls if tc.function.arguments
)
assert json.loads(tool_args) == {"city": "Dallas"}
flush_result = parser.parse_delta("", [], request_obj, finished=True)
assert flush_result is None or flush_result.tool_calls is None
def test_parse_delta_finished_appends_remaining_args(tokenizer, request_obj):
"""When finished=True and the tool parser has unstreamed args,
parse_delta appends the remaining arguments to the tool-call delta."""
parser = make_parser(tokenizer, reasoning=False, tool=True)
token_ids = tokenizer.encode(MODEL_OUTPUT, add_special_tokens=False)
remainder = ',"unit":"celsius"}'
prompt_ids: list[int] | None = []
results: list[DeltaMessage | None] = []
for i, tid in enumerate(token_ids):
prev = results[-1] if results else None
prev_had_args = (
prev
and prev.tool_calls
and any(tc.function and tc.function.arguments for tc in prev.tool_calls)
)
if prev_had_args:
parser._tool_parser.get_remaining_unstreamed_args = lambda: remainder
result = parser.parse_delta(
tokenizer.decode([tid]),
[tid],
request_obj,
prompt_token_ids=prompt_ids,
finished=prev_had_args,
)
prompt_ids = None
results.append(result)
if prev_had_args:
break
_, _, tool_calls = collect_fields(results)
tool_args = "".join(
tc.function.arguments for tc in tool_calls if tc.function.arguments
)
assert tool_args.endswith(remainder)
+26 -2
View File
@@ -1382,7 +1382,20 @@ def test_adjust_request_non_mistral_tokenizer(
[
{"regex": r"\d+"},
{"choice": ["a", "b"]},
{"structural_tag": '{"key": "value"}'},
{
"structural_tag": json.dumps(
{
"structures": [
{
"begin": "<tool>",
"schema": {"type": "object"},
"end": "</tool>",
}
],
"triggers": ["<tool>"],
}
)
},
{"grammar": "start: 'hello'"},
],
ids=["regex", "choice", "structural_tag", "grammar"],
@@ -1404,7 +1417,18 @@ def test_adjust_request_unsupported_response_format(
) -> None:
request = _make_request(
response_format=StructuralTagResponseFormat(
type="structural_tag", format={"some": "config"}
type="structural_tag",
format={
"type": "triggered_tags",
"tags": [
{
"begin": "<tool>",
"content": {"type": "any_text"},
"end": "</tool>",
}
],
"triggers": ["<tool>"],
},
),
)
result = mistral_tool_parser.adjust_request(request)
@@ -1305,3 +1305,79 @@ def test_swa_alignment_skip(request_runner, async_scheduling: bool):
(1, 7),
),
)
@pytest.mark.parametrize("async_scheduling", [True, False])
def test_stale_sliding_window_block_after_prepare_store_failure(
request_runner, async_scheduling: bool
):
"""Regression test: when prepare_store fails (returns None), offloading is
delayed. Meanwhile, sliding window blocks get freed and reallocated to the
same request. On retry, the stale block_id must be detected and skipped.
Without the fix, the stale block_id would either:
- Cause a KeyError in _remove_pending_job (duplicate in
_block_id_to_pending_jobs)
- Silently offload wrong data under a wrong key
"""
block_size = 4
# sliding_window = 8 -> window of 2 blocks
sliding_window = 8
# Use a tight GPU block budget so freed sliding window blocks are
# immediately reused by the same request's new allocations.
num_gpu_blocks = 4
kv_cache_groups = [
KVCacheGroupSpec(
["layer0"],
SlidingWindowSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=sliding_window,
),
),
]
runner = request_runner(
block_size=block_size,
num_gpu_blocks=num_gpu_blocks,
async_scheduling=async_scheduling,
kv_cache_groups=kv_cache_groups,
)
# Request with 3 blocks of prompt. Window = 2 blocks, so block 0 is
# outside the window but won't be freed until the next allocate_slots.
runner.new_request(token_ids=[0] * block_size * 3)
# First step: prepare_store FAILS -> offloading delayed.
# next_stored_block_idx stays at 0, block_ids[0] still holds the
# original block_id for position 0.
runner.manager.prepare_store.side_effect = lambda keys, req_context: None
runner.run(decoded_tokens=[0])
runner.manager.prepare_store.assert_called()
# Second step: decode more tokens -> block 3 allocated.
# allocate_slots calls remove_skipped_blocks which frees block 0
# (it's now outside the sliding window). With num_gpu_blocks=4,
# the freed block is immediately reused for the new allocation.
# prepare_store still fails so offloading is still delayed.
runner.manager.prepare_store.side_effect = lambda keys, req_context: None
runner.run(decoded_tokens=[0] * block_size)
# Now prepare_store succeeds.
# Without the fix, the request would try to offload the stale block_id
# at position 0 (now reused at position 3), causing a duplicate in
# sliding_window_block_ids and eventually a KeyError.
runner.manager.prepare_store.side_effect = lambda keys, req_context: (
generate_store_output(keys)
)
# block_ids=[0, ?, 3, 1]: positions 0 and 1 are zeroed (stale blocks that
# were freed by the sliding window and reallocated). Only blocks at
# positions 2 and 3 (request offsets 2, 3) are stored.
runner.run(
decoded_tokens=[EOS_TOKEN_ID],
expected_stored=(2, 3),
expected_flushed=(2, 3) if not async_scheduling else (),
)
@@ -461,6 +461,35 @@ def test_store_sending_thread_only_skips_on_no_available_handle():
assert store.batch_put_from_multi_buffers.call_count == 2
def test_store_sending_thread_releases_pin_on_batch_is_exist_failure():
# `batch_is_exist` raising must still decrement `stored_requests` so the
# scheduler can drop `delay_free_blocks` and release the pinned GPU blocks.
store = MagicMock()
store.batch_is_exist.side_effect = RuntimeError("mooncake down")
thread = _make_store_sending_thread(store)
thread.add_stored_request("req-a")
with pytest.raises(RuntimeError):
thread._handle_request(_make_store_req("req-a", [b"a0", b"a1"]))
assert thread.stored_requests["req-a"] == 0
store.batch_put_from_multi_buffers.assert_not_called()
def test_store_sending_thread_releases_pin_on_batch_put_failure():
# `batch_put_from_multi_buffers` raising is logged (not re-raised), and the
# pin must still be released through the finally block.
store = MagicMock()
store.batch_is_exist.return_value = [0, 0]
store.batch_put_from_multi_buffers.side_effect = RuntimeError("rdma error")
thread = _make_store_sending_thread(store)
thread.add_stored_request("req-a")
thread._handle_request(_make_store_req("req-a", [b"a0", b"a1"]))
assert thread.stored_requests["req-a"] == 0
def test_store_recving_thread_reports_failed_block_ids():
store = MagicMock()
store.batch_get_into_multi_buffers.return_value = [256, -5, -7]
+10 -1
View File
@@ -338,7 +338,16 @@ def _xpu_mxfp8_quantize_impl(
shape = x.shape[:-1] + (x.shape[-1] // MXFP8_BLOCK_SIZE,)
x_s = torch.empty(shape, device=x.device, dtype=torch.float32)
torch.ops._C.per_token_group_fp8_quant(
x, x_q, x_s, MXFP8_BLOCK_SIZE, eps, fp8_min, fp8_max, True
x,
x_q,
x_s,
MXFP8_BLOCK_SIZE,
eps,
fp8_min,
fp8_max,
True,
False,
False, # dummy_is_scale_transposed, dummy_is_tma_aligned
)
x_s = x_s.to(torch.float8_e8m0fnu)
return x_q, x_s
@@ -517,187 +517,190 @@ class KVCacheStoreSendingThread(KVTransferThread):
if req_id not in self.stored_requests:
self.request_queue.task_done()
return
if token_len == 0:
self.dec_stored_request(req_id)
self.request_queue.task_done()
return
if self._should_skip_request(req_id):
logger.debug(
"Skipping Mooncake store for request %s while CPU/disk offloading "
"is under pressure",
req_id,
)
self.dec_stored_request(req_id)
self.request_queue.task_done()
return
# Within each lcm region only per-spec relevant chunks are loaded
# (e.g., SWA or linear attn), so mask out irrelevant chunks
store_masks = self.coord.store_mask(token_len)
starts: list[int] = []
ends: list[int] = []
keys: list[str] = []
block_hashes: list[BlockHash] = []
group_indices: list[int] = []
for g_idx, db in enumerate(self.token_databases):
mask = store_masks[g_idx]
for chunk_idx, (start, end, key) in enumerate(
db.process_tokens(token_len, req_meta.block_hashes)
):
if chunk_idx >= len(mask) or not mask[chunk_idx]:
continue
starts.append(start)
ends.append(end)
keys.append(key.to_string())
block_hashes.append(req_meta.block_hashes[chunk_idx])
group_indices.append(g_idx)
# Apply put_step striding for TP
sl = slice(self.tp_rank % self.put_step, None, self.put_step)
starts = starts[sl]
ends = ends[sl]
keys = keys[sl]
block_hashes = block_hashes[sl]
group_indices = group_indices[sl]
if not keys:
self.dec_stored_request(req_id)
return
# Check which blocks already exist (dedup)
save_exists_start = time.perf_counter()
# Decrement the in-flight counter and signal task_done() in `finally`
# so the scheduler can release the GPU blocks it pinned for this
# request (via `delay_free_blocks`) even when the store path raises.
try:
exists_states = self.store.batch_is_exist(keys)
except Exception:
if token_len == 0:
return
if self._should_skip_request(req_id):
logger.debug(
"Skipping Mooncake store for request %s while CPU/disk "
"offloading is under pressure",
req_id,
)
return
# Within each lcm region only per-spec relevant chunks are loaded
# (e.g., SWA or linear attn), so mask out irrelevant chunks
store_masks = self.coord.store_mask(token_len)
starts: list[int] = []
ends: list[int] = []
keys: list[str] = []
block_hashes: list[BlockHash] = []
group_indices: list[int] = []
for g_idx, db in enumerate(self.token_databases):
mask = store_masks[g_idx]
for chunk_idx, (start, end, key) in enumerate(
db.process_tokens(token_len, req_meta.block_hashes)
):
if chunk_idx >= len(mask) or not mask[chunk_idx]:
continue
starts.append(start)
ends.append(end)
keys.append(key.to_string())
block_hashes.append(req_meta.block_hashes[chunk_idx])
group_indices.append(g_idx)
# Apply put_step striding for TP
sl = slice(self.tp_rank % self.put_step, None, self.put_step)
starts = starts[sl]
ends = ends[sl]
keys = keys[sl]
block_hashes = block_hashes[sl]
group_indices = group_indices[sl]
if not keys:
return
# Check which blocks already exist (dedup)
save_exists_start = time.perf_counter()
try:
exists_states = self.store.batch_is_exist(keys)
except Exception:
self._record_operation(
"save_exists",
save_exists_start,
len(keys),
status="error",
num_failed_keys=len(keys),
)
raise
self._record_operation(
"save_exists",
save_exists_start,
len(keys),
status="error",
num_failed_keys=len(keys),
)
raise
self._record_operation(
"save_exists",
save_exists_start,
len(keys),
)
missing_indices = [i for i, exists in enumerate(exists_states) if exists != 1]
missing_indices = [
i for i, exists in enumerate(exists_states) if exists != 1
]
if not missing_indices:
self.dec_stored_request(req_id)
return
if not missing_indices:
return
starts = [starts[i] for i in missing_indices]
ends = [ends[i] for i in missing_indices]
keys = [keys[i] for i in missing_indices]
block_hashes = [block_hashes[i] for i in missing_indices]
group_indices = [group_indices[i] for i in missing_indices]
starts = [starts[i] for i in missing_indices]
ends = [ends[i] for i in missing_indices]
keys = [keys[i] for i in missing_indices]
block_hashes = [block_hashes[i] for i in missing_indices]
group_indices = [group_indices[i] for i in missing_indices]
logger.debug(
"Storing KV cache for %d blocks (groups=%s) for request %s",
len(keys),
set(group_indices),
req_id,
)
addrs: list[list[int]] = []
sizes: list[list[int]] = []
stored_events: list[BlockStored] = []
# parent_block_hash chains live within a group, not across.
prev_key_per_group: dict[int, Any] = {}
new_block_hashes = [maybe_convert_block_hash(bh) for bh in block_hashes]
for idx, (s, e, g_idx) in enumerate(
zip(starts, ends, group_indices, strict=True)
):
db = self.token_databases[g_idx]
addr, size, _ = db.prepare_value(s, e, block_ids_per_group[g_idx])
addrs.append(addr)
sizes.append(size)
if self.enable_kv_event:
token_ids = (
req_meta.token_ids[s:e] if req_meta.token_ids is not None else None
)
stored_event = BlockStored(
block_hashes=[new_block_hashes[idx]],
parent_block_hash=prev_key_per_group.get(g_idx),
token_ids=token_ids,
block_size=req_meta.original_block_size,
lora_id=None,
medium="cpu",
lora_name=None,
)
stored_events.append(stored_event)
prev_key_per_group[g_idx] = new_block_hashes[idx]
if current_event is not None:
current_event.synchronize()
batch_bytes = _sum_batch_bytes(sizes)
put_start = time.perf_counter()
try:
res = self.store.batch_put_from_multi_buffers(
keys,
addrs,
sizes,
self.replicate_config,
)
failed = [i for i, v in enumerate(res) if v < 0]
self._record_operation(
"save_put",
put_start,
logger.debug(
"Storing KV cache for %d blocks (groups=%s) for request %s",
len(keys),
num_bytes=batch_bytes,
status="partial_failure" if failed else "ok",
num_failed_keys=len(failed),
set(group_indices),
req_id,
)
if failed:
failed_codes = set(res[i] for i in failed)
logger.warning(
"batch_put failed: %d/%d keys failed "
"(codes=%s, batch_bytes=%d, num_keys=%d), "
"first_key=%s",
len(failed),
len(keys),
failed_codes,
batch_bytes,
len(keys),
keys[0] if keys else "N/A",
)
if (
MOONCAKE_NO_AVAILABLE_HANDLE in failed_codes
and not self._mark_request_skipped_for_pressure(req_id)
):
logger.warning(
"Detected Mooncake CPU/disk offloading pressure "
"(NO_AVAILABLE_HANDLE); skipping future store "
"batches for request %s until a later store "
"batch succeeds",
req_id,
addrs: list[list[int]] = []
sizes: list[list[int]] = []
stored_events: list[BlockStored] = []
# parent_block_hash chains live within a group, not across.
prev_key_per_group: dict[int, Any] = {}
new_block_hashes = [maybe_convert_block_hash(bh) for bh in block_hashes]
for idx, (s, e, g_idx) in enumerate(
zip(starts, ends, group_indices, strict=True)
):
db = self.token_databases[g_idx]
addr, size, _ = db.prepare_value(s, e, block_ids_per_group[g_idx])
addrs.append(addr)
sizes.append(size)
if self.enable_kv_event:
token_ids = (
req_meta.token_ids[s:e]
if req_meta.token_ids is not None
else None
)
elif self._clear_store_pressure():
logger.info(
"Mooncake CPU/disk offloading pressure cleared after a "
"successful store batch"
stored_event = BlockStored(
block_hashes=[new_block_hashes[idx]],
parent_block_hash=prev_key_per_group.get(g_idx),
token_ids=token_ids,
block_size=req_meta.original_block_size,
lora_id=None,
medium="cpu",
lora_name=None,
)
stored_events.append(stored_event)
prev_key_per_group[g_idx] = new_block_hashes[idx]
if current_event is not None:
current_event.synchronize()
batch_bytes = _sum_batch_bytes(sizes)
put_start = time.perf_counter()
try:
res = self.store.batch_put_from_multi_buffers(
keys,
addrs,
sizes,
self.replicate_config,
)
except Exception as e:
self._record_operation(
"save_put",
put_start,
len(keys),
num_bytes=batch_bytes,
status="error",
num_failed_keys=len(keys),
)
logger.error("Failed to put key %s, error: %s", keys, e)
failed = [i for i, v in enumerate(res) if v < 0]
self._record_operation(
"save_put",
put_start,
len(keys),
num_bytes=batch_bytes,
status="partial_failure" if failed else "ok",
num_failed_keys=len(failed),
)
if failed:
failed_codes = set(res[i] for i in failed)
logger.warning(
"batch_put failed: %d/%d keys failed "
"(codes=%s, batch_bytes=%d, num_keys=%d), "
"first_key=%s",
len(failed),
len(keys),
failed_codes,
batch_bytes,
len(keys),
keys[0] if keys else "N/A",
)
if (
MOONCAKE_NO_AVAILABLE_HANDLE in failed_codes
and not self._mark_request_skipped_for_pressure(req_id)
):
logger.warning(
"Detected Mooncake CPU/disk offloading pressure "
"(NO_AVAILABLE_HANDLE); skipping future store "
"batches for request %s until a later store "
"batch succeeds",
req_id,
)
elif self._clear_store_pressure():
logger.info(
"Mooncake CPU/disk offloading pressure cleared after a "
"successful store batch"
)
except Exception as e:
self._record_operation(
"save_put",
put_start,
len(keys),
num_bytes=batch_bytes,
status="error",
num_failed_keys=len(keys),
)
logger.error("Failed to put key %s, error: %s", keys, e)
if self.enable_kv_event and stored_events:
self.update_kv_event(stored_events)
self.dec_stored_request(req_id)
self.request_queue.task_done()
if self.enable_kv_event and stored_events:
self.update_kv_event(stored_events)
finally:
self.dec_stored_request(req_id)
self.request_queue.task_done()
class KVCacheStoreRecvingThread(KVTransferThread):
@@ -286,6 +286,8 @@ class OffloadingConnectorScheduler:
self._req_status: dict[ReqId, RequestOffloadState] = {}
self._current_batch_load_jobs: dict[int, TransferJob] = {}
self._current_batch_jobs_to_flush: set[int] = set()
# GPU block IDs allocated in the current engine step
self._current_batch_allocated_block_ids: set[int] = set()
# if GPU prefix caching is enabled,
# track loaded blocks to avoid redundant loads
self._blocks_being_loaded: set[OffloadKey] | None = (
@@ -589,6 +591,10 @@ class OffloadingConnectorScheduler:
req_status.group_states,
blocks.blocks,
):
self._current_batch_allocated_block_ids.update(
block.block_id for block in group_blocks if block.block_id != 0
)
gpu_block_size = group_config.gpu_block_size
offloaded_block_size = group_config.offloaded_block_size
offload_keys = group_state.offload_keys
@@ -638,17 +644,6 @@ class OffloadingConnectorScheduler:
if req_status.offloading_context.policy == OffloadPolicy.BLOCK_LEVEL:
group_state.next_stored_block_idx = num_blocks
# Fence dst blocks against finished-request pending stores.
if (
self._block_id_to_pending_jobs
and not self._block_id_to_pending_jobs.keys().isdisjoint(dst_block_ids)
):
self._current_batch_jobs_to_flush.update(
jid
for bid in dst_block_ids
for jid in self._block_id_to_pending_jobs.get(bid, ())
)
src_spec = self.manager.prepare_load(keys_to_load, req_status.req_context)
dst_spec = GPULoadStoreSpec(
dst_block_ids, group_sizes=group_sizes, block_indices=block_indices
@@ -672,37 +667,68 @@ class OffloadingConnectorScheduler:
if self._blocks_being_loaded is not None:
self._blocks_being_loaded.update(keys_to_load)
def _build_store_jobs(
self,
scheduler_output: SchedulerOutput,
) -> dict[int, TransferJob]:
block_size_factor = self.config.block_size_factor
store_jobs: dict[int, TransferJob] = {}
# iterate over both new and cached requests
def _update_req_states(self, scheduler_output: SchedulerOutput) -> None:
"""
Update request states from the Scheduler's output.
"""
# new_block_ids_end[req_id][i] = end of pre-existing block_ids for
# the i-th sliding window group (before this step's extend).
# Used to detect sliding window blocks that got re-allocated.
new_block_ids_end: dict[str, tuple[int, ...]] = {}
for req_id, new_block_id_groups, preempted in yield_req_data(scheduler_output):
req_status = self._req_status[req_id]
req_status.update_offload_keys()
req = req_status.req
if preempted:
for group_state in req_status.group_states:
group_state.block_ids.clear()
if new_block_id_groups:
if self._sliding_window_groups:
new_block_ids_end[req_id] = tuple(
len(req_status.group_states[grp_idx].block_ids)
for grp_idx in self._sliding_window_groups
)
req_status.update_block_id_groups(new_block_id_groups)
# Fence new blocks against in-flight stores.
if self._block_id_to_pending_jobs:
new_blocks_flat = [
bid for new_blocks in new_block_id_groups for bid in new_blocks
]
if not self._block_id_to_pending_jobs.keys().isdisjoint(
new_blocks_flat
):
self._current_batch_jobs_to_flush.update(
jid
for bid in new_blocks_flat
for jid in self._block_id_to_pending_jobs.get(bid, ())
)
for new_blocks in new_block_id_groups:
for bid in new_blocks:
if bid != 0:
self._current_batch_allocated_block_ids.add(bid)
# Zero out stale block_ids in sliding window groups' pending-store
# positions. Only sliding window groups can have stale entries (blocks
# freed by remove_skipped_blocks then reallocated). Only positions in
# [next_stored_block_idx * bsf, end) need checking where end is the
# pre-extend length: earlier positions were already offloaded, later
# ones are fresh allocations from this step.
if self._sliding_window_groups and self._current_batch_allocated_block_ids:
block_size_factor = self.config.block_size_factor
for req_id, req_status in self._req_status.items():
ends = new_block_ids_end.get(req_id)
for i, grp_idx in enumerate(self._sliding_window_groups):
group_state = req_status.group_states[grp_idx]
start = group_state.next_stored_block_idx * block_size_factor
end = ends[i] if ends is not None else len(group_state.block_ids)
for j in range(start, end):
if (
group_state.block_ids[j]
in self._current_batch_allocated_block_ids
):
group_state.block_ids[j] = 0
def _build_store_jobs(
self,
scheduler_output: SchedulerOutput,
) -> dict[int, TransferJob]:
block_size_factor = self.config.block_size_factor
store_jobs: dict[int, TransferJob] = {}
for req_id in scheduler_output.num_scheduled_tokens:
req_status = self._req_status.get(req_id)
if req_status is None:
continue
req = req_status.req
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
num_tokens_after_batch = req.num_computed_tokens + num_scheduled_tokens
@@ -735,11 +761,8 @@ class OffloadingConnectorScheduler:
# For each block to offload, take the last corresponding GPU block.
# e.g. if block size factor is 3 and GPU block IDs are
# 1 5 6 7 2 4 9 3 8 then we'll take blocks 6 4 8.
# We will use these GPU blocks to determine if the block needs
# offloading, or (if the GPU block ID is 0) this block should
# be skipped due to sliding window attention / SSM.
# We know that if a block is skipped, then all the previous blocks
# are skipped as well. This is why we take the last of each block.
# A block_id of 0 means either a sliding window / SSM skip
# or a stale entry that was zeroed out — skip it either way.
offload_block_ids = group_state.block_ids[
start_block_idx * block_size_factor
+ block_size_factor
@@ -814,10 +837,8 @@ class OffloadingConnectorScheduler:
for i in range(block_size_factor):
block_id = block_ids[gpu_block_idx + i]
if block_id == 0:
# skipped blocks cannot appear after non-skipped blocks
assert start_gpu_block_idx is None
continue
elif start_gpu_block_idx is None:
if start_gpu_block_idx is None:
start_gpu_block_idx = gpu_block_idx + i
src_block_ids.append(block_id)
num_group_blocks += 1
@@ -875,6 +896,9 @@ class OffloadingConnectorScheduler:
def build_connector_meta(
self, scheduler_output: SchedulerOutput
) -> KVConnectorMetadata:
self._update_req_states(scheduler_output)
# Flush jobs for preempted requests.
for req_id in scheduler_output.preempted_req_ids or ():
req_status = self._req_status.get(req_id)
if req_status is None or not req_status.transfer_jobs:
@@ -883,6 +907,20 @@ class OffloadingConnectorScheduler:
assert self._jobs[any_jid].is_store
self._current_batch_jobs_to_flush.update(req_status.transfer_jobs)
# Flush jobs that contain re-allocated blocks.
if (
self._block_id_to_pending_jobs
and not self._block_id_to_pending_jobs.keys().isdisjoint(
self._current_batch_allocated_block_ids
)
):
self._current_batch_jobs_to_flush.update(
jid
for bid in self._current_batch_allocated_block_ids
if bid in self._block_id_to_pending_jobs
for jid in self._block_id_to_pending_jobs[bid]
)
# If all tracked requests are finished, flush all pending jobs
# (both store and load) - there might not be a future scheduler
# step to trigger their completion.
@@ -898,6 +936,7 @@ class OffloadingConnectorScheduler:
)
self._current_batch_load_jobs = {}
self._current_batch_jobs_to_flush = set()
self._current_batch_allocated_block_ids = set()
return meta
def update_connector_output(self, connector_output: KVConnectorOutput):
@@ -1013,6 +1052,7 @@ class OffloadingConnectorScheduler:
# reset_cache cannot be called in the middle of a schedule step
assert not self._current_batch_load_jobs
assert not self._current_batch_jobs_to_flush
assert not self._current_batch_allocated_block_ids
# Flush all in-flight jobs
self._current_batch_jobs_to_flush.update(self._jobs.keys())
@@ -30,6 +30,8 @@ from vllm.entrypoints.openai.engine.protocol import (
StructuralTagResponseFormat,
ToolCall,
UsageInfo,
validate_structural_tag_response_format,
validate_structured_outputs_structural_tag,
)
from vllm.exceptions import VLLMValidationError
from vllm.logger import init_logger
@@ -671,6 +673,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
parameter="response_format",
)
if rf_type == "structural_tag":
validate_structural_tag_response_format(response_format)
return data
@model_validator(mode="before")
@@ -754,6 +759,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
"You can only either use constraints for structured outputs "
"or tools, not both.",
)
validate_structured_outputs_structural_tag(structured_outputs_kwargs)
return data
@model_validator(mode="before")
@@ -979,6 +985,16 @@ class BatchChatCompletionRequest(OpenAIBaseModel):
"Batch chat completions do not support beam search. "
"Please set `use_beam_search` to False."
)
response_format = data.get("response_format")
rf_type = (
response_format.get("type")
if isinstance(response_format, dict)
else getattr(response_format, "type", None)
)
if rf_type == "structural_tag":
validate_structural_tag_response_format(response_format)
if (structured_outputs := data.get("structured_outputs")) is not None:
validate_structured_outputs_structural_tag(structured_outputs)
n = data.get("n", 1)
if n is not None and n != 1:
raise ValueError(
@@ -3,7 +3,6 @@
import asyncio
import io
import json
import time
from collections.abc import AsyncGenerator, AsyncIterator
from collections.abc import Sequence as GenericSequence
@@ -40,9 +39,7 @@ from vllm.entrypoints.openai.chat_completion.stream_harmony import (
extract_harmony_streaming_delta,
)
from vllm.entrypoints.openai.engine.protocol import (
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ErrorResponse,
FunctionCall,
PromptTokenUsageInfo,
@@ -65,7 +62,7 @@ from vllm.entrypoints.utils import get_max_tokens, should_include_usage
from vllm.inputs import EngineInput
from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.outputs import RequestOutput
from vllm.parser import ParserManager
from vllm.parser.abstract_parser import Parser
from vllm.reasoning import ReasoningParser
@@ -715,6 +712,7 @@ class OpenAIServingChat(OpenAIServing):
delta_token_ids=as_list(output.token_ids),
request=request,
prompt_token_ids=res.prompt_token_ids,
finished=output.finish_reason is not None,
)
if delta_message and delta_message.tool_calls:
tools_streamed[i] = True
@@ -805,81 +803,13 @@ class OpenAIServingChat(OpenAIServing):
# finish_reason='error' indicates a retryable error
self._raise_if_error(output.finish_reason, request_id)
# check to make sure we haven't "forgotten" to stream
# any tokens that were generated but previously
# matched by partial json parsing
# only happens if we are NOT using structured outputs
index = 0
auto_tools_called = False
if tool_parser:
auto_tools_called = len(tool_parser.prev_tool_call_arr) > 0
index = (
len(tool_parser.prev_tool_call_arr) - 1
if auto_tools_called
else 0
)
should_check = (
self._should_check_for_unstreamed_tool_arg_tokens(
delta_message, output
)
)
# only check if there are any tool calls
# detected by partial parsing
if should_check and tool_parser and auto_tools_called:
latest_delta_len = 0
if (
isinstance(
delta_message.tool_calls[0].function,
DeltaFunctionCall,
)
) and isinstance(
delta_message.tool_calls[0].function.arguments, str
):
latest_delta_len = len(
delta_message.tool_calls[0].function.arguments
)
# get the expected call based on partial JSON
# parsing which "autocompletes" the JSON.
# Tool parsers (e.g. Qwen3Coder) store
# arguments as a JSON string in
# prev_tool_call_arr. Calling json.dumps()
# on an already-serialized string would
# double-serialize it (e.g. '{"k":1}' becomes
# '"{\\"k\\":1}"'), which then causes the
# replace() below to fail and append the
# entire double-serialized string as a
# spurious final delta.
args = tool_parser.prev_tool_call_arr[index].get(
"arguments", {}
)
if isinstance(args, str):
expected_call = args
else:
expected_call = json.dumps(args, ensure_ascii=False)
# get what we've streamed so far for arguments
# for the current tool
actual_call = tool_parser.streamed_args_for_tool[index]
if latest_delta_len > 0:
actual_call = actual_call[:-latest_delta_len]
# check to see if there's anything left to stream
remaining_call = expected_call.replace(actual_call, "", 1)
# set that as a delta message
delta_message = self._create_remaining_args_delta(
delta_message, remaining_call, index
)
# Send the finish response for each request.n only once
# In OpenAI's API, when a tool is called, the
# finish_reason is:
# "tool_calls" for "auto" or "required" tool calls,
# and "stop" for named tool calls.
if (
auto_tools_called
or (tools_streamed[i] and not tool_choice_function_name)
or (self.use_harmony and harmony_tools_streamed[i])
if (tools_streamed[i] and not tool_choice_function_name) or (
self.use_harmony and harmony_tools_streamed[i]
):
finish_reason_ = "tool_calls"
else:
@@ -1535,56 +1465,3 @@ class OpenAIServingChat(OpenAIServing):
and self.enable_auto_tools
and request.tool_choice in ["auto", None]
)
def _should_check_for_unstreamed_tool_arg_tokens(
self,
delta_message: DeltaMessage | None,
output: CompletionOutput,
) -> bool:
"""
Check to see if we should check for unstreamed tool arguments tokens.
This is only applicable when auto tool parsing is enabled, the delta
is a tool call with arguments.
"""
return bool(
# if there is a delta message that includes tool calls which
# include a function that has arguments
output.finish_reason is not None
and self.enable_auto_tools
and self.tool_parser
and delta_message
and delta_message.tool_calls
and delta_message.tool_calls[0]
and delta_message.tool_calls[0].function
and delta_message.tool_calls[0].function.arguments is not None
)
@staticmethod
def _create_remaining_args_delta(
delta_message: DeltaMessage,
remaining_call: str,
index: int,
) -> DeltaMessage:
"""
Create a delta message for remaining tool arguments, preserving
id/type/name from the original delta.
"""
original_tc = next(
(tc for tc in delta_message.tool_calls if tc.index == index),
None,
)
original_fn = original_tc.function if original_tc else None
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=index,
id=original_tc.id if original_tc else None,
type=original_tc.type if original_tc else None,
function=DeltaFunctionCall(
name=original_fn.name if original_fn else None,
arguments=remaining_call,
),
)
]
)
@@ -18,6 +18,8 @@ from vllm.entrypoints.openai.engine.protocol import (
StreamOptions,
StructuralTagResponseFormat,
UsageInfo,
validate_structural_tag_response_format,
validate_structured_outputs_structural_tag,
)
from vllm.exceptions import VLLMValidationError
from vllm.logger import init_logger
@@ -370,6 +372,9 @@ class CompletionRequest(OpenAIBaseModel):
parameter="response_format",
)
if rf_type == "structural_tag":
validate_structural_tag_response_format(response_format)
return data
@model_validator(mode="before")
@@ -397,6 +402,7 @@ class CompletionRequest(OpenAIBaseModel):
"outputs ('json', 'regex' or 'choice').",
parameter="structured_outputs",
)
validate_structured_outputs_structural_tag(structured_outputs_kwargs)
return data
@model_validator(mode="before")
@@ -17,6 +17,7 @@ from pydantic import (
)
from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.exceptions import VLLMValidationError
from vllm.logger import init_logger
from vllm.utils import random_uuid
from vllm.utils.import_utils import resolve_obj_by_qualname
@@ -158,6 +159,80 @@ AnyResponseFormat: TypeAlias = (
)
def validate_structural_tag_response_format(
response_format: AnyStructuralTagResponseFormat | dict[str, Any],
) -> None:
"""Validate structural tags before they are sent to the engine.
Engine-side validation reports malformed structural tags as generation
failures. OpenAI request parsing should classify them as bad requests.
"""
import json
from pydantic import TypeAdapter, ValidationError
if isinstance(response_format, dict):
try:
response_format = TypeAdapter(
AnyStructuralTagResponseFormat
).validate_python(response_format)
except ValidationError as exc:
raise VLLMValidationError(
"Invalid response_format structural_tag specification.",
parameter="response_format",
) from exc
try:
payload = json.dumps(response_format.model_dump(by_alias=True))
validate_structural_tag_payload(payload, parameter="response_format")
except (TypeError, ValueError) as exc:
raise VLLMValidationError(
"Invalid response_format structural_tag specification.",
parameter="response_format",
) from exc
def validate_structural_tag_payload(payload: Any, *, parameter: str) -> None:
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
from vllm.v1.structured_output.backend_xgrammar import validate_xgrammar_grammar
if isinstance(payload, str) and not payload:
raise VLLMValidationError(
f"Invalid {parameter} structural_tag specification.",
parameter=parameter,
)
try:
validate_xgrammar_grammar(
SamplingParams(
structured_outputs=StructuredOutputsParams(structural_tag=payload)
)
)
except (TypeError, ValueError) as exc:
raise VLLMValidationError(
f"Invalid {parameter} structural_tag specification.",
parameter=parameter,
) from exc
def validate_structured_outputs_structural_tag(
structured_outputs: Any,
) -> None:
from vllm.sampling_params import StructuredOutputsParams
if isinstance(structured_outputs, StructuredOutputsParams):
structural_tag = structured_outputs.structural_tag
elif isinstance(structured_outputs, dict):
structural_tag = structured_outputs.get("structural_tag")
else:
return
if structural_tag is not None:
validate_structural_tag_payload(
structural_tag,
parameter="structured_outputs",
)
class StreamOptions(OpenAIBaseModel):
include_usage: bool | None = False
continuous_usage_stats: bool | None = False
@@ -575,7 +575,7 @@ def per_token_group_quant_fp8(
# prefer CUDA/XPU kernel if available
# TODO(bnell): this causes some fp8 moe test to fail.
if current_platform.is_cuda() and x.is_contiguous():
if (current_platform.is_cuda() or current_platform.is_xpu()) and x.is_contiguous():
torch.ops._C.per_token_group_fp8_quant(
x,
x_q,
@@ -590,12 +590,6 @@ def per_token_group_quant_fp8(
)
return x_q, x_s
if current_platform.is_xpu() and x.is_contiguous():
torch.ops._C.per_token_group_fp8_quant(
x, x_q, x_s, group_size, eps, fp8_min, fp8_max, use_ue8m0
)
return x_q, x_s
# TRITON FALLBACK
M = x.numel() // group_size
N = group_size
+71 -4
View File
@@ -173,9 +173,64 @@ MiniCPMOAudioInputs: TypeAlias = (
def _minicpmo_field_config(hf_inputs: Mapping[str, torch.Tensor]):
audio_features = hf_inputs.get("audio_features")
audio_feature_lens = hf_inputs.get("audio_feature_lens")
# For multi-chunk audio (>30s), audio_features has one item per chunk
# (total_chunks) while audio_feature_lens has one item per audio (N).
# Use flat to group audio_features by audio so both fields
# share the same batch size (N).
audio_features_cfg = MultiModalFieldConfig.batched("audio")
if audio_features is not None and audio_feature_lens is not None:
num_features = (
len(audio_features)
if isinstance(audio_features, (list, tuple))
else audio_features.shape[0]
)
num_audios = (
len(audio_feature_lens)
if isinstance(audio_feature_lens, (list, tuple))
else audio_feature_lens.shape[0]
)
if num_features > num_audios:
# Compute the number of chunks belonging to each audio
chunks_per_audio: list[int] = []
for lens in audio_feature_lens:
if isinstance(lens, torch.Tensor):
chunks_per_audio.append(lens.numel())
else:
chunks_per_audio.append(1)
# When audio_feature_lens is padded (e.g. from batched HF
# processor output), numel() over-counts. Fall back to
# counting non-zero entries so the sizes sum to num_features.
if sum(chunks_per_audio) != num_features:
chunks_per_audio = []
for lens in audio_feature_lens:
if isinstance(lens, torch.Tensor):
n = int((lens != 0).sum())
chunks_per_audio.append(max(n, 1))
else:
chunks_per_audio.append(1)
# Use flat (not flat_from_sizes) because audio_features
# is list[Tensor] with variable-length chunks (post-unpad).
slice_idxs = [0]
for n in chunks_per_audio:
slice_idxs.append(slice_idxs[-1] + n)
audio_features_cfg = MultiModalFieldConfig.flat(
"audio",
[
slice(slice_idxs[i], slice_idxs[i + 1])
for i in range(len(chunks_per_audio))
],
)
return dict(
**_minicpmv_field_config(hf_inputs),
audio_features=MultiModalFieldConfig.batched("audio"),
audio_features=audio_features_cfg,
audio_feature_lens=MultiModalFieldConfig.batched("audio"),
audio_embeds=MultiModalFieldConfig.batched("audio"),
)
@@ -364,11 +419,23 @@ class MiniCPMOMultiModalProcessor(MiniCPMVMultiModalProcessor[MiniCPMOProcessing
# Avoid padding since we need the output for each audio to be
# independent of other audios for the cache to work correctly
# Flatten audio_feature_lens (list of tensors of any
# dimensionality, one per audio, each containing per-chunk
# lengths) into a flat list of integer lengths so there is
# one length per chunk, matching the first dimension of
# audio_features. Using flatten() handles 0-D, 1-D, and
# higher-dimensional tensors uniformly.
flat_feature_lens: list[int] = []
for lens in audio_inputs["audio_feature_lens"]:
if isinstance(lens, torch.Tensor):
flat_feature_lens.extend(lens.flatten().tolist())
else:
flat_feature_lens.append(int(lens))
unpadded_audio_features = [
feat[:, :feature_len]
for feat, feature_len in zip(
feat[:, :length]
for feat, length in zip(
audio_inputs["audio_features"],
audio_inputs["audio_feature_lens"],
flat_feature_lens,
)
]
audio_inputs["audio_features"] = unpadded_audio_features
+21
View File
@@ -320,6 +320,7 @@ class Parser:
delta_token_ids: list[int],
request: ChatCompletionRequest | ResponsesRequest,
prompt_token_ids: list[int] | None = None,
finished: bool = False,
) -> DeltaMessage | None:
"""Parse a single streaming delta, orchestrating reasoning then
tool call extraction via internal stream state.
@@ -656,12 +657,28 @@ class DelegatingParser(Parser):
return False
return state.reasoning_ended
def _append_unstreamed_tool_args(
self,
delta_message: DeltaMessage | None,
) -> None:
"""Append parsed-but-unstreamed tool-call arguments to *delta_message*."""
if (
self._tool_parser is not None
and delta_message
and delta_message.tool_calls
and (last_tc := delta_message.tool_calls[-1]).function
):
last_tc.function.arguments = (
last_tc.function.arguments or ""
) + self._tool_parser.get_remaining_unstreamed_args()
def parse_delta(
self,
delta_text: str,
delta_token_ids: list[int],
request: ChatCompletionRequest | ResponsesRequest,
prompt_token_ids: list[int] | None = None,
finished: bool = False,
) -> DeltaMessage | None:
state = self._stream_state
@@ -745,6 +762,10 @@ class DelegatingParser(Parser):
state.previous_text = current_text
state.previous_token_ids = current_token_ids
if finished:
self._append_unstreamed_tool_args(delta_message)
return delta_message
+19
View File
@@ -79,6 +79,25 @@ class ToolParser:
else:
self.tools = []
def get_remaining_unstreamed_args(self) -> str:
"""Return tool call arguments parsed but not yet streamed."""
if not self.prev_tool_call_arr:
return ""
index = len(self.prev_tool_call_arr) - 1
args = self.prev_tool_call_arr[index].get("arguments", {})
if isinstance(args, str):
expected = args
else:
expected = json.dumps(args, ensure_ascii=False)
actual = (
self.streamed_args_for_tool[index]
if index < len(self.streamed_args_for_tool)
else ""
)
if expected.startswith(actual):
return expected[len(actual) :]
return ""
@cached_property
def vocab(self) -> dict[str, int]:
# NOTE: Only PreTrainedTokenizerFast is guaranteed to have .vocab