[ROCm][CI] Fix and stabilize EAGLE3 acceptance tests (#41294)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Signed-off-by: Micah Williamson <micah.williamson@amd.com>
Co-authored-by: Micah Williamson <micah.williamson@amd.com>
This commit is contained in:
Andreas Karatzas
2026-06-01 12:40:01 -05:00
committed by GitHub
parent 035733515f
commit fd9e91d7e4
+30 -14
View File
@@ -39,6 +39,8 @@ class Eagle3ModelConfig:
marks: list = field(default_factory=list)
# Custom relative tolerance (defaults to DEFAULT_RTOL if None)
rtol: float | None = None
# ROCm-specific test configuration
rocm_expected_acceptance_lengths_per_pos: list[float] = field(default_factory=list)
# Model configurations for EAGLE3 acceptance length tests.
@@ -69,6 +71,7 @@ EAGLE3_MODEL_CONFIGS = [
# FLASHINFER incompatible: gpt-oss-20b uses sink attention which
# FLASHINFER does not support ("sink setting not supported")
excluded_backends={AttentionBackendEnum.FLASHINFER},
rocm_expected_acceptance_lengths_per_pos=[0.7040, 0.4820, 0.3350],
),
Eagle3ModelConfig(
verifier="Qwen/Qwen3-VL-30B-A3B-Instruct-FP8",
@@ -99,16 +102,14 @@ EXCLUDED_BACKENDS = {AttentionBackendEnum.FLEX_ATTENTION}
def get_available_attention_backends() -> list[str]:
if current_platform.is_rocm():
return ["auto"]
# Check if get_valid_backends is actually defined in the platform class
# (not just returning None from __getattr__)
get_valid_backends = getattr(current_platform.__class__, "get_valid_backends", None)
if get_valid_backends is None:
if current_platform.is_rocm():
# ROCm uses Triton as its default attention backend since
# Flash Attention is not supported.
return ["TRITON_ATTN"]
else:
return ["FLASH_ATTN"]
return ["FLASH_ATTN"]
device_capability = current_platform.get_device_capability()
if device_capability is None:
@@ -167,6 +168,8 @@ def get_mt_bench_prompts(
disable_shuffle=False,
skip_chat_template=False,
trust_remote_code=False,
enable_multimodal_chat=False,
request_id_prefix="",
)
samples = get_samples(args, tokenizer)
prompt_ids = [
@@ -233,9 +236,12 @@ def test_eagle3_acceptance_length(
monkeypatch: pytest.MonkeyPatch,
):
# Skip if this backend is incompatible with the model
backend_enum = AttentionBackendEnum[attention_backend]
if backend_enum in model_config.excluded_backends:
pytest.skip(f"{attention_backend} is incompatible with {model_config.id}")
attention_config = None
if attention_backend != "auto":
backend_enum = AttentionBackendEnum[attention_backend]
if backend_enum in model_config.excluded_backends:
pytest.skip(f"{attention_backend} is incompatible with {model_config.id}")
attention_config = {"backend": attention_backend}
with monkeypatch.context() as m:
m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
@@ -247,11 +253,16 @@ def test_eagle3_acceptance_length(
"model": model_config.drafter,
"num_speculative_tokens": num_spec_tokens,
},
attention_config={"backend": attention_backend},
attention_config=attention_config,
tensor_parallel_size=tp_size,
gpu_memory_utilization=0.7,
disable_log_stats=False,
max_model_len=DEFAULT_MAX_MODEL_LEN,
# Qwen/Qwen3-30B-A3B-FP8 with TP=4 needs EP
# https://github.com/vllm-project/vllm/issues/25292
enable_expert_parallel=(
tp_size == 4 and "Qwen3-VL" in model_config.verifier
),
) as vllm_runner:
tokenizer = vllm_runner.llm.get_tokenizer()
prompt_ids = get_mt_bench_prompts(tokenizer, DEFAULT_NUM_PROMPTS)
@@ -272,6 +283,11 @@ def test_eagle3_acceptance_length(
expected = model_config.expected_acceptance_length
actual_per_pos = results["acceptance_lengths_per_pos"]
expected_per_pos = model_config.expected_acceptance_lengths_per_pos
if (
current_platform.is_rocm()
and model_config.rocm_expected_acceptance_lengths_per_pos
):
expected_per_pos = model_config.rocm_expected_acceptance_lengths_per_pos
rel_error = abs(actual_acceptance_length - expected) / expected
@@ -294,14 +310,14 @@ def test_eagle3_acceptance_length(
zip(actual_per_pos, expected_per_pos)
):
if exp > 0:
pos_rel_error = abs(actual - exp) / exp
assert pos_rel_error <= rtol, (
min_expected = exp * (1 - rtol)
assert actual >= min_expected, (
f"Per-position acceptance length regression at pos {pos} "
f"for {model_config.id}!\n"
f" Expected: {exp:.3f}\n"
f" Actual: {actual:.3f}\n"
f" Relative error: {pos_rel_error:.2%} "
f"(tolerance: {rtol:.2%})"
f" Minimum: {min_expected:.3f}\n"
f" Tolerance: rtol={rtol:.2%}"
)
print(