[Attention] Remove deprecated MLA prefill arguments (#42555)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Matthew Bonanni
2026-05-14 13:34:06 -04:00
committed by GitHub
parent ae4f59f0ec
commit 9898f94abe
5 changed files with 27 additions and 126 deletions
+14 -35
View File
@@ -29,6 +29,7 @@ from vllm.config import (
VllmConfig,
set_current_vllm_config,
)
from vllm.v1.attention.backends.mla.prefill.registry import MLAPrefillBackendEnum
# ============================================================================
# VllmConfig Creation
@@ -79,8 +80,8 @@ def create_minimal_vllm_config(
index_topk: Optional topk value for sparse MLA backends. If provided,
the config will include index_topk for sparse attention.
prefill_backend: Prefill backend name (e.g., "fa3", "fa4", "flashinfer",
"cudnn", "trtllm"). Configures the attention config to
force the specified prefill backend.
"trtllm"). Configures the attention config to force
the specified prefill backend.
Returns:
VllmConfig for benchmarking
@@ -179,27 +180,13 @@ def create_minimal_vllm_config(
if prefill_backend is not None:
prefill_cfg = get_prefill_backend_config(prefill_backend)
if prefill_cfg.get("mla_prefill_backend_enum") is not None:
# Registry-based backends bypass the deprecated boolean flags.
from vllm.v1.attention.backends.mla.prefill import MLAPrefillBackendEnum
vllm_config.attention_config.mla_prefill_backend = MLAPrefillBackendEnum[
prefill_cfg["mla_prefill_backend_enum"]
vllm_config.attention_config.mla_prefill_backend = prefill_cfg[
"mla_prefill_backend"
]
else:
if prefill_cfg["flash_attn_version"] is not None:
vllm_config.attention_config.flash_attn_version = prefill_cfg[
"flash_attn_version"
]
vllm_config.attention_config.disable_flashinfer_prefill = prefill_cfg[
"disable_flashinfer_prefill"
]
vllm_config.attention_config.use_cudnn_prefill = prefill_cfg[
"use_cudnn_prefill"
]
vllm_config.attention_config.use_trtllm_ragged_deepseek_prefill = (
prefill_cfg["use_trtllm_ragged_deepseek_prefill"]
)
return vllm_config
@@ -214,34 +201,27 @@ def create_minimal_vllm_config(
_PREFILL_BACKEND_CONFIG: dict[str, dict] = {
"fa2": {
"flash_attn_version": 2,
"disable_flashinfer_prefill": True,
"use_cudnn_prefill": False,
"use_trtllm_ragged_deepseek_prefill": False,
"mla_prefill_backend": MLAPrefillBackendEnum.FLASH_ATTN,
},
"fa3": {
"flash_attn_version": 3,
"disable_flashinfer_prefill": True,
"use_cudnn_prefill": False,
"use_trtllm_ragged_deepseek_prefill": False,
"mla_prefill_backend": MLAPrefillBackendEnum.FLASH_ATTN,
},
"fa4": {
"flash_attn_version": 4,
"disable_flashinfer_prefill": True,
"use_cudnn_prefill": False,
"use_trtllm_ragged_deepseek_prefill": False,
"mla_prefill_backend": MLAPrefillBackendEnum.FLASH_ATTN,
},
"flashinfer": {
"mla_prefill_backend_enum": "FLASHINFER",
},
"cudnn": {
# cuDNN prefill backend was removed; AttentionConfig raises on use.
"mla_prefill_backend_enum": "FLASHINFER",
"flash_attn_version": None,
"mla_prefill_backend": MLAPrefillBackendEnum.FLASHINFER,
},
"trtllm": {
"mla_prefill_backend_enum": "TRTLLM_RAGGED",
"flash_attn_version": None,
"mla_prefill_backend": MLAPrefillBackendEnum.TRTLLM_RAGGED,
},
"tokenspeed": {
"mla_prefill_backend_enum": "TOKENSPEED_MLA",
"flash_attn_version": None,
"mla_prefill_backend": MLAPrefillBackendEnum.TOKENSPEED_MLA,
},
}
@@ -1020,7 +1000,6 @@ def _run_mla_benchmark_batched(
f"version {fa_version}, got "
f"{actual_fa_version} on {actual_class}."
)
# Run each benchmark with the shared impl
for config, threshold, num_splits in configs_with_params:
# Set threshold for this benchmark (FlashAttn/FlashMLA only)
-12
View File
@@ -368,12 +368,8 @@ def test_attention_config():
"true",
"--attention-config.flash_attn_max_num_splits_for_cuda_graph",
"16",
"--attention-config.use_trtllm_ragged_deepseek_prefill",
"true",
"--attention-config.use_trtllm_attention",
"true",
"--attention-config.disable_flashinfer_prefill",
"true",
"--attention-config.disable_flashinfer_q_quantization",
"true",
]
@@ -385,9 +381,7 @@ def test_attention_config():
assert engine_args.attention_config.flash_attn_version == 3
assert engine_args.attention_config.use_prefill_decode_attention is True
assert engine_args.attention_config.flash_attn_max_num_splits_for_cuda_graph == 16
assert engine_args.attention_config.use_trtllm_ragged_deepseek_prefill is True
assert engine_args.attention_config.use_trtllm_attention is True
assert engine_args.attention_config.disable_flashinfer_prefill is True
assert engine_args.attention_config.disable_flashinfer_q_quantization is True
# set to string form of a dict with all fields
@@ -397,10 +391,7 @@ def test_attention_config():
'{"backend": "FLASHINFER", "flash_attn_version": 2, '
'"use_prefill_decode_attention": false, '
'"flash_attn_max_num_splits_for_cuda_graph": 8, '
'"use_cudnn_prefill": false, '
'"use_trtllm_ragged_deepseek_prefill": false, '
'"use_trtllm_attention": false, '
'"disable_flashinfer_prefill": false, '
'"disable_flashinfer_q_quantization": false}',
]
)
@@ -411,10 +402,7 @@ def test_attention_config():
assert engine_args.attention_config.flash_attn_version == 2
assert engine_args.attention_config.use_prefill_decode_attention is False
assert engine_args.attention_config.flash_attn_max_num_splits_for_cuda_graph == 8
assert engine_args.attention_config.use_cudnn_prefill is False
assert engine_args.attention_config.use_trtllm_ragged_deepseek_prefill is False
assert engine_args.attention_config.use_trtllm_attention is False
assert engine_args.attention_config.disable_flashinfer_prefill is False
assert engine_args.attention_config.disable_flashinfer_q_quantization is False
# test --attention-backend flows into VllmConfig.attention_config
@@ -269,36 +269,21 @@ class TestMLAPrefillBackendParsing:
)
class TestDeprecatedFlagMigration:
"""Tests for _migrate_deprecated_mla_prefill_flags in AttentionConfig."""
class TestMLAPrefillBackendConfig:
"""Tests for mla_prefill_backend configuration in AttentionConfig."""
def test_no_deprecated_flags_leaves_backend_none(self):
def test_default_backend_is_none(self):
config = AttentionConfig()
assert config.mla_prefill_backend is None
def test_use_trtllm_ragged_migrates_to_trtllm_ragged(self):
config = AttentionConfig(use_trtllm_ragged_deepseek_prefill=True)
assert config.mla_prefill_backend == MLAPrefillBackendEnum.TRTLLM_RAGGED
def test_disable_flashinfer_prefill_migrates_to_flash_attn(self):
config = AttentionConfig(disable_flashinfer_prefill=True)
assert config.mla_prefill_backend == MLAPrefillBackendEnum.FLASH_ATTN
def test_explicit_backend_ignores_deprecated_flags(self):
def test_explicit_flash_attn_backend(self):
config = AttentionConfig(
mla_prefill_backend=MLAPrefillBackendEnum.FLASH_ATTN,
use_cudnn_prefill=True,
)
assert config.mla_prefill_backend == MLAPrefillBackendEnum.FLASH_ATTN
def test_cudnn_raises_error(self):
match = "cuDNN MLA prefill backend has been removed"
with pytest.raises(ValueError, match=match):
AttentionConfig(use_cudnn_prefill=True)
def test_trtllm_takes_priority_over_disable_flashinfer(self):
def test_explicit_trtllm_ragged_backend(self):
config = AttentionConfig(
use_trtllm_ragged_deepseek_prefill=True,
disable_flashinfer_prefill=True,
mla_prefill_backend=MLAPrefillBackendEnum.TRTLLM_RAGGED,
)
assert config.mla_prefill_backend == MLAPrefillBackendEnum.TRTLLM_RAGGED
+1 -52
View File
@@ -6,12 +6,9 @@ from typing import Any, Literal
from pydantic import field_validator
from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.prefill.registry import MLAPrefillBackendEnum
from vllm.v1.attention.backends.registry import AttentionBackendEnum
logger = init_logger(__name__)
@config
class AttentionConfig:
@@ -36,27 +33,16 @@ class AttentionConfig:
Fixes the split count so grid dimensions are constant across captures,
and buffers can be pre-allocated to avoid inflating the memory estimate."""
use_cudnn_prefill: bool = False
"""Deprecated: cuDNN prefill backend has been removed."""
use_trtllm_ragged_deepseek_prefill: bool = False
"""Whether to use TRTLLM ragged deepseek prefill."""
use_trtllm_attention: bool | None = None
"""If set to True/False, use or don't use the TRTLLM attention backend
in flashinfer. If None, auto-detect the attention backend in flashinfer."""
disable_flashinfer_prefill: bool | None = None
"""Whether to disable flashinfer prefill."""
disable_flashinfer_q_quantization: bool = False
"""If set, when using fp8 kv, do not quantize Q to fp8."""
mla_prefill_backend: MLAPrefillBackendEnum | None = None
"""MLA prefill backend to use. If None, will be selected automatically.
Valid options: FLASH_ATTN (FA3/FA4), FLASHINFER, TRTLLM_RAGGED.
This option supersedes use_trtllm_ragged_deepseek_prefill
and disable_flashinfer_prefill which are deprecated."""
Valid options: FLASH_ATTN (FA3/FA4), FLASHINFER, TRTLLM_RAGGED."""
use_prefill_query_quantization: bool = False
"""If set, quantize query for attention in prefill."""
@@ -123,40 +109,3 @@ class AttentionConfig:
if isinstance(value, str):
return MLAPrefillBackendEnum[value.upper()]
return value
def __post_init__(self) -> None:
self._migrate_deprecated_mla_prefill_flags()
def _migrate_deprecated_mla_prefill_flags(self) -> None:
"""Migrate deprecated MLA prefill flags to mla_prefill_backend."""
# If the new option is already set, it takes precedence
if self.mla_prefill_backend is not None:
return
# Check for deprecated flags and migrate them.
# Only the first flag encountered sets the backend.
if self.use_cudnn_prefill:
raise ValueError(
"The cuDNN MLA prefill backend has been removed. "
"Use --attention-config.mla_prefill_backend=FLASH_ATTN or "
"FLASHINFER or TRTLLM_RAGGED instead."
)
if self.use_trtllm_ragged_deepseek_prefill:
if self.mla_prefill_backend is None:
self.mla_prefill_backend = MLAPrefillBackendEnum.TRTLLM_RAGGED
logger.warning_once(
"use_trtllm_ragged_deepseek_prefill is deprecated and "
"will be removed in v0.22. Use "
"--attention-config.mla_prefill_backend=TRTLLM_RAGGED "
"instead."
)
if self.disable_flashinfer_prefill:
if self.mla_prefill_backend is None:
self.mla_prefill_backend = MLAPrefillBackendEnum.FLASH_ATTN
logger.warning_once(
"disable_flashinfer_prefill is deprecated and will be removed "
"in v0.22. Use --attention-config.mla_prefill_backend="
"FLASH_ATTN instead."
)
@@ -779,14 +779,14 @@ class MLAAttention(nn.Module, AttentionLayerBase):
attn_out,
lse,
get_dcp_group(),
is_lse_base_on_e=not getattr(self, "_use_fi_prefill", False),
is_lse_base_on_e=True,
)
else:
attn_out = cp_lse_ag_out_rs(
attn_out,
lse,
get_dcp_group(),
is_lse_base_on_e=not getattr(self, "_use_fi_prefill", False),
is_lse_base_on_e=True,
)
# v_up projection