[Attention] Remove deprecated MLA prefill arguments (#42555)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Matthew Bonanni
2026-05-14 13:34:06 -04:00
committed by GitHub
parent ae4f59f0ec
commit 9898f94abe
5 changed files with 27 additions and 126 deletions
+18 -39
View File
@@ -29,6 +29,7 @@ from vllm.config import (
VllmConfig, VllmConfig,
set_current_vllm_config, set_current_vllm_config,
) )
from vllm.v1.attention.backends.mla.prefill.registry import MLAPrefillBackendEnum
# ============================================================================ # ============================================================================
# VllmConfig Creation # VllmConfig Creation
@@ -79,8 +80,8 @@ def create_minimal_vllm_config(
index_topk: Optional topk value for sparse MLA backends. If provided, index_topk: Optional topk value for sparse MLA backends. If provided,
the config will include index_topk for sparse attention. the config will include index_topk for sparse attention.
prefill_backend: Prefill backend name (e.g., "fa3", "fa4", "flashinfer", prefill_backend: Prefill backend name (e.g., "fa3", "fa4", "flashinfer",
"cudnn", "trtllm"). Configures the attention config to "trtllm"). Configures the attention config to force
force the specified prefill backend. the specified prefill backend.
Returns: Returns:
VllmConfig for benchmarking VllmConfig for benchmarking
@@ -179,27 +180,13 @@ def create_minimal_vllm_config(
if prefill_backend is not None: if prefill_backend is not None:
prefill_cfg = get_prefill_backend_config(prefill_backend) prefill_cfg = get_prefill_backend_config(prefill_backend)
if prefill_cfg.get("mla_prefill_backend_enum") is not None: vllm_config.attention_config.mla_prefill_backend = prefill_cfg[
# Registry-based backends bypass the deprecated boolean flags. "mla_prefill_backend"
from vllm.v1.attention.backends.mla.prefill import MLAPrefillBackendEnum ]
if prefill_cfg["flash_attn_version"] is not None:
vllm_config.attention_config.mla_prefill_backend = MLAPrefillBackendEnum[ vllm_config.attention_config.flash_attn_version = prefill_cfg[
prefill_cfg["mla_prefill_backend_enum"] "flash_attn_version"
] ]
else:
if prefill_cfg["flash_attn_version"] is not None:
vllm_config.attention_config.flash_attn_version = prefill_cfg[
"flash_attn_version"
]
vllm_config.attention_config.disable_flashinfer_prefill = prefill_cfg[
"disable_flashinfer_prefill"
]
vllm_config.attention_config.use_cudnn_prefill = prefill_cfg[
"use_cudnn_prefill"
]
vllm_config.attention_config.use_trtllm_ragged_deepseek_prefill = (
prefill_cfg["use_trtllm_ragged_deepseek_prefill"]
)
return vllm_config return vllm_config
@@ -214,34 +201,27 @@ def create_minimal_vllm_config(
_PREFILL_BACKEND_CONFIG: dict[str, dict] = { _PREFILL_BACKEND_CONFIG: dict[str, dict] = {
"fa2": { "fa2": {
"flash_attn_version": 2, "flash_attn_version": 2,
"disable_flashinfer_prefill": True, "mla_prefill_backend": MLAPrefillBackendEnum.FLASH_ATTN,
"use_cudnn_prefill": False,
"use_trtllm_ragged_deepseek_prefill": False,
}, },
"fa3": { "fa3": {
"flash_attn_version": 3, "flash_attn_version": 3,
"disable_flashinfer_prefill": True, "mla_prefill_backend": MLAPrefillBackendEnum.FLASH_ATTN,
"use_cudnn_prefill": False,
"use_trtllm_ragged_deepseek_prefill": False,
}, },
"fa4": { "fa4": {
"flash_attn_version": 4, "flash_attn_version": 4,
"disable_flashinfer_prefill": True, "mla_prefill_backend": MLAPrefillBackendEnum.FLASH_ATTN,
"use_cudnn_prefill": False,
"use_trtllm_ragged_deepseek_prefill": False,
}, },
"flashinfer": { "flashinfer": {
"mla_prefill_backend_enum": "FLASHINFER", "flash_attn_version": None,
}, "mla_prefill_backend": MLAPrefillBackendEnum.FLASHINFER,
"cudnn": {
# cuDNN prefill backend was removed; AttentionConfig raises on use.
"mla_prefill_backend_enum": "FLASHINFER",
}, },
"trtllm": { "trtllm": {
"mla_prefill_backend_enum": "TRTLLM_RAGGED", "flash_attn_version": None,
"mla_prefill_backend": MLAPrefillBackendEnum.TRTLLM_RAGGED,
}, },
"tokenspeed": { "tokenspeed": {
"mla_prefill_backend_enum": "TOKENSPEED_MLA", "flash_attn_version": None,
"mla_prefill_backend": MLAPrefillBackendEnum.TOKENSPEED_MLA,
}, },
} }
@@ -1020,7 +1000,6 @@ def _run_mla_benchmark_batched(
f"version {fa_version}, got " f"version {fa_version}, got "
f"{actual_fa_version} on {actual_class}." f"{actual_fa_version} on {actual_class}."
) )
# Run each benchmark with the shared impl # Run each benchmark with the shared impl
for config, threshold, num_splits in configs_with_params: for config, threshold, num_splits in configs_with_params:
# Set threshold for this benchmark (FlashAttn/FlashMLA only) # Set threshold for this benchmark (FlashAttn/FlashMLA only)
-12
View File
@@ -368,12 +368,8 @@ def test_attention_config():
"true", "true",
"--attention-config.flash_attn_max_num_splits_for_cuda_graph", "--attention-config.flash_attn_max_num_splits_for_cuda_graph",
"16", "16",
"--attention-config.use_trtllm_ragged_deepseek_prefill",
"true",
"--attention-config.use_trtllm_attention", "--attention-config.use_trtllm_attention",
"true", "true",
"--attention-config.disable_flashinfer_prefill",
"true",
"--attention-config.disable_flashinfer_q_quantization", "--attention-config.disable_flashinfer_q_quantization",
"true", "true",
] ]
@@ -385,9 +381,7 @@ def test_attention_config():
assert engine_args.attention_config.flash_attn_version == 3 assert engine_args.attention_config.flash_attn_version == 3
assert engine_args.attention_config.use_prefill_decode_attention is True assert engine_args.attention_config.use_prefill_decode_attention is True
assert engine_args.attention_config.flash_attn_max_num_splits_for_cuda_graph == 16 assert engine_args.attention_config.flash_attn_max_num_splits_for_cuda_graph == 16
assert engine_args.attention_config.use_trtllm_ragged_deepseek_prefill is True
assert engine_args.attention_config.use_trtllm_attention is True assert engine_args.attention_config.use_trtllm_attention is True
assert engine_args.attention_config.disable_flashinfer_prefill is True
assert engine_args.attention_config.disable_flashinfer_q_quantization is True assert engine_args.attention_config.disable_flashinfer_q_quantization is True
# set to string form of a dict with all fields # set to string form of a dict with all fields
@@ -397,10 +391,7 @@ def test_attention_config():
'{"backend": "FLASHINFER", "flash_attn_version": 2, ' '{"backend": "FLASHINFER", "flash_attn_version": 2, '
'"use_prefill_decode_attention": false, ' '"use_prefill_decode_attention": false, '
'"flash_attn_max_num_splits_for_cuda_graph": 8, ' '"flash_attn_max_num_splits_for_cuda_graph": 8, '
'"use_cudnn_prefill": false, '
'"use_trtllm_ragged_deepseek_prefill": false, '
'"use_trtllm_attention": false, ' '"use_trtllm_attention": false, '
'"disable_flashinfer_prefill": false, '
'"disable_flashinfer_q_quantization": false}', '"disable_flashinfer_q_quantization": false}',
] ]
) )
@@ -411,10 +402,7 @@ def test_attention_config():
assert engine_args.attention_config.flash_attn_version == 2 assert engine_args.attention_config.flash_attn_version == 2
assert engine_args.attention_config.use_prefill_decode_attention is False assert engine_args.attention_config.use_prefill_decode_attention is False
assert engine_args.attention_config.flash_attn_max_num_splits_for_cuda_graph == 8 assert engine_args.attention_config.flash_attn_max_num_splits_for_cuda_graph == 8
assert engine_args.attention_config.use_cudnn_prefill is False
assert engine_args.attention_config.use_trtllm_ragged_deepseek_prefill is False
assert engine_args.attention_config.use_trtllm_attention is False assert engine_args.attention_config.use_trtllm_attention is False
assert engine_args.attention_config.disable_flashinfer_prefill is False
assert engine_args.attention_config.disable_flashinfer_q_quantization is False assert engine_args.attention_config.disable_flashinfer_q_quantization is False
# test --attention-backend flows into VllmConfig.attention_config # test --attention-backend flows into VllmConfig.attention_config
@@ -269,36 +269,21 @@ class TestMLAPrefillBackendParsing:
) )
class TestDeprecatedFlagMigration: class TestMLAPrefillBackendConfig:
"""Tests for _migrate_deprecated_mla_prefill_flags in AttentionConfig.""" """Tests for mla_prefill_backend configuration in AttentionConfig."""
def test_no_deprecated_flags_leaves_backend_none(self): def test_default_backend_is_none(self):
config = AttentionConfig() config = AttentionConfig()
assert config.mla_prefill_backend is None assert config.mla_prefill_backend is None
def test_use_trtllm_ragged_migrates_to_trtllm_ragged(self): def test_explicit_flash_attn_backend(self):
config = AttentionConfig(use_trtllm_ragged_deepseek_prefill=True)
assert config.mla_prefill_backend == MLAPrefillBackendEnum.TRTLLM_RAGGED
def test_disable_flashinfer_prefill_migrates_to_flash_attn(self):
config = AttentionConfig(disable_flashinfer_prefill=True)
assert config.mla_prefill_backend == MLAPrefillBackendEnum.FLASH_ATTN
def test_explicit_backend_ignores_deprecated_flags(self):
config = AttentionConfig( config = AttentionConfig(
mla_prefill_backend=MLAPrefillBackendEnum.FLASH_ATTN, mla_prefill_backend=MLAPrefillBackendEnum.FLASH_ATTN,
use_cudnn_prefill=True,
) )
assert config.mla_prefill_backend == MLAPrefillBackendEnum.FLASH_ATTN assert config.mla_prefill_backend == MLAPrefillBackendEnum.FLASH_ATTN
def test_cudnn_raises_error(self): def test_explicit_trtllm_ragged_backend(self):
match = "cuDNN MLA prefill backend has been removed"
with pytest.raises(ValueError, match=match):
AttentionConfig(use_cudnn_prefill=True)
def test_trtllm_takes_priority_over_disable_flashinfer(self):
config = AttentionConfig( config = AttentionConfig(
use_trtllm_ragged_deepseek_prefill=True, mla_prefill_backend=MLAPrefillBackendEnum.TRTLLM_RAGGED,
disable_flashinfer_prefill=True,
) )
assert config.mla_prefill_backend == MLAPrefillBackendEnum.TRTLLM_RAGGED assert config.mla_prefill_backend == MLAPrefillBackendEnum.TRTLLM_RAGGED
+1 -52
View File
@@ -6,12 +6,9 @@ from typing import Any, Literal
from pydantic import field_validator from pydantic import field_validator
from vllm.config.utils import config from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.prefill.registry import MLAPrefillBackendEnum from vllm.v1.attention.backends.mla.prefill.registry import MLAPrefillBackendEnum
from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.v1.attention.backends.registry import AttentionBackendEnum
logger = init_logger(__name__)
@config @config
class AttentionConfig: class AttentionConfig:
@@ -36,27 +33,16 @@ class AttentionConfig:
Fixes the split count so grid dimensions are constant across captures, Fixes the split count so grid dimensions are constant across captures,
and buffers can be pre-allocated to avoid inflating the memory estimate.""" and buffers can be pre-allocated to avoid inflating the memory estimate."""
use_cudnn_prefill: bool = False
"""Deprecated: cuDNN prefill backend has been removed."""
use_trtllm_ragged_deepseek_prefill: bool = False
"""Whether to use TRTLLM ragged deepseek prefill."""
use_trtllm_attention: bool | None = None use_trtllm_attention: bool | None = None
"""If set to True/False, use or don't use the TRTLLM attention backend """If set to True/False, use or don't use the TRTLLM attention backend
in flashinfer. If None, auto-detect the attention backend in flashinfer.""" in flashinfer. If None, auto-detect the attention backend in flashinfer."""
disable_flashinfer_prefill: bool | None = None
"""Whether to disable flashinfer prefill."""
disable_flashinfer_q_quantization: bool = False disable_flashinfer_q_quantization: bool = False
"""If set, when using fp8 kv, do not quantize Q to fp8.""" """If set, when using fp8 kv, do not quantize Q to fp8."""
mla_prefill_backend: MLAPrefillBackendEnum | None = None mla_prefill_backend: MLAPrefillBackendEnum | None = None
"""MLA prefill backend to use. If None, will be selected automatically. """MLA prefill backend to use. If None, will be selected automatically.
Valid options: FLASH_ATTN (FA3/FA4), FLASHINFER, TRTLLM_RAGGED. Valid options: FLASH_ATTN (FA3/FA4), FLASHINFER, TRTLLM_RAGGED."""
This option supersedes use_trtllm_ragged_deepseek_prefill
and disable_flashinfer_prefill which are deprecated."""
use_prefill_query_quantization: bool = False use_prefill_query_quantization: bool = False
"""If set, quantize query for attention in prefill.""" """If set, quantize query for attention in prefill."""
@@ -123,40 +109,3 @@ class AttentionConfig:
if isinstance(value, str): if isinstance(value, str):
return MLAPrefillBackendEnum[value.upper()] return MLAPrefillBackendEnum[value.upper()]
return value return value
def __post_init__(self) -> None:
self._migrate_deprecated_mla_prefill_flags()
def _migrate_deprecated_mla_prefill_flags(self) -> None:
"""Migrate deprecated MLA prefill flags to mla_prefill_backend."""
# If the new option is already set, it takes precedence
if self.mla_prefill_backend is not None:
return
# Check for deprecated flags and migrate them.
# Only the first flag encountered sets the backend.
if self.use_cudnn_prefill:
raise ValueError(
"The cuDNN MLA prefill backend has been removed. "
"Use --attention-config.mla_prefill_backend=FLASH_ATTN or "
"FLASHINFER or TRTLLM_RAGGED instead."
)
if self.use_trtllm_ragged_deepseek_prefill:
if self.mla_prefill_backend is None:
self.mla_prefill_backend = MLAPrefillBackendEnum.TRTLLM_RAGGED
logger.warning_once(
"use_trtllm_ragged_deepseek_prefill is deprecated and "
"will be removed in v0.22. Use "
"--attention-config.mla_prefill_backend=TRTLLM_RAGGED "
"instead."
)
if self.disable_flashinfer_prefill:
if self.mla_prefill_backend is None:
self.mla_prefill_backend = MLAPrefillBackendEnum.FLASH_ATTN
logger.warning_once(
"disable_flashinfer_prefill is deprecated and will be removed "
"in v0.22. Use --attention-config.mla_prefill_backend="
"FLASH_ATTN instead."
)
@@ -779,14 +779,14 @@ class MLAAttention(nn.Module, AttentionLayerBase):
attn_out, attn_out,
lse, lse,
get_dcp_group(), get_dcp_group(),
is_lse_base_on_e=not getattr(self, "_use_fi_prefill", False), is_lse_base_on_e=True,
) )
else: else:
attn_out = cp_lse_ag_out_rs( attn_out = cp_lse_ag_out_rs(
attn_out, attn_out,
lse, lse,
get_dcp_group(), get_dcp_group(),
is_lse_base_on_e=not getattr(self, "_use_fi_prefill", False), is_lse_base_on_e=True,
) )
# v_up projection # v_up projection