mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
Merge branch 'main' into wosouk/dsv4-attn-cleanup-2
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
This commit is contained in:
Executable
+39
@@ -0,0 +1,39 @@
|
||||
#!/bin/bash
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
REQUIREMENTS_FILE="${KV_CONNECTORS_REQUIREMENTS:-/vllm-workspace/requirements/kv_connectors.txt}"
|
||||
|
||||
uv pip install --system -r "${REQUIREMENTS_FILE}"
|
||||
|
||||
NIXL_METADATA=$(python3 - <<'PY'
|
||||
import importlib.metadata as metadata
|
||||
|
||||
import torch
|
||||
|
||||
cuda_version = torch.version.cuda
|
||||
if cuda_version is None:
|
||||
raise SystemExit("torch.version.cuda is not set")
|
||||
|
||||
print(cuda_version.split(".", 1)[0], metadata.version("nixl"))
|
||||
PY
|
||||
)
|
||||
read -r CUDA_MAJOR NIXL_VERSION <<<"${NIXL_METADATA}"
|
||||
|
||||
# nixl>=1.1.0 can install multiple CUDA wheel variants. Keep only the variant
|
||||
# matching this CI image so nixl_ep_cpp links against the available libcudart.
|
||||
uv pip uninstall --system nixl-cu12 nixl-cu13 2>/dev/null || true
|
||||
uv pip install --system --no-deps "nixl-cu${CUDA_MAJOR}==${NIXL_VERSION}"
|
||||
|
||||
python3 - <<'PY'
|
||||
import importlib.metadata as metadata
|
||||
|
||||
for package_name in ("nixl", "nixl-cu12", "nixl-cu13"):
|
||||
try:
|
||||
version = metadata.version(package_name)
|
||||
except metadata.PackageNotFoundError:
|
||||
version = "not installed"
|
||||
print(f"{package_name}: {version}")
|
||||
PY
|
||||
@@ -11,7 +11,7 @@ steps:
|
||||
- vllm/distributed/kv_transfer/kv_connector/v1/nixl/
|
||||
- tests/v1/kv_connector/nixl_integration/
|
||||
commands:
|
||||
- uv pip install --system -r /vllm-workspace/requirements/kv_connectors.txt
|
||||
- bash /vllm-workspace/.buildkite/scripts/install-kv-connectors.sh
|
||||
- bash v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh
|
||||
- label: Distributed FlashInfer NixlConnector PD accuracy (4 GPUs)
|
||||
key: distributed-flashinfer-nixlconnector-pd-accuracy-4-gpus
|
||||
@@ -22,7 +22,7 @@ steps:
|
||||
- vllm/distributed/kv_transfer/kv_connector/v1/nixl/
|
||||
- tests/v1/kv_connector/nixl_integration/
|
||||
commands:
|
||||
- uv pip install --system -r /vllm-workspace/requirements/kv_connectors.txt
|
||||
- bash /vllm-workspace/.buildkite/scripts/install-kv-connectors.sh
|
||||
- FLASHINFER=1 bash v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh
|
||||
|
||||
- label: DP EP Distributed NixlConnector PD accuracy tests (4 GPUs)
|
||||
@@ -34,7 +34,7 @@ steps:
|
||||
- vllm/distributed/kv_transfer/kv_connector/v1/nixl/
|
||||
- tests/v1/kv_connector/nixl_integration/
|
||||
commands:
|
||||
- uv pip install --system -r /vllm-workspace/requirements/kv_connectors.txt
|
||||
- bash /vllm-workspace/.buildkite/scripts/install-kv-connectors.sh
|
||||
- DP_EP=1 bash v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh
|
||||
|
||||
- label: CrossLayer KV layout Distributed NixlConnector PD accuracy tests (4 GPUs)
|
||||
@@ -46,7 +46,7 @@ steps:
|
||||
- vllm/distributed/kv_transfer/kv_connector/v1/nixl/
|
||||
- tests/v1/kv_connector/nixl_integration/
|
||||
commands:
|
||||
- uv pip install --system -r /vllm-workspace/requirements/kv_connectors.txt
|
||||
- bash /vllm-workspace/.buildkite/scripts/install-kv-connectors.sh
|
||||
- CROSS_LAYERS_BLOCKS=True bash v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh
|
||||
|
||||
- label: Hybrid SSM NixlConnector PD accuracy tests (4 GPUs)
|
||||
@@ -58,7 +58,7 @@ steps:
|
||||
- vllm/distributed/kv_transfer/kv_connector/v1/nixl/
|
||||
- tests/v1/kv_connector/nixl_integration/
|
||||
commands:
|
||||
- uv pip install --system -r /vllm-workspace/requirements/kv_connectors.txt
|
||||
- bash /vllm-workspace/.buildkite/scripts/install-kv-connectors.sh
|
||||
- HYBRID_SSM=1 bash v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh
|
||||
|
||||
- label: MultiConnector (Nixl+Offloading) PD accuracy (2 GPUs)
|
||||
@@ -73,7 +73,7 @@ steps:
|
||||
- vllm/distributed/kv_transfer/kv_connector/v1/offloading/
|
||||
- tests/v1/kv_connector/nixl_integration/
|
||||
commands:
|
||||
- uv pip install --system -r /vllm-workspace/requirements/kv_connectors.txt
|
||||
- bash /vllm-workspace/.buildkite/scripts/install-kv-connectors.sh
|
||||
- bash v1/kv_connector/nixl_integration/run_multi_connector_accuracy_test.sh
|
||||
|
||||
- label: NixlConnector PD + Spec Decode acceptance (2 GPUs)
|
||||
@@ -87,7 +87,7 @@ steps:
|
||||
- vllm/v1/worker/kv_connector_model_runner_mixin.py
|
||||
- tests/v1/kv_connector/nixl_integration/
|
||||
commands:
|
||||
- uv pip install --system -r /vllm-workspace/requirements/kv_connectors.txt
|
||||
- bash /vllm-workspace/.buildkite/scripts/install-kv-connectors.sh
|
||||
- bash v1/kv_connector/nixl_integration/config_sweep_spec_decode_test.sh
|
||||
|
||||
- label: MultiConnector (Nixl+Offloading) PD edge cases (2 GPUs)
|
||||
@@ -102,5 +102,5 @@ steps:
|
||||
- vllm/distributed/kv_transfer/kv_connector/v1/offloading/
|
||||
- tests/v1/kv_connector/nixl_integration/
|
||||
commands:
|
||||
- uv pip install --system -r /vllm-workspace/requirements/kv_connectors.txt
|
||||
- bash v1/kv_connector/nixl_integration/run_multi_connector_edge_case_test.sh
|
||||
- bash /vllm-workspace/.buildkite/scripts/install-kv-connectors.sh
|
||||
- bash v1/kv_connector/nixl_integration/run_multi_connector_edge_case_test.sh
|
||||
|
||||
@@ -86,7 +86,7 @@ steps:
|
||||
- tests/v1/metrics
|
||||
- tests/entrypoints/openai/correctness/test_lmeval.py
|
||||
commands:
|
||||
- uv pip install --system -r /vllm-workspace/requirements/kv_connectors.txt
|
||||
- bash /vllm-workspace/.buildkite/scripts/install-kv-connectors.sh
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
# split the test to avoid interference
|
||||
- pytest -v -s -m 'not cpu_test' v1/core
|
||||
|
||||
@@ -9,7 +9,7 @@ ARG PYTORCH_AUDIO_BRANCH="v2.9.0"
|
||||
ARG PYTORCH_AUDIO_REPO="https://github.com/pytorch/audio.git"
|
||||
ARG FA_BRANCH="0e60e394"
|
||||
ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git"
|
||||
ARG AITER_BRANCH="v0.1.13"
|
||||
ARG AITER_BRANCH="v0.1.13.post1"
|
||||
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
|
||||
ARG MORI_BRANCH="v1.1.0"
|
||||
ARG MORI_REPO="https://github.com/ROCm/mori.git"
|
||||
|
||||
@@ -55,6 +55,15 @@ def test_dynamic_shapes_compilation(
|
||||
evaluate_guards,
|
||||
):
|
||||
"""Test that all dynamic shapes types compile successfully"""
|
||||
if shapes_type == DynamicShapesType.UNBACKED and not is_torch_equal_or_newer(
|
||||
"2.11.0"
|
||||
):
|
||||
# NOTE[ROCm]: shape_id (used by Qwen2/Llama to relate input dims) only
|
||||
# landed in torch 2.11, but the ROCm CI still runs torch 2.10.x. On
|
||||
# older torch there's no way to express it, so unbacked shapes go
|
||||
# data-dependent and compilation blows up -- nothing to test.
|
||||
pytest.skip("unbacked dynamic shapes with shape_id require torch>=2.11")
|
||||
|
||||
if evaluate_guards and shapes_type == DynamicShapesType.UNBACKED:
|
||||
pytest.skip("unbacked dynamic shapes do not add guards")
|
||||
|
||||
|
||||
@@ -6,9 +6,13 @@ from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from vllm.config.multimodal import MultiModalConfig
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
BatchChatCompletionRequest,
|
||||
ChatCompletionRequest,
|
||||
)
|
||||
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.engine.protocol import GenerationError
|
||||
from vllm.entrypoints.openai.models.protocol import BaseModelPath
|
||||
@@ -444,3 +448,45 @@ def test_json_schema_response_format_missing_schema():
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
response_format={"type": "json_schema"},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("format_value", [None, {}])
|
||||
def test_structural_tag_response_format_invalid(format_value):
|
||||
"""Malformed structural tags should be rejected during request validation."""
|
||||
with pytest.raises(
|
||||
ValidationError,
|
||||
match="Invalid response_format structural_tag",
|
||||
):
|
||||
ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
response_format={"type": "structural_tag", "format": format_value},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("format_value", [None, {}])
|
||||
def test_batch_structural_tag_response_format_invalid(format_value):
|
||||
"""Batch chat should reject malformed structural tags at request parsing."""
|
||||
with pytest.raises(
|
||||
ValidationError,
|
||||
match="Invalid response_format structural_tag",
|
||||
):
|
||||
BatchChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
messages=[[{"role": "user", "content": "hello"}]],
|
||||
response_format={"type": "structural_tag", "format": format_value},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("structural_tag", ["not json", ""])
|
||||
def test_structured_outputs_structural_tag_invalid(structural_tag):
|
||||
"""Malformed direct structured_outputs structural tags should be rejected."""
|
||||
with pytest.raises(
|
||||
ValidationError,
|
||||
match="Invalid structured_outputs structural_tag",
|
||||
):
|
||||
ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
structured_outputs={"structural_tag": structural_tag},
|
||||
)
|
||||
|
||||
@@ -1935,8 +1935,10 @@ async def test_streaming_n_gt1_independent_tool_parsers():
|
||||
finished=True,
|
||||
)
|
||||
|
||||
# Collect tool-call deltas per choice from the SSE stream.
|
||||
# Collect tool-call deltas and finish_reasons per choice from the SSE
|
||||
# stream.
|
||||
tc_deltas_by_choice: dict[int, list[dict]] = {i: [] for i in range(num_choices)}
|
||||
finish_reasons_by_choice: dict[int, list[str]] = {i: [] for i in range(num_choices)}
|
||||
async for chunk_str in serving_chat.chat_completion_stream_generator(
|
||||
request=request,
|
||||
result_generator=result_generator(),
|
||||
@@ -1959,6 +1961,8 @@ async def test_streaming_n_gt1_independent_tool_parsers():
|
||||
if delta.get("tool_calls"):
|
||||
for tc in delta["tool_calls"]:
|
||||
tc_deltas_by_choice[idx].append(tc)
|
||||
if choice.get("finish_reason") is not None:
|
||||
finish_reasons_by_choice[idx].append(choice["finish_reason"])
|
||||
|
||||
# Both choices must independently produce the correct tool call.
|
||||
for choice_idx in range(num_choices):
|
||||
@@ -1984,141 +1988,11 @@ async def test_streaming_n_gt1_independent_tool_parsers():
|
||||
f"Choice {choice_idx}: expected {{'city': 'Tokyo'}}, got {parsed_args}"
|
||||
)
|
||||
|
||||
|
||||
class TestCreateRemainingArgsDelta:
|
||||
"""Tests for _create_remaining_args_delta helper function.
|
||||
|
||||
This helper is used when streaming tool calls to preserve id/type/name
|
||||
fields in the finish chunk, which would otherwise be lost.
|
||||
"""
|
||||
|
||||
def test_preserves_id_type_name(self):
|
||||
"""Test that id, type, and name are preserved from original delta."""
|
||||
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
DeltaFunctionCall,
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
reasons = finish_reasons_by_choice[choice_idx]
|
||||
assert len(reasons) == 1, (
|
||||
f"Choice {choice_idx}: expected exactly 1 finish_reason, got {reasons}"
|
||||
)
|
||||
|
||||
original_delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=0,
|
||||
id="call_abc123",
|
||||
type="function",
|
||||
function=DeltaFunctionCall(
|
||||
name="get_weather",
|
||||
arguments='{"location": "Paris"}',
|
||||
),
|
||||
)
|
||||
]
|
||||
assert reasons[0] == "tool_calls", (
|
||||
f"Choice {choice_idx}: expected finish_reason='tool_calls', "
|
||||
f"got '{reasons[0]}'"
|
||||
)
|
||||
|
||||
result = OpenAIServingChat._create_remaining_args_delta(
|
||||
original_delta, '", "unit": "celsius"}', 0
|
||||
)
|
||||
|
||||
assert len(result.tool_calls) == 1
|
||||
tc = result.tool_calls[0]
|
||||
assert tc.index == 0
|
||||
assert tc.id == "call_abc123"
|
||||
assert tc.type == "function"
|
||||
assert tc.function.name == "get_weather"
|
||||
assert tc.function.arguments == '", "unit": "celsius"}'
|
||||
|
||||
def test_matches_by_index(self):
|
||||
"""Test that the correct tool call is matched by index."""
|
||||
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
DeltaFunctionCall,
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
)
|
||||
|
||||
original_delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=0,
|
||||
id="call_first",
|
||||
type="function",
|
||||
function=DeltaFunctionCall(name="func_a", arguments="{}"),
|
||||
),
|
||||
DeltaToolCall(
|
||||
index=1,
|
||||
id="call_second",
|
||||
type="function",
|
||||
function=DeltaFunctionCall(name="func_b", arguments="{}"),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
result = OpenAIServingChat._create_remaining_args_delta(
|
||||
original_delta, '{"extra": true}', 1
|
||||
)
|
||||
|
||||
assert len(result.tool_calls) == 1
|
||||
tc = result.tool_calls[0]
|
||||
assert tc.index == 1
|
||||
assert tc.id == "call_second"
|
||||
assert tc.function.name == "func_b"
|
||||
|
||||
def test_no_matching_tool_call(self):
|
||||
"""Test graceful handling when no matching tool call is found."""
|
||||
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
DeltaFunctionCall,
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
)
|
||||
|
||||
original_delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=0,
|
||||
id="call_zero",
|
||||
type="function",
|
||||
function=DeltaFunctionCall(name="func", arguments="{}"),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
result = OpenAIServingChat._create_remaining_args_delta(
|
||||
original_delta, '{"arg": 1}', 5
|
||||
)
|
||||
|
||||
assert len(result.tool_calls) == 1
|
||||
tc = result.tool_calls[0]
|
||||
assert tc.index == 5
|
||||
assert tc.id is None
|
||||
assert tc.type is None
|
||||
assert tc.function.name is None
|
||||
assert tc.function.arguments == '{"arg": 1}'
|
||||
|
||||
def test_function_is_none(self):
|
||||
"""Test handling when original tool call has no function."""
|
||||
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.engine.protocol import DeltaMessage, DeltaToolCall
|
||||
|
||||
original_delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=0,
|
||||
id="call_nofunc",
|
||||
type="function",
|
||||
function=None,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
result = OpenAIServingChat._create_remaining_args_delta(
|
||||
original_delta, '{"data": "value"}', 0
|
||||
)
|
||||
|
||||
assert len(result.tool_calls) == 1
|
||||
tc = result.tool_calls[0]
|
||||
assert tc.index == 0
|
||||
assert tc.id == "call_nofunc"
|
||||
assert tc.type == "function"
|
||||
assert tc.function.name is None
|
||||
assert tc.function.arguments == '{"data": "value"}'
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from vllm.config.multimodal import MultiModalConfig
|
||||
from vllm.entrypoints.openai.completion.protocol import CompletionRequest
|
||||
@@ -302,6 +303,36 @@ def test_json_schema_response_format_missing_schema():
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("format_value", [None, {}])
|
||||
def test_structural_tag_response_format_invalid(format_value):
|
||||
"""Malformed structural tags should be rejected during request validation."""
|
||||
with pytest.raises(
|
||||
ValidationError,
|
||||
match="Invalid response_format structural_tag",
|
||||
):
|
||||
CompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
prompt="Test prompt",
|
||||
max_tokens=10,
|
||||
response_format={"type": "structural_tag", "format": format_value},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("structural_tag", ["not json", ""])
|
||||
def test_structured_outputs_structural_tag_invalid(structural_tag):
|
||||
"""Malformed direct structured_outputs structural tags should be rejected."""
|
||||
with pytest.raises(
|
||||
ValidationError,
|
||||
match="Invalid structured_outputs structural_tag",
|
||||
):
|
||||
CompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
prompt="Test prompt",
|
||||
max_tokens=10,
|
||||
structured_outputs={"structural_tag": structural_tag},
|
||||
)
|
||||
|
||||
|
||||
def test_negative_prompt_token_ids_nested():
|
||||
"""Negative token IDs in prompt (nested list) should raise validation error."""
|
||||
with pytest.raises(Exception, match="greater than or equal to 0"):
|
||||
|
||||
@@ -235,3 +235,86 @@ def test_parse_delta_reasoning_only_thinking_disabled(tokenizer, request_obj):
|
||||
assert "Hello" in content
|
||||
assert "assist" in content
|
||||
assert len(tool_calls) == 0
|
||||
|
||||
|
||||
def test_parse_delta_finished_no_flush_without_tool_call_delta(tokenizer, request_obj):
|
||||
"""When finished=True but the final parse_delta produces no
|
||||
tool-call delta, unstreamed args are not flushed."""
|
||||
parser = make_parser(tokenizer, reasoning=False, tool=True)
|
||||
|
||||
results = stream_text(
|
||||
parser, tokenizer, MODEL_OUTPUT, request_obj, prompt_token_ids=[]
|
||||
)
|
||||
_, _, tool_calls = collect_fields(results)
|
||||
assert len(tool_calls) > 0
|
||||
|
||||
streamed = parser._tool_parser.streamed_args_for_tool[0]
|
||||
assert len(streamed) > 5
|
||||
parser._tool_parser.streamed_args_for_tool[0] = streamed[:-5]
|
||||
|
||||
# Prevent normal extraction from catching the gap — without a
|
||||
# tool-call delta to merge into, the flush is skipped.
|
||||
parser._tool_parser.extract_tool_calls_streaming = lambda *a, **kw: None
|
||||
|
||||
flush_result = parser.parse_delta("", [], request_obj, finished=True)
|
||||
assert flush_result is None or flush_result.tool_calls is None
|
||||
|
||||
|
||||
def test_parse_delta_finished_no_extra_args_when_fully_streamed(tokenizer, request_obj):
|
||||
"""When all args have been streamed, finished=True must not
|
||||
produce extra or duplicate arguments."""
|
||||
parser = make_parser(tokenizer, reasoning=False, tool=True)
|
||||
results = stream_text(
|
||||
parser, tokenizer, MODEL_OUTPUT, request_obj, prompt_token_ids=[]
|
||||
)
|
||||
_, _, tool_calls = collect_fields(results)
|
||||
|
||||
assert len(tool_calls) > 0
|
||||
assert tool_calls[0].function.name == "get_weather"
|
||||
tool_args = "".join(
|
||||
tc.function.arguments for tc in tool_calls if tc.function.arguments
|
||||
)
|
||||
assert json.loads(tool_args) == {"city": "Dallas"}
|
||||
|
||||
flush_result = parser.parse_delta("", [], request_obj, finished=True)
|
||||
assert flush_result is None or flush_result.tool_calls is None
|
||||
|
||||
|
||||
def test_parse_delta_finished_appends_remaining_args(tokenizer, request_obj):
|
||||
"""When finished=True and the tool parser has unstreamed args,
|
||||
parse_delta appends the remaining arguments to the tool-call delta."""
|
||||
parser = make_parser(tokenizer, reasoning=False, tool=True)
|
||||
token_ids = tokenizer.encode(MODEL_OUTPUT, add_special_tokens=False)
|
||||
|
||||
remainder = ',"unit":"celsius"}'
|
||||
prompt_ids: list[int] | None = []
|
||||
results: list[DeltaMessage | None] = []
|
||||
for i, tid in enumerate(token_ids):
|
||||
prev = results[-1] if results else None
|
||||
prev_had_args = (
|
||||
prev
|
||||
and prev.tool_calls
|
||||
and any(tc.function and tc.function.arguments for tc in prev.tool_calls)
|
||||
)
|
||||
|
||||
if prev_had_args:
|
||||
parser._tool_parser.get_remaining_unstreamed_args = lambda: remainder
|
||||
|
||||
result = parser.parse_delta(
|
||||
tokenizer.decode([tid]),
|
||||
[tid],
|
||||
request_obj,
|
||||
prompt_token_ids=prompt_ids,
|
||||
finished=prev_had_args,
|
||||
)
|
||||
prompt_ids = None
|
||||
results.append(result)
|
||||
|
||||
if prev_had_args:
|
||||
break
|
||||
|
||||
_, _, tool_calls = collect_fields(results)
|
||||
tool_args = "".join(
|
||||
tc.function.arguments for tc in tool_calls if tc.function.arguments
|
||||
)
|
||||
assert tool_args.endswith(remainder)
|
||||
|
||||
@@ -1382,7 +1382,20 @@ def test_adjust_request_non_mistral_tokenizer(
|
||||
[
|
||||
{"regex": r"\d+"},
|
||||
{"choice": ["a", "b"]},
|
||||
{"structural_tag": '{"key": "value"}'},
|
||||
{
|
||||
"structural_tag": json.dumps(
|
||||
{
|
||||
"structures": [
|
||||
{
|
||||
"begin": "<tool>",
|
||||
"schema": {"type": "object"},
|
||||
"end": "</tool>",
|
||||
}
|
||||
],
|
||||
"triggers": ["<tool>"],
|
||||
}
|
||||
)
|
||||
},
|
||||
{"grammar": "start: 'hello'"},
|
||||
],
|
||||
ids=["regex", "choice", "structural_tag", "grammar"],
|
||||
@@ -1404,7 +1417,18 @@ def test_adjust_request_unsupported_response_format(
|
||||
) -> None:
|
||||
request = _make_request(
|
||||
response_format=StructuralTagResponseFormat(
|
||||
type="structural_tag", format={"some": "config"}
|
||||
type="structural_tag",
|
||||
format={
|
||||
"type": "triggered_tags",
|
||||
"tags": [
|
||||
{
|
||||
"begin": "<tool>",
|
||||
"content": {"type": "any_text"},
|
||||
"end": "</tool>",
|
||||
}
|
||||
],
|
||||
"triggers": ["<tool>"],
|
||||
},
|
||||
),
|
||||
)
|
||||
result = mistral_tool_parser.adjust_request(request)
|
||||
|
||||
@@ -1305,3 +1305,79 @@ def test_swa_alignment_skip(request_runner, async_scheduling: bool):
|
||||
(1, 7),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("async_scheduling", [True, False])
|
||||
def test_stale_sliding_window_block_after_prepare_store_failure(
|
||||
request_runner, async_scheduling: bool
|
||||
):
|
||||
"""Regression test: when prepare_store fails (returns None), offloading is
|
||||
delayed. Meanwhile, sliding window blocks get freed and reallocated to the
|
||||
same request. On retry, the stale block_id must be detected and skipped.
|
||||
|
||||
Without the fix, the stale block_id would either:
|
||||
- Cause a KeyError in _remove_pending_job (duplicate in
|
||||
_block_id_to_pending_jobs)
|
||||
- Silently offload wrong data under a wrong key
|
||||
"""
|
||||
block_size = 4
|
||||
# sliding_window = 8 -> window of 2 blocks
|
||||
sliding_window = 8
|
||||
# Use a tight GPU block budget so freed sliding window blocks are
|
||||
# immediately reused by the same request's new allocations.
|
||||
num_gpu_blocks = 4
|
||||
|
||||
kv_cache_groups = [
|
||||
KVCacheGroupSpec(
|
||||
["layer0"],
|
||||
SlidingWindowSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
sliding_window=sliding_window,
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
runner = request_runner(
|
||||
block_size=block_size,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
async_scheduling=async_scheduling,
|
||||
kv_cache_groups=kv_cache_groups,
|
||||
)
|
||||
|
||||
# Request with 3 blocks of prompt. Window = 2 blocks, so block 0 is
|
||||
# outside the window but won't be freed until the next allocate_slots.
|
||||
runner.new_request(token_ids=[0] * block_size * 3)
|
||||
|
||||
# First step: prepare_store FAILS -> offloading delayed.
|
||||
# next_stored_block_idx stays at 0, block_ids[0] still holds the
|
||||
# original block_id for position 0.
|
||||
runner.manager.prepare_store.side_effect = lambda keys, req_context: None
|
||||
runner.run(decoded_tokens=[0])
|
||||
runner.manager.prepare_store.assert_called()
|
||||
|
||||
# Second step: decode more tokens -> block 3 allocated.
|
||||
# allocate_slots calls remove_skipped_blocks which frees block 0
|
||||
# (it's now outside the sliding window). With num_gpu_blocks=4,
|
||||
# the freed block is immediately reused for the new allocation.
|
||||
# prepare_store still fails so offloading is still delayed.
|
||||
runner.manager.prepare_store.side_effect = lambda keys, req_context: None
|
||||
runner.run(decoded_tokens=[0] * block_size)
|
||||
|
||||
# Now prepare_store succeeds.
|
||||
# Without the fix, the request would try to offload the stale block_id
|
||||
# at position 0 (now reused at position 3), causing a duplicate in
|
||||
# sliding_window_block_ids and eventually a KeyError.
|
||||
runner.manager.prepare_store.side_effect = lambda keys, req_context: (
|
||||
generate_store_output(keys)
|
||||
)
|
||||
# block_ids=[0, ?, 3, 1]: positions 0 and 1 are zeroed (stale blocks that
|
||||
# were freed by the sliding window and reallocated). Only blocks at
|
||||
# positions 2 and 3 (request offsets 2, 3) are stored.
|
||||
runner.run(
|
||||
decoded_tokens=[EOS_TOKEN_ID],
|
||||
expected_stored=(2, 3),
|
||||
expected_flushed=(2, 3) if not async_scheduling else (),
|
||||
)
|
||||
|
||||
@@ -461,6 +461,35 @@ def test_store_sending_thread_only_skips_on_no_available_handle():
|
||||
assert store.batch_put_from_multi_buffers.call_count == 2
|
||||
|
||||
|
||||
def test_store_sending_thread_releases_pin_on_batch_is_exist_failure():
|
||||
# `batch_is_exist` raising must still decrement `stored_requests` so the
|
||||
# scheduler can drop `delay_free_blocks` and release the pinned GPU blocks.
|
||||
store = MagicMock()
|
||||
store.batch_is_exist.side_effect = RuntimeError("mooncake down")
|
||||
thread = _make_store_sending_thread(store)
|
||||
|
||||
thread.add_stored_request("req-a")
|
||||
with pytest.raises(RuntimeError):
|
||||
thread._handle_request(_make_store_req("req-a", [b"a0", b"a1"]))
|
||||
|
||||
assert thread.stored_requests["req-a"] == 0
|
||||
store.batch_put_from_multi_buffers.assert_not_called()
|
||||
|
||||
|
||||
def test_store_sending_thread_releases_pin_on_batch_put_failure():
|
||||
# `batch_put_from_multi_buffers` raising is logged (not re-raised), and the
|
||||
# pin must still be released through the finally block.
|
||||
store = MagicMock()
|
||||
store.batch_is_exist.return_value = [0, 0]
|
||||
store.batch_put_from_multi_buffers.side_effect = RuntimeError("rdma error")
|
||||
thread = _make_store_sending_thread(store)
|
||||
|
||||
thread.add_stored_request("req-a")
|
||||
thread._handle_request(_make_store_req("req-a", [b"a0", b"a1"]))
|
||||
|
||||
assert thread.stored_requests["req-a"] == 0
|
||||
|
||||
|
||||
def test_store_recving_thread_reports_failed_block_ids():
|
||||
store = MagicMock()
|
||||
store.batch_get_into_multi_buffers.return_value = [256, -5, -7]
|
||||
|
||||
+10
-1
@@ -338,7 +338,16 @@ def _xpu_mxfp8_quantize_impl(
|
||||
shape = x.shape[:-1] + (x.shape[-1] // MXFP8_BLOCK_SIZE,)
|
||||
x_s = torch.empty(shape, device=x.device, dtype=torch.float32)
|
||||
torch.ops._C.per_token_group_fp8_quant(
|
||||
x, x_q, x_s, MXFP8_BLOCK_SIZE, eps, fp8_min, fp8_max, True
|
||||
x,
|
||||
x_q,
|
||||
x_s,
|
||||
MXFP8_BLOCK_SIZE,
|
||||
eps,
|
||||
fp8_min,
|
||||
fp8_max,
|
||||
True,
|
||||
False,
|
||||
False, # dummy_is_scale_transposed, dummy_is_tma_aligned
|
||||
)
|
||||
x_s = x_s.to(torch.float8_e8m0fnu)
|
||||
return x_q, x_s
|
||||
|
||||
@@ -517,187 +517,190 @@ class KVCacheStoreSendingThread(KVTransferThread):
|
||||
if req_id not in self.stored_requests:
|
||||
self.request_queue.task_done()
|
||||
return
|
||||
if token_len == 0:
|
||||
self.dec_stored_request(req_id)
|
||||
self.request_queue.task_done()
|
||||
return
|
||||
if self._should_skip_request(req_id):
|
||||
logger.debug(
|
||||
"Skipping Mooncake store for request %s while CPU/disk offloading "
|
||||
"is under pressure",
|
||||
req_id,
|
||||
)
|
||||
self.dec_stored_request(req_id)
|
||||
self.request_queue.task_done()
|
||||
return
|
||||
|
||||
# Within each lcm region only per-spec relevant chunks are loaded
|
||||
# (e.g., SWA or linear attn), so mask out irrelevant chunks
|
||||
store_masks = self.coord.store_mask(token_len)
|
||||
starts: list[int] = []
|
||||
ends: list[int] = []
|
||||
keys: list[str] = []
|
||||
block_hashes: list[BlockHash] = []
|
||||
group_indices: list[int] = []
|
||||
for g_idx, db in enumerate(self.token_databases):
|
||||
mask = store_masks[g_idx]
|
||||
for chunk_idx, (start, end, key) in enumerate(
|
||||
db.process_tokens(token_len, req_meta.block_hashes)
|
||||
):
|
||||
if chunk_idx >= len(mask) or not mask[chunk_idx]:
|
||||
continue
|
||||
starts.append(start)
|
||||
ends.append(end)
|
||||
keys.append(key.to_string())
|
||||
block_hashes.append(req_meta.block_hashes[chunk_idx])
|
||||
group_indices.append(g_idx)
|
||||
|
||||
# Apply put_step striding for TP
|
||||
sl = slice(self.tp_rank % self.put_step, None, self.put_step)
|
||||
starts = starts[sl]
|
||||
ends = ends[sl]
|
||||
keys = keys[sl]
|
||||
block_hashes = block_hashes[sl]
|
||||
group_indices = group_indices[sl]
|
||||
|
||||
if not keys:
|
||||
self.dec_stored_request(req_id)
|
||||
return
|
||||
|
||||
# Check which blocks already exist (dedup)
|
||||
save_exists_start = time.perf_counter()
|
||||
# Decrement the in-flight counter and signal task_done() in `finally`
|
||||
# so the scheduler can release the GPU blocks it pinned for this
|
||||
# request (via `delay_free_blocks`) even when the store path raises.
|
||||
try:
|
||||
exists_states = self.store.batch_is_exist(keys)
|
||||
except Exception:
|
||||
if token_len == 0:
|
||||
return
|
||||
if self._should_skip_request(req_id):
|
||||
logger.debug(
|
||||
"Skipping Mooncake store for request %s while CPU/disk "
|
||||
"offloading is under pressure",
|
||||
req_id,
|
||||
)
|
||||
return
|
||||
|
||||
# Within each lcm region only per-spec relevant chunks are loaded
|
||||
# (e.g., SWA or linear attn), so mask out irrelevant chunks
|
||||
store_masks = self.coord.store_mask(token_len)
|
||||
starts: list[int] = []
|
||||
ends: list[int] = []
|
||||
keys: list[str] = []
|
||||
block_hashes: list[BlockHash] = []
|
||||
group_indices: list[int] = []
|
||||
for g_idx, db in enumerate(self.token_databases):
|
||||
mask = store_masks[g_idx]
|
||||
for chunk_idx, (start, end, key) in enumerate(
|
||||
db.process_tokens(token_len, req_meta.block_hashes)
|
||||
):
|
||||
if chunk_idx >= len(mask) or not mask[chunk_idx]:
|
||||
continue
|
||||
starts.append(start)
|
||||
ends.append(end)
|
||||
keys.append(key.to_string())
|
||||
block_hashes.append(req_meta.block_hashes[chunk_idx])
|
||||
group_indices.append(g_idx)
|
||||
|
||||
# Apply put_step striding for TP
|
||||
sl = slice(self.tp_rank % self.put_step, None, self.put_step)
|
||||
starts = starts[sl]
|
||||
ends = ends[sl]
|
||||
keys = keys[sl]
|
||||
block_hashes = block_hashes[sl]
|
||||
group_indices = group_indices[sl]
|
||||
|
||||
if not keys:
|
||||
return
|
||||
|
||||
# Check which blocks already exist (dedup)
|
||||
save_exists_start = time.perf_counter()
|
||||
try:
|
||||
exists_states = self.store.batch_is_exist(keys)
|
||||
except Exception:
|
||||
self._record_operation(
|
||||
"save_exists",
|
||||
save_exists_start,
|
||||
len(keys),
|
||||
status="error",
|
||||
num_failed_keys=len(keys),
|
||||
)
|
||||
raise
|
||||
self._record_operation(
|
||||
"save_exists",
|
||||
save_exists_start,
|
||||
len(keys),
|
||||
status="error",
|
||||
num_failed_keys=len(keys),
|
||||
)
|
||||
raise
|
||||
self._record_operation(
|
||||
"save_exists",
|
||||
save_exists_start,
|
||||
len(keys),
|
||||
)
|
||||
missing_indices = [i for i, exists in enumerate(exists_states) if exists != 1]
|
||||
missing_indices = [
|
||||
i for i, exists in enumerate(exists_states) if exists != 1
|
||||
]
|
||||
|
||||
if not missing_indices:
|
||||
self.dec_stored_request(req_id)
|
||||
return
|
||||
if not missing_indices:
|
||||
return
|
||||
|
||||
starts = [starts[i] for i in missing_indices]
|
||||
ends = [ends[i] for i in missing_indices]
|
||||
keys = [keys[i] for i in missing_indices]
|
||||
block_hashes = [block_hashes[i] for i in missing_indices]
|
||||
group_indices = [group_indices[i] for i in missing_indices]
|
||||
starts = [starts[i] for i in missing_indices]
|
||||
ends = [ends[i] for i in missing_indices]
|
||||
keys = [keys[i] for i in missing_indices]
|
||||
block_hashes = [block_hashes[i] for i in missing_indices]
|
||||
group_indices = [group_indices[i] for i in missing_indices]
|
||||
|
||||
logger.debug(
|
||||
"Storing KV cache for %d blocks (groups=%s) for request %s",
|
||||
len(keys),
|
||||
set(group_indices),
|
||||
req_id,
|
||||
)
|
||||
|
||||
addrs: list[list[int]] = []
|
||||
sizes: list[list[int]] = []
|
||||
stored_events: list[BlockStored] = []
|
||||
# parent_block_hash chains live within a group, not across.
|
||||
prev_key_per_group: dict[int, Any] = {}
|
||||
new_block_hashes = [maybe_convert_block_hash(bh) for bh in block_hashes]
|
||||
|
||||
for idx, (s, e, g_idx) in enumerate(
|
||||
zip(starts, ends, group_indices, strict=True)
|
||||
):
|
||||
db = self.token_databases[g_idx]
|
||||
addr, size, _ = db.prepare_value(s, e, block_ids_per_group[g_idx])
|
||||
addrs.append(addr)
|
||||
sizes.append(size)
|
||||
|
||||
if self.enable_kv_event:
|
||||
token_ids = (
|
||||
req_meta.token_ids[s:e] if req_meta.token_ids is not None else None
|
||||
)
|
||||
stored_event = BlockStored(
|
||||
block_hashes=[new_block_hashes[idx]],
|
||||
parent_block_hash=prev_key_per_group.get(g_idx),
|
||||
token_ids=token_ids,
|
||||
block_size=req_meta.original_block_size,
|
||||
lora_id=None,
|
||||
medium="cpu",
|
||||
lora_name=None,
|
||||
)
|
||||
stored_events.append(stored_event)
|
||||
prev_key_per_group[g_idx] = new_block_hashes[idx]
|
||||
|
||||
if current_event is not None:
|
||||
current_event.synchronize()
|
||||
|
||||
batch_bytes = _sum_batch_bytes(sizes)
|
||||
put_start = time.perf_counter()
|
||||
try:
|
||||
res = self.store.batch_put_from_multi_buffers(
|
||||
keys,
|
||||
addrs,
|
||||
sizes,
|
||||
self.replicate_config,
|
||||
)
|
||||
failed = [i for i, v in enumerate(res) if v < 0]
|
||||
self._record_operation(
|
||||
"save_put",
|
||||
put_start,
|
||||
logger.debug(
|
||||
"Storing KV cache for %d blocks (groups=%s) for request %s",
|
||||
len(keys),
|
||||
num_bytes=batch_bytes,
|
||||
status="partial_failure" if failed else "ok",
|
||||
num_failed_keys=len(failed),
|
||||
set(group_indices),
|
||||
req_id,
|
||||
)
|
||||
if failed:
|
||||
failed_codes = set(res[i] for i in failed)
|
||||
logger.warning(
|
||||
"batch_put failed: %d/%d keys failed "
|
||||
"(codes=%s, batch_bytes=%d, num_keys=%d), "
|
||||
"first_key=%s",
|
||||
len(failed),
|
||||
len(keys),
|
||||
failed_codes,
|
||||
batch_bytes,
|
||||
len(keys),
|
||||
keys[0] if keys else "N/A",
|
||||
)
|
||||
if (
|
||||
MOONCAKE_NO_AVAILABLE_HANDLE in failed_codes
|
||||
and not self._mark_request_skipped_for_pressure(req_id)
|
||||
):
|
||||
logger.warning(
|
||||
"Detected Mooncake CPU/disk offloading pressure "
|
||||
"(NO_AVAILABLE_HANDLE); skipping future store "
|
||||
"batches for request %s until a later store "
|
||||
"batch succeeds",
|
||||
req_id,
|
||||
|
||||
addrs: list[list[int]] = []
|
||||
sizes: list[list[int]] = []
|
||||
stored_events: list[BlockStored] = []
|
||||
# parent_block_hash chains live within a group, not across.
|
||||
prev_key_per_group: dict[int, Any] = {}
|
||||
new_block_hashes = [maybe_convert_block_hash(bh) for bh in block_hashes]
|
||||
|
||||
for idx, (s, e, g_idx) in enumerate(
|
||||
zip(starts, ends, group_indices, strict=True)
|
||||
):
|
||||
db = self.token_databases[g_idx]
|
||||
addr, size, _ = db.prepare_value(s, e, block_ids_per_group[g_idx])
|
||||
addrs.append(addr)
|
||||
sizes.append(size)
|
||||
|
||||
if self.enable_kv_event:
|
||||
token_ids = (
|
||||
req_meta.token_ids[s:e]
|
||||
if req_meta.token_ids is not None
|
||||
else None
|
||||
)
|
||||
elif self._clear_store_pressure():
|
||||
logger.info(
|
||||
"Mooncake CPU/disk offloading pressure cleared after a "
|
||||
"successful store batch"
|
||||
stored_event = BlockStored(
|
||||
block_hashes=[new_block_hashes[idx]],
|
||||
parent_block_hash=prev_key_per_group.get(g_idx),
|
||||
token_ids=token_ids,
|
||||
block_size=req_meta.original_block_size,
|
||||
lora_id=None,
|
||||
medium="cpu",
|
||||
lora_name=None,
|
||||
)
|
||||
stored_events.append(stored_event)
|
||||
prev_key_per_group[g_idx] = new_block_hashes[idx]
|
||||
|
||||
if current_event is not None:
|
||||
current_event.synchronize()
|
||||
|
||||
batch_bytes = _sum_batch_bytes(sizes)
|
||||
put_start = time.perf_counter()
|
||||
try:
|
||||
res = self.store.batch_put_from_multi_buffers(
|
||||
keys,
|
||||
addrs,
|
||||
sizes,
|
||||
self.replicate_config,
|
||||
)
|
||||
except Exception as e:
|
||||
self._record_operation(
|
||||
"save_put",
|
||||
put_start,
|
||||
len(keys),
|
||||
num_bytes=batch_bytes,
|
||||
status="error",
|
||||
num_failed_keys=len(keys),
|
||||
)
|
||||
logger.error("Failed to put key %s, error: %s", keys, e)
|
||||
failed = [i for i, v in enumerate(res) if v < 0]
|
||||
self._record_operation(
|
||||
"save_put",
|
||||
put_start,
|
||||
len(keys),
|
||||
num_bytes=batch_bytes,
|
||||
status="partial_failure" if failed else "ok",
|
||||
num_failed_keys=len(failed),
|
||||
)
|
||||
if failed:
|
||||
failed_codes = set(res[i] for i in failed)
|
||||
logger.warning(
|
||||
"batch_put failed: %d/%d keys failed "
|
||||
"(codes=%s, batch_bytes=%d, num_keys=%d), "
|
||||
"first_key=%s",
|
||||
len(failed),
|
||||
len(keys),
|
||||
failed_codes,
|
||||
batch_bytes,
|
||||
len(keys),
|
||||
keys[0] if keys else "N/A",
|
||||
)
|
||||
if (
|
||||
MOONCAKE_NO_AVAILABLE_HANDLE in failed_codes
|
||||
and not self._mark_request_skipped_for_pressure(req_id)
|
||||
):
|
||||
logger.warning(
|
||||
"Detected Mooncake CPU/disk offloading pressure "
|
||||
"(NO_AVAILABLE_HANDLE); skipping future store "
|
||||
"batches for request %s until a later store "
|
||||
"batch succeeds",
|
||||
req_id,
|
||||
)
|
||||
elif self._clear_store_pressure():
|
||||
logger.info(
|
||||
"Mooncake CPU/disk offloading pressure cleared after a "
|
||||
"successful store batch"
|
||||
)
|
||||
except Exception as e:
|
||||
self._record_operation(
|
||||
"save_put",
|
||||
put_start,
|
||||
len(keys),
|
||||
num_bytes=batch_bytes,
|
||||
status="error",
|
||||
num_failed_keys=len(keys),
|
||||
)
|
||||
logger.error("Failed to put key %s, error: %s", keys, e)
|
||||
|
||||
if self.enable_kv_event and stored_events:
|
||||
self.update_kv_event(stored_events)
|
||||
|
||||
self.dec_stored_request(req_id)
|
||||
self.request_queue.task_done()
|
||||
if self.enable_kv_event and stored_events:
|
||||
self.update_kv_event(stored_events)
|
||||
finally:
|
||||
self.dec_stored_request(req_id)
|
||||
self.request_queue.task_done()
|
||||
|
||||
|
||||
class KVCacheStoreRecvingThread(KVTransferThread):
|
||||
|
||||
@@ -286,6 +286,8 @@ class OffloadingConnectorScheduler:
|
||||
self._req_status: dict[ReqId, RequestOffloadState] = {}
|
||||
self._current_batch_load_jobs: dict[int, TransferJob] = {}
|
||||
self._current_batch_jobs_to_flush: set[int] = set()
|
||||
# GPU block IDs allocated in the current engine step
|
||||
self._current_batch_allocated_block_ids: set[int] = set()
|
||||
# if GPU prefix caching is enabled,
|
||||
# track loaded blocks to avoid redundant loads
|
||||
self._blocks_being_loaded: set[OffloadKey] | None = (
|
||||
@@ -589,6 +591,10 @@ class OffloadingConnectorScheduler:
|
||||
req_status.group_states,
|
||||
blocks.blocks,
|
||||
):
|
||||
self._current_batch_allocated_block_ids.update(
|
||||
block.block_id for block in group_blocks if block.block_id != 0
|
||||
)
|
||||
|
||||
gpu_block_size = group_config.gpu_block_size
|
||||
offloaded_block_size = group_config.offloaded_block_size
|
||||
offload_keys = group_state.offload_keys
|
||||
@@ -638,17 +644,6 @@ class OffloadingConnectorScheduler:
|
||||
if req_status.offloading_context.policy == OffloadPolicy.BLOCK_LEVEL:
|
||||
group_state.next_stored_block_idx = num_blocks
|
||||
|
||||
# Fence dst blocks against finished-request pending stores.
|
||||
if (
|
||||
self._block_id_to_pending_jobs
|
||||
and not self._block_id_to_pending_jobs.keys().isdisjoint(dst_block_ids)
|
||||
):
|
||||
self._current_batch_jobs_to_flush.update(
|
||||
jid
|
||||
for bid in dst_block_ids
|
||||
for jid in self._block_id_to_pending_jobs.get(bid, ())
|
||||
)
|
||||
|
||||
src_spec = self.manager.prepare_load(keys_to_load, req_status.req_context)
|
||||
dst_spec = GPULoadStoreSpec(
|
||||
dst_block_ids, group_sizes=group_sizes, block_indices=block_indices
|
||||
@@ -672,37 +667,68 @@ class OffloadingConnectorScheduler:
|
||||
if self._blocks_being_loaded is not None:
|
||||
self._blocks_being_loaded.update(keys_to_load)
|
||||
|
||||
def _build_store_jobs(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
) -> dict[int, TransferJob]:
|
||||
block_size_factor = self.config.block_size_factor
|
||||
store_jobs: dict[int, TransferJob] = {}
|
||||
# iterate over both new and cached requests
|
||||
def _update_req_states(self, scheduler_output: SchedulerOutput) -> None:
|
||||
"""
|
||||
Update request states from the Scheduler's output.
|
||||
"""
|
||||
|
||||
# new_block_ids_end[req_id][i] = end of pre-existing block_ids for
|
||||
# the i-th sliding window group (before this step's extend).
|
||||
# Used to detect sliding window blocks that got re-allocated.
|
||||
new_block_ids_end: dict[str, tuple[int, ...]] = {}
|
||||
|
||||
for req_id, new_block_id_groups, preempted in yield_req_data(scheduler_output):
|
||||
req_status = self._req_status[req_id]
|
||||
req_status.update_offload_keys()
|
||||
req = req_status.req
|
||||
|
||||
if preempted:
|
||||
for group_state in req_status.group_states:
|
||||
group_state.block_ids.clear()
|
||||
|
||||
if new_block_id_groups:
|
||||
if self._sliding_window_groups:
|
||||
new_block_ids_end[req_id] = tuple(
|
||||
len(req_status.group_states[grp_idx].block_ids)
|
||||
for grp_idx in self._sliding_window_groups
|
||||
)
|
||||
req_status.update_block_id_groups(new_block_id_groups)
|
||||
# Fence new blocks against in-flight stores.
|
||||
if self._block_id_to_pending_jobs:
|
||||
new_blocks_flat = [
|
||||
bid for new_blocks in new_block_id_groups for bid in new_blocks
|
||||
]
|
||||
if not self._block_id_to_pending_jobs.keys().isdisjoint(
|
||||
new_blocks_flat
|
||||
):
|
||||
self._current_batch_jobs_to_flush.update(
|
||||
jid
|
||||
for bid in new_blocks_flat
|
||||
for jid in self._block_id_to_pending_jobs.get(bid, ())
|
||||
)
|
||||
for new_blocks in new_block_id_groups:
|
||||
for bid in new_blocks:
|
||||
if bid != 0:
|
||||
self._current_batch_allocated_block_ids.add(bid)
|
||||
|
||||
# Zero out stale block_ids in sliding window groups' pending-store
|
||||
# positions. Only sliding window groups can have stale entries (blocks
|
||||
# freed by remove_skipped_blocks then reallocated). Only positions in
|
||||
# [next_stored_block_idx * bsf, end) need checking where end is the
|
||||
# pre-extend length: earlier positions were already offloaded, later
|
||||
# ones are fresh allocations from this step.
|
||||
if self._sliding_window_groups and self._current_batch_allocated_block_ids:
|
||||
block_size_factor = self.config.block_size_factor
|
||||
for req_id, req_status in self._req_status.items():
|
||||
ends = new_block_ids_end.get(req_id)
|
||||
for i, grp_idx in enumerate(self._sliding_window_groups):
|
||||
group_state = req_status.group_states[grp_idx]
|
||||
start = group_state.next_stored_block_idx * block_size_factor
|
||||
end = ends[i] if ends is not None else len(group_state.block_ids)
|
||||
for j in range(start, end):
|
||||
if (
|
||||
group_state.block_ids[j]
|
||||
in self._current_batch_allocated_block_ids
|
||||
):
|
||||
group_state.block_ids[j] = 0
|
||||
|
||||
def _build_store_jobs(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
) -> dict[int, TransferJob]:
|
||||
block_size_factor = self.config.block_size_factor
|
||||
store_jobs: dict[int, TransferJob] = {}
|
||||
for req_id in scheduler_output.num_scheduled_tokens:
|
||||
req_status = self._req_status.get(req_id)
|
||||
if req_status is None:
|
||||
continue
|
||||
req = req_status.req
|
||||
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
num_tokens_after_batch = req.num_computed_tokens + num_scheduled_tokens
|
||||
@@ -735,11 +761,8 @@ class OffloadingConnectorScheduler:
|
||||
# For each block to offload, take the last corresponding GPU block.
|
||||
# e.g. if block size factor is 3 and GPU block IDs are
|
||||
# 1 5 6 7 2 4 9 3 8 then we'll take blocks 6 4 8.
|
||||
# We will use these GPU blocks to determine if the block needs
|
||||
# offloading, or (if the GPU block ID is 0) this block should
|
||||
# be skipped due to sliding window attention / SSM.
|
||||
# We know that if a block is skipped, then all the previous blocks
|
||||
# are skipped as well. This is why we take the last of each block.
|
||||
# A block_id of 0 means either a sliding window / SSM skip
|
||||
# or a stale entry that was zeroed out — skip it either way.
|
||||
offload_block_ids = group_state.block_ids[
|
||||
start_block_idx * block_size_factor
|
||||
+ block_size_factor
|
||||
@@ -814,10 +837,8 @@ class OffloadingConnectorScheduler:
|
||||
for i in range(block_size_factor):
|
||||
block_id = block_ids[gpu_block_idx + i]
|
||||
if block_id == 0:
|
||||
# skipped blocks cannot appear after non-skipped blocks
|
||||
assert start_gpu_block_idx is None
|
||||
continue
|
||||
elif start_gpu_block_idx is None:
|
||||
if start_gpu_block_idx is None:
|
||||
start_gpu_block_idx = gpu_block_idx + i
|
||||
src_block_ids.append(block_id)
|
||||
num_group_blocks += 1
|
||||
@@ -875,6 +896,9 @@ class OffloadingConnectorScheduler:
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: SchedulerOutput
|
||||
) -> KVConnectorMetadata:
|
||||
self._update_req_states(scheduler_output)
|
||||
|
||||
# Flush jobs for preempted requests.
|
||||
for req_id in scheduler_output.preempted_req_ids or ():
|
||||
req_status = self._req_status.get(req_id)
|
||||
if req_status is None or not req_status.transfer_jobs:
|
||||
@@ -883,6 +907,20 @@ class OffloadingConnectorScheduler:
|
||||
assert self._jobs[any_jid].is_store
|
||||
self._current_batch_jobs_to_flush.update(req_status.transfer_jobs)
|
||||
|
||||
# Flush jobs that contain re-allocated blocks.
|
||||
if (
|
||||
self._block_id_to_pending_jobs
|
||||
and not self._block_id_to_pending_jobs.keys().isdisjoint(
|
||||
self._current_batch_allocated_block_ids
|
||||
)
|
||||
):
|
||||
self._current_batch_jobs_to_flush.update(
|
||||
jid
|
||||
for bid in self._current_batch_allocated_block_ids
|
||||
if bid in self._block_id_to_pending_jobs
|
||||
for jid in self._block_id_to_pending_jobs[bid]
|
||||
)
|
||||
|
||||
# If all tracked requests are finished, flush all pending jobs
|
||||
# (both store and load) - there might not be a future scheduler
|
||||
# step to trigger their completion.
|
||||
@@ -898,6 +936,7 @@ class OffloadingConnectorScheduler:
|
||||
)
|
||||
self._current_batch_load_jobs = {}
|
||||
self._current_batch_jobs_to_flush = set()
|
||||
self._current_batch_allocated_block_ids = set()
|
||||
return meta
|
||||
|
||||
def update_connector_output(self, connector_output: KVConnectorOutput):
|
||||
@@ -1013,6 +1052,7 @@ class OffloadingConnectorScheduler:
|
||||
# reset_cache cannot be called in the middle of a schedule step
|
||||
assert not self._current_batch_load_jobs
|
||||
assert not self._current_batch_jobs_to_flush
|
||||
assert not self._current_batch_allocated_block_ids
|
||||
|
||||
# Flush all in-flight jobs
|
||||
self._current_batch_jobs_to_flush.update(self._jobs.keys())
|
||||
|
||||
@@ -30,6 +30,8 @@ from vllm.entrypoints.openai.engine.protocol import (
|
||||
StructuralTagResponseFormat,
|
||||
ToolCall,
|
||||
UsageInfo,
|
||||
validate_structural_tag_response_format,
|
||||
validate_structured_outputs_structural_tag,
|
||||
)
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.logger import init_logger
|
||||
@@ -671,6 +673,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
parameter="response_format",
|
||||
)
|
||||
|
||||
if rf_type == "structural_tag":
|
||||
validate_structural_tag_response_format(response_format)
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@@ -754,6 +759,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
"You can only either use constraints for structured outputs "
|
||||
"or tools, not both.",
|
||||
)
|
||||
validate_structured_outputs_structural_tag(structured_outputs_kwargs)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@@ -979,6 +985,16 @@ class BatchChatCompletionRequest(OpenAIBaseModel):
|
||||
"Batch chat completions do not support beam search. "
|
||||
"Please set `use_beam_search` to False."
|
||||
)
|
||||
response_format = data.get("response_format")
|
||||
rf_type = (
|
||||
response_format.get("type")
|
||||
if isinstance(response_format, dict)
|
||||
else getattr(response_format, "type", None)
|
||||
)
|
||||
if rf_type == "structural_tag":
|
||||
validate_structural_tag_response_format(response_format)
|
||||
if (structured_outputs := data.get("structured_outputs")) is not None:
|
||||
validate_structured_outputs_structural_tag(structured_outputs)
|
||||
n = data.get("n", 1)
|
||||
if n is not None and n != 1:
|
||||
raise ValueError(
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import json
|
||||
import time
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from collections.abc import Sequence as GenericSequence
|
||||
@@ -40,9 +39,7 @@ from vllm.entrypoints.openai.chat_completion.stream_harmony import (
|
||||
extract_harmony_streaming_delta,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
DeltaFunctionCall,
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ErrorResponse,
|
||||
FunctionCall,
|
||||
PromptTokenUsageInfo,
|
||||
@@ -65,7 +62,7 @@ from vllm.entrypoints.utils import get_max_tokens, should_include_usage
|
||||
from vllm.inputs import EngineInput
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logprobs import Logprob
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.parser import ParserManager
|
||||
from vllm.parser.abstract_parser import Parser
|
||||
from vllm.reasoning import ReasoningParser
|
||||
@@ -715,6 +712,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
delta_token_ids=as_list(output.token_ids),
|
||||
request=request,
|
||||
prompt_token_ids=res.prompt_token_ids,
|
||||
finished=output.finish_reason is not None,
|
||||
)
|
||||
if delta_message and delta_message.tool_calls:
|
||||
tools_streamed[i] = True
|
||||
@@ -805,81 +803,13 @@ class OpenAIServingChat(OpenAIServing):
|
||||
# finish_reason='error' indicates a retryable error
|
||||
self._raise_if_error(output.finish_reason, request_id)
|
||||
|
||||
# check to make sure we haven't "forgotten" to stream
|
||||
# any tokens that were generated but previously
|
||||
# matched by partial json parsing
|
||||
# only happens if we are NOT using structured outputs
|
||||
index = 0
|
||||
auto_tools_called = False
|
||||
if tool_parser:
|
||||
auto_tools_called = len(tool_parser.prev_tool_call_arr) > 0
|
||||
index = (
|
||||
len(tool_parser.prev_tool_call_arr) - 1
|
||||
if auto_tools_called
|
||||
else 0
|
||||
)
|
||||
should_check = (
|
||||
self._should_check_for_unstreamed_tool_arg_tokens(
|
||||
delta_message, output
|
||||
)
|
||||
)
|
||||
# only check if there are any tool calls
|
||||
# detected by partial parsing
|
||||
if should_check and tool_parser and auto_tools_called:
|
||||
latest_delta_len = 0
|
||||
if (
|
||||
isinstance(
|
||||
delta_message.tool_calls[0].function,
|
||||
DeltaFunctionCall,
|
||||
)
|
||||
) and isinstance(
|
||||
delta_message.tool_calls[0].function.arguments, str
|
||||
):
|
||||
latest_delta_len = len(
|
||||
delta_message.tool_calls[0].function.arguments
|
||||
)
|
||||
|
||||
# get the expected call based on partial JSON
|
||||
# parsing which "autocompletes" the JSON.
|
||||
# Tool parsers (e.g. Qwen3Coder) store
|
||||
# arguments as a JSON string in
|
||||
# prev_tool_call_arr. Calling json.dumps()
|
||||
# on an already-serialized string would
|
||||
# double-serialize it (e.g. '{"k":1}' becomes
|
||||
# '"{\\"k\\":1}"'), which then causes the
|
||||
# replace() below to fail and append the
|
||||
# entire double-serialized string as a
|
||||
# spurious final delta.
|
||||
args = tool_parser.prev_tool_call_arr[index].get(
|
||||
"arguments", {}
|
||||
)
|
||||
if isinstance(args, str):
|
||||
expected_call = args
|
||||
else:
|
||||
expected_call = json.dumps(args, ensure_ascii=False)
|
||||
|
||||
# get what we've streamed so far for arguments
|
||||
# for the current tool
|
||||
actual_call = tool_parser.streamed_args_for_tool[index]
|
||||
if latest_delta_len > 0:
|
||||
actual_call = actual_call[:-latest_delta_len]
|
||||
|
||||
# check to see if there's anything left to stream
|
||||
remaining_call = expected_call.replace(actual_call, "", 1)
|
||||
# set that as a delta message
|
||||
delta_message = self._create_remaining_args_delta(
|
||||
delta_message, remaining_call, index
|
||||
)
|
||||
|
||||
# Send the finish response for each request.n only once
|
||||
# In OpenAI's API, when a tool is called, the
|
||||
# finish_reason is:
|
||||
# "tool_calls" for "auto" or "required" tool calls,
|
||||
# and "stop" for named tool calls.
|
||||
if (
|
||||
auto_tools_called
|
||||
or (tools_streamed[i] and not tool_choice_function_name)
|
||||
or (self.use_harmony and harmony_tools_streamed[i])
|
||||
if (tools_streamed[i] and not tool_choice_function_name) or (
|
||||
self.use_harmony and harmony_tools_streamed[i]
|
||||
):
|
||||
finish_reason_ = "tool_calls"
|
||||
else:
|
||||
@@ -1535,56 +1465,3 @@ class OpenAIServingChat(OpenAIServing):
|
||||
and self.enable_auto_tools
|
||||
and request.tool_choice in ["auto", None]
|
||||
)
|
||||
|
||||
def _should_check_for_unstreamed_tool_arg_tokens(
|
||||
self,
|
||||
delta_message: DeltaMessage | None,
|
||||
output: CompletionOutput,
|
||||
) -> bool:
|
||||
"""
|
||||
Check to see if we should check for unstreamed tool arguments tokens.
|
||||
This is only applicable when auto tool parsing is enabled, the delta
|
||||
is a tool call with arguments.
|
||||
"""
|
||||
|
||||
return bool(
|
||||
# if there is a delta message that includes tool calls which
|
||||
# include a function that has arguments
|
||||
output.finish_reason is not None
|
||||
and self.enable_auto_tools
|
||||
and self.tool_parser
|
||||
and delta_message
|
||||
and delta_message.tool_calls
|
||||
and delta_message.tool_calls[0]
|
||||
and delta_message.tool_calls[0].function
|
||||
and delta_message.tool_calls[0].function.arguments is not None
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _create_remaining_args_delta(
|
||||
delta_message: DeltaMessage,
|
||||
remaining_call: str,
|
||||
index: int,
|
||||
) -> DeltaMessage:
|
||||
"""
|
||||
Create a delta message for remaining tool arguments, preserving
|
||||
id/type/name from the original delta.
|
||||
"""
|
||||
original_tc = next(
|
||||
(tc for tc in delta_message.tool_calls if tc.index == index),
|
||||
None,
|
||||
)
|
||||
original_fn = original_tc.function if original_tc else None
|
||||
return DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=index,
|
||||
id=original_tc.id if original_tc else None,
|
||||
type=original_tc.type if original_tc else None,
|
||||
function=DeltaFunctionCall(
|
||||
name=original_fn.name if original_fn else None,
|
||||
arguments=remaining_call,
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
@@ -18,6 +18,8 @@ from vllm.entrypoints.openai.engine.protocol import (
|
||||
StreamOptions,
|
||||
StructuralTagResponseFormat,
|
||||
UsageInfo,
|
||||
validate_structural_tag_response_format,
|
||||
validate_structured_outputs_structural_tag,
|
||||
)
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.logger import init_logger
|
||||
@@ -370,6 +372,9 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
parameter="response_format",
|
||||
)
|
||||
|
||||
if rf_type == "structural_tag":
|
||||
validate_structural_tag_response_format(response_format)
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@@ -397,6 +402,7 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
"outputs ('json', 'regex' or 'choice').",
|
||||
parameter="structured_outputs",
|
||||
)
|
||||
validate_structured_outputs_structural_tag(structured_outputs_kwargs)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
|
||||
@@ -17,6 +17,7 @@ from pydantic import (
|
||||
)
|
||||
|
||||
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
@@ -158,6 +159,80 @@ AnyResponseFormat: TypeAlias = (
|
||||
)
|
||||
|
||||
|
||||
def validate_structural_tag_response_format(
|
||||
response_format: AnyStructuralTagResponseFormat | dict[str, Any],
|
||||
) -> None:
|
||||
"""Validate structural tags before they are sent to the engine.
|
||||
|
||||
Engine-side validation reports malformed structural tags as generation
|
||||
failures. OpenAI request parsing should classify them as bad requests.
|
||||
"""
|
||||
import json
|
||||
|
||||
from pydantic import TypeAdapter, ValidationError
|
||||
|
||||
if isinstance(response_format, dict):
|
||||
try:
|
||||
response_format = TypeAdapter(
|
||||
AnyStructuralTagResponseFormat
|
||||
).validate_python(response_format)
|
||||
except ValidationError as exc:
|
||||
raise VLLMValidationError(
|
||||
"Invalid response_format structural_tag specification.",
|
||||
parameter="response_format",
|
||||
) from exc
|
||||
|
||||
try:
|
||||
payload = json.dumps(response_format.model_dump(by_alias=True))
|
||||
validate_structural_tag_payload(payload, parameter="response_format")
|
||||
except (TypeError, ValueError) as exc:
|
||||
raise VLLMValidationError(
|
||||
"Invalid response_format structural_tag specification.",
|
||||
parameter="response_format",
|
||||
) from exc
|
||||
|
||||
|
||||
def validate_structural_tag_payload(payload: Any, *, parameter: str) -> None:
|
||||
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
|
||||
from vllm.v1.structured_output.backend_xgrammar import validate_xgrammar_grammar
|
||||
|
||||
if isinstance(payload, str) and not payload:
|
||||
raise VLLMValidationError(
|
||||
f"Invalid {parameter} structural_tag specification.",
|
||||
parameter=parameter,
|
||||
)
|
||||
|
||||
try:
|
||||
validate_xgrammar_grammar(
|
||||
SamplingParams(
|
||||
structured_outputs=StructuredOutputsParams(structural_tag=payload)
|
||||
)
|
||||
)
|
||||
except (TypeError, ValueError) as exc:
|
||||
raise VLLMValidationError(
|
||||
f"Invalid {parameter} structural_tag specification.",
|
||||
parameter=parameter,
|
||||
) from exc
|
||||
|
||||
|
||||
def validate_structured_outputs_structural_tag(
|
||||
structured_outputs: Any,
|
||||
) -> None:
|
||||
from vllm.sampling_params import StructuredOutputsParams
|
||||
|
||||
if isinstance(structured_outputs, StructuredOutputsParams):
|
||||
structural_tag = structured_outputs.structural_tag
|
||||
elif isinstance(structured_outputs, dict):
|
||||
structural_tag = structured_outputs.get("structural_tag")
|
||||
else:
|
||||
return
|
||||
if structural_tag is not None:
|
||||
validate_structural_tag_payload(
|
||||
structural_tag,
|
||||
parameter="structured_outputs",
|
||||
)
|
||||
|
||||
|
||||
class StreamOptions(OpenAIBaseModel):
|
||||
include_usage: bool | None = False
|
||||
continuous_usage_stats: bool | None = False
|
||||
|
||||
@@ -575,7 +575,7 @@ def per_token_group_quant_fp8(
|
||||
|
||||
# prefer CUDA/XPU kernel if available
|
||||
# TODO(bnell): this causes some fp8 moe test to fail.
|
||||
if current_platform.is_cuda() and x.is_contiguous():
|
||||
if (current_platform.is_cuda() or current_platform.is_xpu()) and x.is_contiguous():
|
||||
torch.ops._C.per_token_group_fp8_quant(
|
||||
x,
|
||||
x_q,
|
||||
@@ -590,12 +590,6 @@ def per_token_group_quant_fp8(
|
||||
)
|
||||
return x_q, x_s
|
||||
|
||||
if current_platform.is_xpu() and x.is_contiguous():
|
||||
torch.ops._C.per_token_group_fp8_quant(
|
||||
x, x_q, x_s, group_size, eps, fp8_min, fp8_max, use_ue8m0
|
||||
)
|
||||
return x_q, x_s
|
||||
|
||||
# TRITON FALLBACK
|
||||
M = x.numel() // group_size
|
||||
N = group_size
|
||||
|
||||
@@ -173,9 +173,64 @@ MiniCPMOAudioInputs: TypeAlias = (
|
||||
|
||||
|
||||
def _minicpmo_field_config(hf_inputs: Mapping[str, torch.Tensor]):
|
||||
audio_features = hf_inputs.get("audio_features")
|
||||
audio_feature_lens = hf_inputs.get("audio_feature_lens")
|
||||
|
||||
# For multi-chunk audio (>30s), audio_features has one item per chunk
|
||||
# (total_chunks) while audio_feature_lens has one item per audio (N).
|
||||
# Use flat to group audio_features by audio so both fields
|
||||
# share the same batch size (N).
|
||||
audio_features_cfg = MultiModalFieldConfig.batched("audio")
|
||||
|
||||
if audio_features is not None and audio_feature_lens is not None:
|
||||
num_features = (
|
||||
len(audio_features)
|
||||
if isinstance(audio_features, (list, tuple))
|
||||
else audio_features.shape[0]
|
||||
)
|
||||
num_audios = (
|
||||
len(audio_feature_lens)
|
||||
if isinstance(audio_feature_lens, (list, tuple))
|
||||
else audio_feature_lens.shape[0]
|
||||
)
|
||||
|
||||
if num_features > num_audios:
|
||||
# Compute the number of chunks belonging to each audio
|
||||
chunks_per_audio: list[int] = []
|
||||
for lens in audio_feature_lens:
|
||||
if isinstance(lens, torch.Tensor):
|
||||
chunks_per_audio.append(lens.numel())
|
||||
else:
|
||||
chunks_per_audio.append(1)
|
||||
|
||||
# When audio_feature_lens is padded (e.g. from batched HF
|
||||
# processor output), numel() over-counts. Fall back to
|
||||
# counting non-zero entries so the sizes sum to num_features.
|
||||
if sum(chunks_per_audio) != num_features:
|
||||
chunks_per_audio = []
|
||||
for lens in audio_feature_lens:
|
||||
if isinstance(lens, torch.Tensor):
|
||||
n = int((lens != 0).sum())
|
||||
chunks_per_audio.append(max(n, 1))
|
||||
else:
|
||||
chunks_per_audio.append(1)
|
||||
|
||||
# Use flat (not flat_from_sizes) because audio_features
|
||||
# is list[Tensor] with variable-length chunks (post-unpad).
|
||||
slice_idxs = [0]
|
||||
for n in chunks_per_audio:
|
||||
slice_idxs.append(slice_idxs[-1] + n)
|
||||
audio_features_cfg = MultiModalFieldConfig.flat(
|
||||
"audio",
|
||||
[
|
||||
slice(slice_idxs[i], slice_idxs[i + 1])
|
||||
for i in range(len(chunks_per_audio))
|
||||
],
|
||||
)
|
||||
|
||||
return dict(
|
||||
**_minicpmv_field_config(hf_inputs),
|
||||
audio_features=MultiModalFieldConfig.batched("audio"),
|
||||
audio_features=audio_features_cfg,
|
||||
audio_feature_lens=MultiModalFieldConfig.batched("audio"),
|
||||
audio_embeds=MultiModalFieldConfig.batched("audio"),
|
||||
)
|
||||
@@ -364,11 +419,23 @@ class MiniCPMOMultiModalProcessor(MiniCPMVMultiModalProcessor[MiniCPMOProcessing
|
||||
|
||||
# Avoid padding since we need the output for each audio to be
|
||||
# independent of other audios for the cache to work correctly
|
||||
# Flatten audio_feature_lens (list of tensors of any
|
||||
# dimensionality, one per audio, each containing per-chunk
|
||||
# lengths) into a flat list of integer lengths so there is
|
||||
# one length per chunk, matching the first dimension of
|
||||
# audio_features. Using flatten() handles 0-D, 1-D, and
|
||||
# higher-dimensional tensors uniformly.
|
||||
flat_feature_lens: list[int] = []
|
||||
for lens in audio_inputs["audio_feature_lens"]:
|
||||
if isinstance(lens, torch.Tensor):
|
||||
flat_feature_lens.extend(lens.flatten().tolist())
|
||||
else:
|
||||
flat_feature_lens.append(int(lens))
|
||||
unpadded_audio_features = [
|
||||
feat[:, :feature_len]
|
||||
for feat, feature_len in zip(
|
||||
feat[:, :length]
|
||||
for feat, length in zip(
|
||||
audio_inputs["audio_features"],
|
||||
audio_inputs["audio_feature_lens"],
|
||||
flat_feature_lens,
|
||||
)
|
||||
]
|
||||
audio_inputs["audio_features"] = unpadded_audio_features
|
||||
|
||||
@@ -320,6 +320,7 @@ class Parser:
|
||||
delta_token_ids: list[int],
|
||||
request: ChatCompletionRequest | ResponsesRequest,
|
||||
prompt_token_ids: list[int] | None = None,
|
||||
finished: bool = False,
|
||||
) -> DeltaMessage | None:
|
||||
"""Parse a single streaming delta, orchestrating reasoning then
|
||||
tool call extraction via internal stream state.
|
||||
@@ -656,12 +657,28 @@ class DelegatingParser(Parser):
|
||||
return False
|
||||
return state.reasoning_ended
|
||||
|
||||
def _append_unstreamed_tool_args(
|
||||
self,
|
||||
delta_message: DeltaMessage | None,
|
||||
) -> None:
|
||||
"""Append parsed-but-unstreamed tool-call arguments to *delta_message*."""
|
||||
if (
|
||||
self._tool_parser is not None
|
||||
and delta_message
|
||||
and delta_message.tool_calls
|
||||
and (last_tc := delta_message.tool_calls[-1]).function
|
||||
):
|
||||
last_tc.function.arguments = (
|
||||
last_tc.function.arguments or ""
|
||||
) + self._tool_parser.get_remaining_unstreamed_args()
|
||||
|
||||
def parse_delta(
|
||||
self,
|
||||
delta_text: str,
|
||||
delta_token_ids: list[int],
|
||||
request: ChatCompletionRequest | ResponsesRequest,
|
||||
prompt_token_ids: list[int] | None = None,
|
||||
finished: bool = False,
|
||||
) -> DeltaMessage | None:
|
||||
state = self._stream_state
|
||||
|
||||
@@ -745,6 +762,10 @@ class DelegatingParser(Parser):
|
||||
|
||||
state.previous_text = current_text
|
||||
state.previous_token_ids = current_token_ids
|
||||
|
||||
if finished:
|
||||
self._append_unstreamed_tool_args(delta_message)
|
||||
|
||||
return delta_message
|
||||
|
||||
|
||||
|
||||
@@ -79,6 +79,25 @@ class ToolParser:
|
||||
else:
|
||||
self.tools = []
|
||||
|
||||
def get_remaining_unstreamed_args(self) -> str:
|
||||
"""Return tool call arguments parsed but not yet streamed."""
|
||||
if not self.prev_tool_call_arr:
|
||||
return ""
|
||||
index = len(self.prev_tool_call_arr) - 1
|
||||
args = self.prev_tool_call_arr[index].get("arguments", {})
|
||||
if isinstance(args, str):
|
||||
expected = args
|
||||
else:
|
||||
expected = json.dumps(args, ensure_ascii=False)
|
||||
actual = (
|
||||
self.streamed_args_for_tool[index]
|
||||
if index < len(self.streamed_args_for_tool)
|
||||
else ""
|
||||
)
|
||||
if expected.startswith(actual):
|
||||
return expected[len(actual) :]
|
||||
return ""
|
||||
|
||||
@cached_property
|
||||
def vocab(self) -> dict[str, int]:
|
||||
# NOTE: Only PreTrainedTokenizerFast is guaranteed to have .vocab
|
||||
|
||||
Reference in New Issue
Block a user