mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[CI] Add MTP + PD disagg test for Qwen3.5 (#42677)
Signed-off-by: ZhanqiuHu <zhu@redhat.com> Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
This commit is contained in:
@@ -88,7 +88,7 @@ steps:
|
||||
- tests/v1/kv_connector/nixl_integration/
|
||||
commands:
|
||||
- uv pip install --system -r /vllm-workspace/requirements/kv_connectors.txt
|
||||
- bash v1/kv_connector/nixl_integration/spec_decode_acceptance_test.sh
|
||||
- bash v1/kv_connector/nixl_integration/config_sweep_spec_decode_test.sh
|
||||
|
||||
- label: MultiConnector (Nixl+Offloading) PD edge cases (2 GPUs)
|
||||
key: multiconnector-nixl-offloading-pd-edge-cases-2-gpus
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
# Sweep wrapper for spec decode acceptance tests, following the same pattern
|
||||
# as config_sweep_accuracy_test.sh. Runs spec_decode_acceptance_test.sh once
|
||||
# per configuration.
|
||||
|
||||
SCRIPT="v1/kv_connector/nixl_integration/spec_decode_acceptance_test.sh"
|
||||
|
||||
# EAGLE3: Llama-3.1-8B-Instruct with EAGLE3 speculator.
|
||||
eagle3_config="SD_METHOD=eagle3 MODEL_NAME=meta-llama/Llama-3.1-8B-Instruct SD_MODEL=RedHatAI/Llama-3.1-8B-Instruct-speculator.eagle3 NUM_SPEC_TOKENS=3"
|
||||
|
||||
# MTP: Qwen3.5-0.8B-Base with hybrid SSM flags.
|
||||
mtp_config="SD_METHOD=mtp MODEL_NAME=Qwen/Qwen3.5-0.8B-Base SD_MODEL=Qwen/Qwen3.5-0.8B-Base NUM_SPEC_TOKENS=1 BLOCK_SIZE=32 MAX_MODEL_LEN=4096 VLLM_SSM_CONV_STATE_LAYOUT=DS ENABLE_HMA_FLAG=1 KV_BUFFER_DEVICES=cuda"
|
||||
|
||||
configs=(
|
||||
"$eagle3_config"
|
||||
"$mtp_config"
|
||||
)
|
||||
|
||||
for cfg in "${configs[@]}"; do
|
||||
local_cfg_parts=()
|
||||
read -r -a local_cfg_parts <<< "$cfg"
|
||||
echo "-> Running with: ${cfg}"
|
||||
if ! env "${local_cfg_parts[@]}" bash "${SCRIPT}"; then
|
||||
echo "❌ Test failed for config: ${cfg}"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
|
||||
echo "✅ All spec decode acceptance tests passed!"
|
||||
@@ -26,7 +26,10 @@
|
||||
# ROCm options: TRITON_ATTN, ROCM_ATTN, ROCM_AITER_FA,
|
||||
# ROCM_AITER_UNIFIED_ATTN
|
||||
# NVIDIA options: FLASH_ATTN, FLASHINFER
|
||||
set -x
|
||||
# VLLM_SSM_CONV_STATE_LAYOUT - SSM conv state layout (e.g. "DS" required for Mamba models)
|
||||
# ENABLE_HMA_FLAG - set to 1 to enable hybrid KV cache manager
|
||||
# VLLM_SERVE_EXTRA_ARGS - comma-separated extra args for vllm serve
|
||||
set -ex
|
||||
|
||||
# ── Model & spec decode config ──────────────────────────────────────────
|
||||
|
||||
@@ -82,6 +85,20 @@ if [[ -z "${ATTENTION_BACKEND:-}" ]]; then
|
||||
fi
|
||||
echo "Using attention backend: ${ATTENTION_BACKEND}"
|
||||
|
||||
# ── HMA & extra serve args ────────────────────────────────────────────
|
||||
|
||||
ENABLE_HMA_VAR=""
|
||||
if [[ -n "${ENABLE_HMA_FLAG:-}" ]]; then
|
||||
ENABLE_HMA_VAR="--no-disable-hybrid-kv-cache-manager"
|
||||
echo "HMA (Hybrid KV Cache Manager) enabled"
|
||||
fi
|
||||
|
||||
EXTRA_SERVE_ARGS=()
|
||||
if [[ -n "${VLLM_SERVE_EXTRA_ARGS:-}" ]]; then
|
||||
IFS=',' read -r -a EXTRA_SERVE_ARGS <<< "$VLLM_SERVE_EXTRA_ARGS"
|
||||
echo "Extra serve args: ${EXTRA_SERVE_ARGS[*]}"
|
||||
fi
|
||||
|
||||
cleanup_instances() {
|
||||
echo ""
|
||||
echo "Cleaning up..."
|
||||
@@ -228,6 +245,7 @@ run_test_for_device() {
|
||||
${GPU_DEVICE_VAR}=$GPU_ID \
|
||||
VLLM_KV_CACHE_LAYOUT='HND' \
|
||||
UCX_NET_DEVICES=all \
|
||||
${VLLM_SSM_CONV_STATE_LAYOUT:+VLLM_SSM_CONV_STATE_LAYOUT=$VLLM_SSM_CONV_STATE_LAYOUT} \
|
||||
VLLM_NIXL_SIDE_CHANNEL_HOST=$NIXL_SIDE_CHANNEL_HOST \
|
||||
VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT \
|
||||
vllm serve $MODEL_NAME \
|
||||
@@ -239,7 +257,9 @@ run_test_for_device() {
|
||||
--tensor-parallel-size $PREFILLER_TP_SIZE \
|
||||
--kv-transfer-config "$kv_config" \
|
||||
--speculative-config "$PREFILL_SPEC_CONFIG" \
|
||||
--attention-backend $ATTENTION_BACKEND &
|
||||
--attention-backend $ATTENTION_BACKEND \
|
||||
${ENABLE_HMA_VAR} \
|
||||
${EXTRA_SERVE_ARGS[@]+"${EXTRA_SERVE_ARGS[@]}"} &
|
||||
local SERVER_PID=$!
|
||||
|
||||
PREFILL_HOSTS+=("$SERVER_HOST")
|
||||
@@ -265,6 +285,7 @@ run_test_for_device() {
|
||||
${GPU_DEVICE_VAR}=$GPU_ID \
|
||||
VLLM_KV_CACHE_LAYOUT='HND' \
|
||||
UCX_NET_DEVICES=all \
|
||||
${VLLM_SSM_CONV_STATE_LAYOUT:+VLLM_SSM_CONV_STATE_LAYOUT=$VLLM_SSM_CONV_STATE_LAYOUT} \
|
||||
VLLM_NIXL_SIDE_CHANNEL_HOST=$NIXL_SIDE_CHANNEL_HOST \
|
||||
VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT \
|
||||
vllm serve $MODEL_NAME \
|
||||
@@ -276,7 +297,9 @@ run_test_for_device() {
|
||||
--tensor-parallel-size $DECODER_TP_SIZE \
|
||||
--kv-transfer-config "$kv_config" \
|
||||
--speculative-config "$DECODE_SPEC_CONFIG" \
|
||||
--attention-backend $ATTENTION_BACKEND &
|
||||
--attention-backend $ATTENTION_BACKEND \
|
||||
${ENABLE_HMA_VAR} \
|
||||
${EXTRA_SERVE_ARGS[@]+"${EXTRA_SERVE_ARGS[@]}"} &
|
||||
local SERVER_PID=$!
|
||||
|
||||
DECODE_HOSTS+=("$SERVER_HOST")
|
||||
@@ -303,6 +326,7 @@ run_test_for_device() {
|
||||
DECODE_PORT=${DECODE_PORTS[0]} \
|
||||
SERVER_HOST=$SERVER_HOST \
|
||||
TEST_MODEL=$MODEL_NAME \
|
||||
SD_METHOD=$SD_METHOD \
|
||||
python3 -m pytest -s -x "${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/test_spec_decode_acceptance.py"
|
||||
|
||||
# Tear down before next iteration
|
||||
|
||||
@@ -1,19 +1,18 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""NixlConnector PD + EAGLE3 speculative decoding acceptance length test.
|
||||
"""NixlConnector PD + speculative decoding acceptance length test.
|
||||
|
||||
- Loads MT-Bench prompts (80 prompts, 256 output tokens)
|
||||
- Sends through the PD proxy (completions API)
|
||||
- Scrapes Prometheus metrics from the decode server
|
||||
- Asserts acceptance length matches standalone EAGLE3 baselines
|
||||
- Asserts acceptance metrics match standalone baselines
|
||||
|
||||
Baselines from tests/v1/spec_decode/test_acceptance_length.py
|
||||
(standalone EAGLE3 with same model/drafter on MT-Bench, temp=0).
|
||||
PD disaggregation via NixlConnector should match within tolerance.
|
||||
Supports EAGLE3 (default) and MTP, selected via SD_METHOD env var.
|
||||
|
||||
Environment variables (set by spec_decode_acceptance_test.sh):
|
||||
TEST_MODEL - target model name
|
||||
DECODE_PORT - port of the decode vLLM server (for /metrics)
|
||||
SD_METHOD - "eagle3" (default) or "mtp"
|
||||
"""
|
||||
|
||||
import os
|
||||
@@ -31,28 +30,38 @@ SERVER_HOST = os.environ.get("SERVER_HOST", "127.0.0.1")
|
||||
PROXY_BASE_URL = f"http://{SERVER_HOST}:8192/v1"
|
||||
DECODE_PORT = os.environ.get("DECODE_PORT", "8200")
|
||||
MODEL_NAME = os.environ.get("TEST_MODEL", "meta-llama/Llama-3.1-8B-Instruct")
|
||||
SD_METHOD = os.environ.get("SD_METHOD", "eagle3").lower()
|
||||
|
||||
|
||||
@dataclass
|
||||
class Eagle3ModelConfig:
|
||||
verifier: str
|
||||
drafter: str
|
||||
class ModelConfig:
|
||||
model: str
|
||||
method: str
|
||||
expected_acceptance_length: float
|
||||
drafter: str = ""
|
||||
expected_acceptance_lengths_per_pos: list[float] = field(default_factory=list)
|
||||
expected_acceptance_rate: float | None = None
|
||||
id: str = ""
|
||||
rtol: float | None = None
|
||||
|
||||
|
||||
# Standalone EAGLE3 baselines (MT-Bench, 80 prompts, 256 tokens, temp=0).
|
||||
# Source: tests/v1/spec_decode/test_acceptance_length.py
|
||||
EAGLE3_MODEL_CONFIGS = [
|
||||
Eagle3ModelConfig(
|
||||
verifier="meta-llama/Llama-3.1-8B-Instruct",
|
||||
# Standalone baselines (MT-Bench, 80 prompts, 256 tokens, temp=0).
|
||||
# EAGLE3 source: tests/v1/spec_decode/test_acceptance_length.py
|
||||
MODEL_CONFIGS = [
|
||||
ModelConfig(
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
method="eagle3",
|
||||
drafter="RedHatAI/Llama-3.1-8B-Instruct-speculator.eagle3",
|
||||
expected_acceptance_length=2.60,
|
||||
expected_acceptance_lengths_per_pos=[0.7296, 0.5208, 0.3545],
|
||||
id="llama3-8b-eagle3",
|
||||
),
|
||||
ModelConfig(
|
||||
model="Qwen/Qwen3.5-0.8B-Base",
|
||||
method="mtp",
|
||||
expected_acceptance_length=1.798,
|
||||
id="qwen35-0.8b-mtp",
|
||||
),
|
||||
]
|
||||
|
||||
DEFAULT_NUM_PROMPTS = 80
|
||||
@@ -60,14 +69,14 @@ DEFAULT_OUTPUT_LEN = 256
|
||||
DEFAULT_RTOL = 0.05
|
||||
|
||||
|
||||
def _get_model_config() -> Eagle3ModelConfig:
|
||||
"""Get the model config matching MODEL_NAME."""
|
||||
for config in EAGLE3_MODEL_CONFIGS:
|
||||
if config.verifier == MODEL_NAME:
|
||||
def _get_model_config() -> ModelConfig:
|
||||
"""Get the model config matching MODEL_NAME and SD_METHOD."""
|
||||
for config in MODEL_CONFIGS:
|
||||
if config.model == MODEL_NAME and config.method == SD_METHOD:
|
||||
return config
|
||||
raise ValueError(
|
||||
f"No Eagle3ModelConfig found for model {MODEL_NAME}. "
|
||||
f"Available: {[c.verifier for c in EAGLE3_MODEL_CONFIGS]}"
|
||||
f"No config for model={MODEL_NAME}, method={SD_METHOD}. "
|
||||
f"Available: {[(c.model, c.method) for c in MODEL_CONFIGS]}"
|
||||
)
|
||||
|
||||
|
||||
@@ -161,47 +170,60 @@ def test_spec_decode_acceptance_length():
|
||||
assert n_drafts > 0, "No spec-decode drafts were generated"
|
||||
|
||||
acceptance_length = 1 + (n_accepted / n_drafts)
|
||||
|
||||
per_pos_counts = _fetch_per_position_acceptance()
|
||||
per_pos_rates = [
|
||||
per_pos_counts.get(i, 0) / n_drafts
|
||||
for i in range(len(config.expected_acceptance_lengths_per_pos))
|
||||
]
|
||||
|
||||
# ── Report ────────────────────────────────────────────────────────
|
||||
expected = config.expected_acceptance_length
|
||||
expected_per_pos = config.expected_acceptance_lengths_per_pos
|
||||
|
||||
print(
|
||||
f"\n{config.id}: acceptance_length={acceptance_length:.3f} "
|
||||
f"(expected={expected:.3f})"
|
||||
)
|
||||
print(f" Drafts: {n_drafts:.0f}, Accepted: {n_accepted:.0f}")
|
||||
for i, (actual, exp) in enumerate(zip(per_pos_rates, expected_per_pos)):
|
||||
print(f" Position {i}: {actual:.4f} (expected: {exp:.4f})")
|
||||
|
||||
# ── Assert overall acceptance length ──────────────────────────────
|
||||
# ── Assert acceptance length (all methods) ────────────────────────
|
||||
rel_error = abs(acceptance_length - expected) / expected
|
||||
|
||||
assert rel_error <= rtol, (
|
||||
f"Acceptance length regression for {config.id}! "
|
||||
f"Expected: {expected:.3f}, "
|
||||
f"Got: {acceptance_length:.3f}, "
|
||||
f"Relative error: {rel_error:.2%} (tolerance: {rtol:.0%}). "
|
||||
f"This may indicate drafter KV was not correctly transferred."
|
||||
f"Relative error: {rel_error:.2%} (tolerance: {rtol:.0%})"
|
||||
)
|
||||
|
||||
# ── Assert per-position acceptance ────────────────────────────────
|
||||
for i, (actual, exp) in enumerate(zip(per_pos_rates, expected_per_pos)):
|
||||
if exp > 0:
|
||||
pos_err = abs(actual - exp) / exp
|
||||
assert pos_err <= rtol, (
|
||||
f"Per-position acceptance regression at position {i} "
|
||||
f"for {config.id}! "
|
||||
f"Expected: {exp:.4f}, Got: {actual:.4f}, "
|
||||
f"Relative error: {pos_err:.2%} "
|
||||
f"(tolerance: {rtol:.0%})"
|
||||
)
|
||||
# ── Assert per-position acceptance (EAGLE3) ───────────────────────
|
||||
if config.expected_acceptance_lengths_per_pos:
|
||||
per_pos_counts = _fetch_per_position_acceptance()
|
||||
per_pos_rates = [
|
||||
per_pos_counts.get(i, 0) / n_drafts
|
||||
for i in range(len(config.expected_acceptance_lengths_per_pos))
|
||||
]
|
||||
for i, (actual, exp) in enumerate(
|
||||
zip(per_pos_rates, config.expected_acceptance_lengths_per_pos)
|
||||
):
|
||||
print(f" Position {i}: {actual:.4f} (expected: {exp:.4f})")
|
||||
if exp > 0:
|
||||
pos_err = abs(actual - exp) / exp
|
||||
assert pos_err <= rtol, (
|
||||
f"Per-position regression at pos {i} for {config.id}! "
|
||||
f"Expected: {exp:.4f}, Got: {actual:.4f}, "
|
||||
f"Relative error: {pos_err:.2%} (tolerance: {rtol:.0%})"
|
||||
)
|
||||
|
||||
# ── Assert acceptance rate (MTP) ──────────────────────────────────
|
||||
if config.expected_acceptance_rate is not None:
|
||||
n_draft_tokens = _fetch_metric("vllm:spec_decode_num_draft_tokens_total")
|
||||
acceptance_rate = n_accepted / n_draft_tokens if n_draft_tokens > 0 else 0.0
|
||||
print(
|
||||
f" Acceptance rate: {acceptance_rate:.3f} "
|
||||
f"(expected: {config.expected_acceptance_rate:.3f})"
|
||||
)
|
||||
rate_err = (
|
||||
abs(acceptance_rate - config.expected_acceptance_rate)
|
||||
/ config.expected_acceptance_rate
|
||||
)
|
||||
assert rate_err <= rtol, (
|
||||
f"Acceptance rate regression for {config.id}! "
|
||||
f"Expected: {config.expected_acceptance_rate:.3f}, "
|
||||
f"Got: {acceptance_rate:.3f}, "
|
||||
f"Relative error: {rate_err:.2%} (tolerance: {rtol:.0%})"
|
||||
)
|
||||
|
||||
print(
|
||||
f"\n=== PASS: {config.id} acceptance length {acceptance_length:.3f} "
|
||||
|
||||
Reference in New Issue
Block a user