From fd9e91d7e4116c9f3d1a3fc237677c925bf9d6d9 Mon Sep 17 00:00:00 2001 From: Andreas Karatzas Date: Mon, 1 Jun 2026 12:40:01 -0500 Subject: [PATCH] [ROCm][CI] Fix and stabilize EAGLE3 acceptance tests (#41294) Signed-off-by: Andreas Karatzas Signed-off-by: Micah Williamson Co-authored-by: Micah Williamson --- .../v1/spec_decode/test_acceptance_length.py | 44 +++++++++++++------ 1 file changed, 30 insertions(+), 14 deletions(-) diff --git a/tests/v1/spec_decode/test_acceptance_length.py b/tests/v1/spec_decode/test_acceptance_length.py index 62ff100fdbf..90e3821e2f1 100644 --- a/tests/v1/spec_decode/test_acceptance_length.py +++ b/tests/v1/spec_decode/test_acceptance_length.py @@ -39,6 +39,8 @@ class Eagle3ModelConfig: marks: list = field(default_factory=list) # Custom relative tolerance (defaults to DEFAULT_RTOL if None) rtol: float | None = None + # ROCm-specific test configuration + rocm_expected_acceptance_lengths_per_pos: list[float] = field(default_factory=list) # Model configurations for EAGLE3 acceptance length tests. @@ -69,6 +71,7 @@ EAGLE3_MODEL_CONFIGS = [ # FLASHINFER incompatible: gpt-oss-20b uses sink attention which # FLASHINFER does not support ("sink setting not supported") excluded_backends={AttentionBackendEnum.FLASHINFER}, + rocm_expected_acceptance_lengths_per_pos=[0.7040, 0.4820, 0.3350], ), Eagle3ModelConfig( verifier="Qwen/Qwen3-VL-30B-A3B-Instruct-FP8", @@ -99,16 +102,14 @@ EXCLUDED_BACKENDS = {AttentionBackendEnum.FLEX_ATTENTION} def get_available_attention_backends() -> list[str]: + if current_platform.is_rocm(): + return ["auto"] + # Check if get_valid_backends is actually defined in the platform class # (not just returning None from __getattr__) get_valid_backends = getattr(current_platform.__class__, "get_valid_backends", None) if get_valid_backends is None: - if current_platform.is_rocm(): - # ROCm uses Triton as its default attention backend since - # Flash Attention is not supported. - return ["TRITON_ATTN"] - else: - return ["FLASH_ATTN"] + return ["FLASH_ATTN"] device_capability = current_platform.get_device_capability() if device_capability is None: @@ -167,6 +168,8 @@ def get_mt_bench_prompts( disable_shuffle=False, skip_chat_template=False, trust_remote_code=False, + enable_multimodal_chat=False, + request_id_prefix="", ) samples = get_samples(args, tokenizer) prompt_ids = [ @@ -233,9 +236,12 @@ def test_eagle3_acceptance_length( monkeypatch: pytest.MonkeyPatch, ): # Skip if this backend is incompatible with the model - backend_enum = AttentionBackendEnum[attention_backend] - if backend_enum in model_config.excluded_backends: - pytest.skip(f"{attention_backend} is incompatible with {model_config.id}") + attention_config = None + if attention_backend != "auto": + backend_enum = AttentionBackendEnum[attention_backend] + if backend_enum in model_config.excluded_backends: + pytest.skip(f"{attention_backend} is incompatible with {model_config.id}") + attention_config = {"backend": attention_backend} with monkeypatch.context() as m: m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") @@ -247,11 +253,16 @@ def test_eagle3_acceptance_length( "model": model_config.drafter, "num_speculative_tokens": num_spec_tokens, }, - attention_config={"backend": attention_backend}, + attention_config=attention_config, tensor_parallel_size=tp_size, gpu_memory_utilization=0.7, disable_log_stats=False, max_model_len=DEFAULT_MAX_MODEL_LEN, + # Qwen/Qwen3-30B-A3B-FP8 with TP=4 needs EP + # https://github.com/vllm-project/vllm/issues/25292 + enable_expert_parallel=( + tp_size == 4 and "Qwen3-VL" in model_config.verifier + ), ) as vllm_runner: tokenizer = vllm_runner.llm.get_tokenizer() prompt_ids = get_mt_bench_prompts(tokenizer, DEFAULT_NUM_PROMPTS) @@ -272,6 +283,11 @@ def test_eagle3_acceptance_length( expected = model_config.expected_acceptance_length actual_per_pos = results["acceptance_lengths_per_pos"] expected_per_pos = model_config.expected_acceptance_lengths_per_pos + if ( + current_platform.is_rocm() + and model_config.rocm_expected_acceptance_lengths_per_pos + ): + expected_per_pos = model_config.rocm_expected_acceptance_lengths_per_pos rel_error = abs(actual_acceptance_length - expected) / expected @@ -294,14 +310,14 @@ def test_eagle3_acceptance_length( zip(actual_per_pos, expected_per_pos) ): if exp > 0: - pos_rel_error = abs(actual - exp) / exp - assert pos_rel_error <= rtol, ( + min_expected = exp * (1 - rtol) + assert actual >= min_expected, ( f"Per-position acceptance length regression at pos {pos} " f"for {model_config.id}!\n" f" Expected: {exp:.3f}\n" f" Actual: {actual:.3f}\n" - f" Relative error: {pos_rel_error:.2%} " - f"(tolerance: {rtol:.2%})" + f" Minimum: {min_expected:.3f}\n" + f" Tolerance: rtol={rtol:.2%}" ) print(