[CI] Add MTP + PD disagg test for Qwen3.5 (#42677)

Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
This commit is contained in:
zhanqiuhu
2026-05-19 05:44:33 -04:00
committed by GitHub
parent ef54a4d604
commit 129019f334
4 changed files with 126 additions and 49 deletions
+1 -1
View File
@@ -88,7 +88,7 @@ steps:
- tests/v1/kv_connector/nixl_integration/
commands:
- uv pip install --system -r /vllm-workspace/requirements/kv_connectors.txt
- bash v1/kv_connector/nixl_integration/spec_decode_acceptance_test.sh
- bash v1/kv_connector/nixl_integration/config_sweep_spec_decode_test.sh
- label: MultiConnector (Nixl+Offloading) PD edge cases (2 GPUs)
key: multiconnector-nixl-offloading-pd-edge-cases-2-gpus
@@ -0,0 +1,31 @@
#!/usr/bin/env bash
set -euo pipefail
# Sweep wrapper for spec decode acceptance tests, following the same pattern
# as config_sweep_accuracy_test.sh. Runs spec_decode_acceptance_test.sh once
# per configuration.
SCRIPT="v1/kv_connector/nixl_integration/spec_decode_acceptance_test.sh"
# EAGLE3: Llama-3.1-8B-Instruct with EAGLE3 speculator.
eagle3_config="SD_METHOD=eagle3 MODEL_NAME=meta-llama/Llama-3.1-8B-Instruct SD_MODEL=RedHatAI/Llama-3.1-8B-Instruct-speculator.eagle3 NUM_SPEC_TOKENS=3"
# MTP: Qwen3.5-0.8B-Base with hybrid SSM flags.
mtp_config="SD_METHOD=mtp MODEL_NAME=Qwen/Qwen3.5-0.8B-Base SD_MODEL=Qwen/Qwen3.5-0.8B-Base NUM_SPEC_TOKENS=1 BLOCK_SIZE=32 MAX_MODEL_LEN=4096 VLLM_SSM_CONV_STATE_LAYOUT=DS ENABLE_HMA_FLAG=1 KV_BUFFER_DEVICES=cuda"
configs=(
"$eagle3_config"
"$mtp_config"
)
for cfg in "${configs[@]}"; do
local_cfg_parts=()
read -r -a local_cfg_parts <<< "$cfg"
echo "-> Running with: ${cfg}"
if ! env "${local_cfg_parts[@]}" bash "${SCRIPT}"; then
echo "❌ Test failed for config: ${cfg}"
exit 1
fi
done
echo "✅ All spec decode acceptance tests passed!"
@@ -26,7 +26,10 @@
# ROCm options: TRITON_ATTN, ROCM_ATTN, ROCM_AITER_FA,
# ROCM_AITER_UNIFIED_ATTN
# NVIDIA options: FLASH_ATTN, FLASHINFER
set -x
# VLLM_SSM_CONV_STATE_LAYOUT - SSM conv state layout (e.g. "DS" required for Mamba models)
# ENABLE_HMA_FLAG - set to 1 to enable hybrid KV cache manager
# VLLM_SERVE_EXTRA_ARGS - comma-separated extra args for vllm serve
set -ex
# ── Model & spec decode config ──────────────────────────────────────────
@@ -82,6 +85,20 @@ if [[ -z "${ATTENTION_BACKEND:-}" ]]; then
fi
echo "Using attention backend: ${ATTENTION_BACKEND}"
# ── HMA & extra serve args ────────────────────────────────────────────
ENABLE_HMA_VAR=""
if [[ -n "${ENABLE_HMA_FLAG:-}" ]]; then
ENABLE_HMA_VAR="--no-disable-hybrid-kv-cache-manager"
echo "HMA (Hybrid KV Cache Manager) enabled"
fi
EXTRA_SERVE_ARGS=()
if [[ -n "${VLLM_SERVE_EXTRA_ARGS:-}" ]]; then
IFS=',' read -r -a EXTRA_SERVE_ARGS <<< "$VLLM_SERVE_EXTRA_ARGS"
echo "Extra serve args: ${EXTRA_SERVE_ARGS[*]}"
fi
cleanup_instances() {
echo ""
echo "Cleaning up..."
@@ -228,6 +245,7 @@ run_test_for_device() {
${GPU_DEVICE_VAR}=$GPU_ID \
VLLM_KV_CACHE_LAYOUT='HND' \
UCX_NET_DEVICES=all \
${VLLM_SSM_CONV_STATE_LAYOUT:+VLLM_SSM_CONV_STATE_LAYOUT=$VLLM_SSM_CONV_STATE_LAYOUT} \
VLLM_NIXL_SIDE_CHANNEL_HOST=$NIXL_SIDE_CHANNEL_HOST \
VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT \
vllm serve $MODEL_NAME \
@@ -239,7 +257,9 @@ run_test_for_device() {
--tensor-parallel-size $PREFILLER_TP_SIZE \
--kv-transfer-config "$kv_config" \
--speculative-config "$PREFILL_SPEC_CONFIG" \
--attention-backend $ATTENTION_BACKEND &
--attention-backend $ATTENTION_BACKEND \
${ENABLE_HMA_VAR} \
${EXTRA_SERVE_ARGS[@]+"${EXTRA_SERVE_ARGS[@]}"} &
local SERVER_PID=$!
PREFILL_HOSTS+=("$SERVER_HOST")
@@ -265,6 +285,7 @@ run_test_for_device() {
${GPU_DEVICE_VAR}=$GPU_ID \
VLLM_KV_CACHE_LAYOUT='HND' \
UCX_NET_DEVICES=all \
${VLLM_SSM_CONV_STATE_LAYOUT:+VLLM_SSM_CONV_STATE_LAYOUT=$VLLM_SSM_CONV_STATE_LAYOUT} \
VLLM_NIXL_SIDE_CHANNEL_HOST=$NIXL_SIDE_CHANNEL_HOST \
VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT \
vllm serve $MODEL_NAME \
@@ -276,7 +297,9 @@ run_test_for_device() {
--tensor-parallel-size $DECODER_TP_SIZE \
--kv-transfer-config "$kv_config" \
--speculative-config "$DECODE_SPEC_CONFIG" \
--attention-backend $ATTENTION_BACKEND &
--attention-backend $ATTENTION_BACKEND \
${ENABLE_HMA_VAR} \
${EXTRA_SERVE_ARGS[@]+"${EXTRA_SERVE_ARGS[@]}"} &
local SERVER_PID=$!
DECODE_HOSTS+=("$SERVER_HOST")
@@ -303,6 +326,7 @@ run_test_for_device() {
DECODE_PORT=${DECODE_PORTS[0]} \
SERVER_HOST=$SERVER_HOST \
TEST_MODEL=$MODEL_NAME \
SD_METHOD=$SD_METHOD \
python3 -m pytest -s -x "${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/test_spec_decode_acceptance.py"
# Tear down before next iteration
@@ -1,19 +1,18 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""NixlConnector PD + EAGLE3 speculative decoding acceptance length test.
"""NixlConnector PD + speculative decoding acceptance length test.
- Loads MT-Bench prompts (80 prompts, 256 output tokens)
- Sends through the PD proxy (completions API)
- Scrapes Prometheus metrics from the decode server
- Asserts acceptance length matches standalone EAGLE3 baselines
- Asserts acceptance metrics match standalone baselines
Baselines from tests/v1/spec_decode/test_acceptance_length.py
(standalone EAGLE3 with same model/drafter on MT-Bench, temp=0).
PD disaggregation via NixlConnector should match within tolerance.
Supports EAGLE3 (default) and MTP, selected via SD_METHOD env var.
Environment variables (set by spec_decode_acceptance_test.sh):
TEST_MODEL - target model name
DECODE_PORT - port of the decode vLLM server (for /metrics)
SD_METHOD - "eagle3" (default) or "mtp"
"""
import os
@@ -31,28 +30,38 @@ SERVER_HOST = os.environ.get("SERVER_HOST", "127.0.0.1")
PROXY_BASE_URL = f"http://{SERVER_HOST}:8192/v1"
DECODE_PORT = os.environ.get("DECODE_PORT", "8200")
MODEL_NAME = os.environ.get("TEST_MODEL", "meta-llama/Llama-3.1-8B-Instruct")
SD_METHOD = os.environ.get("SD_METHOD", "eagle3").lower()
@dataclass
class Eagle3ModelConfig:
verifier: str
drafter: str
class ModelConfig:
model: str
method: str
expected_acceptance_length: float
drafter: str = ""
expected_acceptance_lengths_per_pos: list[float] = field(default_factory=list)
expected_acceptance_rate: float | None = None
id: str = ""
rtol: float | None = None
# Standalone EAGLE3 baselines (MT-Bench, 80 prompts, 256 tokens, temp=0).
# Source: tests/v1/spec_decode/test_acceptance_length.py
EAGLE3_MODEL_CONFIGS = [
Eagle3ModelConfig(
verifier="meta-llama/Llama-3.1-8B-Instruct",
# Standalone baselines (MT-Bench, 80 prompts, 256 tokens, temp=0).
# EAGLE3 source: tests/v1/spec_decode/test_acceptance_length.py
MODEL_CONFIGS = [
ModelConfig(
model="meta-llama/Llama-3.1-8B-Instruct",
method="eagle3",
drafter="RedHatAI/Llama-3.1-8B-Instruct-speculator.eagle3",
expected_acceptance_length=2.60,
expected_acceptance_lengths_per_pos=[0.7296, 0.5208, 0.3545],
id="llama3-8b-eagle3",
),
ModelConfig(
model="Qwen/Qwen3.5-0.8B-Base",
method="mtp",
expected_acceptance_length=1.798,
id="qwen35-0.8b-mtp",
),
]
DEFAULT_NUM_PROMPTS = 80
@@ -60,14 +69,14 @@ DEFAULT_OUTPUT_LEN = 256
DEFAULT_RTOL = 0.05
def _get_model_config() -> Eagle3ModelConfig:
"""Get the model config matching MODEL_NAME."""
for config in EAGLE3_MODEL_CONFIGS:
if config.verifier == MODEL_NAME:
def _get_model_config() -> ModelConfig:
"""Get the model config matching MODEL_NAME and SD_METHOD."""
for config in MODEL_CONFIGS:
if config.model == MODEL_NAME and config.method == SD_METHOD:
return config
raise ValueError(
f"No Eagle3ModelConfig found for model {MODEL_NAME}. "
f"Available: {[c.verifier for c in EAGLE3_MODEL_CONFIGS]}"
f"No config for model={MODEL_NAME}, method={SD_METHOD}. "
f"Available: {[(c.model, c.method) for c in MODEL_CONFIGS]}"
)
@@ -161,47 +170,60 @@ def test_spec_decode_acceptance_length():
assert n_drafts > 0, "No spec-decode drafts were generated"
acceptance_length = 1 + (n_accepted / n_drafts)
per_pos_counts = _fetch_per_position_acceptance()
per_pos_rates = [
per_pos_counts.get(i, 0) / n_drafts
for i in range(len(config.expected_acceptance_lengths_per_pos))
]
# ── Report ────────────────────────────────────────────────────────
expected = config.expected_acceptance_length
expected_per_pos = config.expected_acceptance_lengths_per_pos
print(
f"\n{config.id}: acceptance_length={acceptance_length:.3f} "
f"(expected={expected:.3f})"
)
print(f" Drafts: {n_drafts:.0f}, Accepted: {n_accepted:.0f}")
for i, (actual, exp) in enumerate(zip(per_pos_rates, expected_per_pos)):
print(f" Position {i}: {actual:.4f} (expected: {exp:.4f})")
# ── Assert overall acceptance length ──────────────────────────────
# ── Assert acceptance length (all methods) ────────────────────────
rel_error = abs(acceptance_length - expected) / expected
assert rel_error <= rtol, (
f"Acceptance length regression for {config.id}! "
f"Expected: {expected:.3f}, "
f"Got: {acceptance_length:.3f}, "
f"Relative error: {rel_error:.2%} (tolerance: {rtol:.0%}). "
f"This may indicate drafter KV was not correctly transferred."
f"Relative error: {rel_error:.2%} (tolerance: {rtol:.0%})"
)
# ── Assert per-position acceptance ────────────────────────────────
for i, (actual, exp) in enumerate(zip(per_pos_rates, expected_per_pos)):
if exp > 0:
pos_err = abs(actual - exp) / exp
assert pos_err <= rtol, (
f"Per-position acceptance regression at position {i} "
f"for {config.id}! "
f"Expected: {exp:.4f}, Got: {actual:.4f}, "
f"Relative error: {pos_err:.2%} "
f"(tolerance: {rtol:.0%})"
)
# ── Assert per-position acceptance (EAGLE3) ───────────────────────
if config.expected_acceptance_lengths_per_pos:
per_pos_counts = _fetch_per_position_acceptance()
per_pos_rates = [
per_pos_counts.get(i, 0) / n_drafts
for i in range(len(config.expected_acceptance_lengths_per_pos))
]
for i, (actual, exp) in enumerate(
zip(per_pos_rates, config.expected_acceptance_lengths_per_pos)
):
print(f" Position {i}: {actual:.4f} (expected: {exp:.4f})")
if exp > 0:
pos_err = abs(actual - exp) / exp
assert pos_err <= rtol, (
f"Per-position regression at pos {i} for {config.id}! "
f"Expected: {exp:.4f}, Got: {actual:.4f}, "
f"Relative error: {pos_err:.2%} (tolerance: {rtol:.0%})"
)
# ── Assert acceptance rate (MTP) ──────────────────────────────────
if config.expected_acceptance_rate is not None:
n_draft_tokens = _fetch_metric("vllm:spec_decode_num_draft_tokens_total")
acceptance_rate = n_accepted / n_draft_tokens if n_draft_tokens > 0 else 0.0
print(
f" Acceptance rate: {acceptance_rate:.3f} "
f"(expected: {config.expected_acceptance_rate:.3f})"
)
rate_err = (
abs(acceptance_rate - config.expected_acceptance_rate)
/ config.expected_acceptance_rate
)
assert rate_err <= rtol, (
f"Acceptance rate regression for {config.id}! "
f"Expected: {config.expected_acceptance_rate:.3f}, "
f"Got: {acceptance_rate:.3f}, "
f"Relative error: {rate_err:.2%} (tolerance: {rtol:.0%})"
)
print(
f"\n=== PASS: {config.id} acceptance length {acceptance_length:.3f} "