From 9898f94abe005f675e0a7f40b0ae891afb5cf992 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Thu, 14 May 2026 13:34:06 -0400 Subject: [PATCH] [Attention] Remove deprecated MLA prefill arguments (#42555) Signed-off-by: Matthew Bonanni --- benchmarks/attention_benchmarks/mla_runner.py | 57 ++++++------------- tests/engine/test_arg_utils.py | 12 ---- .../v1/attention/test_mla_prefill_selector.py | 27 ++------- vllm/config/attention.py | 53 +---------------- .../layers/attention/mla_attention.py | 4 +- 5 files changed, 27 insertions(+), 126 deletions(-) diff --git a/benchmarks/attention_benchmarks/mla_runner.py b/benchmarks/attention_benchmarks/mla_runner.py index e552af01e73..abab1e2edba 100644 --- a/benchmarks/attention_benchmarks/mla_runner.py +++ b/benchmarks/attention_benchmarks/mla_runner.py @@ -29,6 +29,7 @@ from vllm.config import ( VllmConfig, set_current_vllm_config, ) +from vllm.v1.attention.backends.mla.prefill.registry import MLAPrefillBackendEnum # ============================================================================ # VllmConfig Creation @@ -79,8 +80,8 @@ def create_minimal_vllm_config( index_topk: Optional topk value for sparse MLA backends. If provided, the config will include index_topk for sparse attention. prefill_backend: Prefill backend name (e.g., "fa3", "fa4", "flashinfer", - "cudnn", "trtllm"). Configures the attention config to - force the specified prefill backend. + "trtllm"). Configures the attention config to force + the specified prefill backend. Returns: VllmConfig for benchmarking @@ -179,27 +180,13 @@ def create_minimal_vllm_config( if prefill_backend is not None: prefill_cfg = get_prefill_backend_config(prefill_backend) - if prefill_cfg.get("mla_prefill_backend_enum") is not None: - # Registry-based backends bypass the deprecated boolean flags. - from vllm.v1.attention.backends.mla.prefill import MLAPrefillBackendEnum - - vllm_config.attention_config.mla_prefill_backend = MLAPrefillBackendEnum[ - prefill_cfg["mla_prefill_backend_enum"] + vllm_config.attention_config.mla_prefill_backend = prefill_cfg[ + "mla_prefill_backend" + ] + if prefill_cfg["flash_attn_version"] is not None: + vllm_config.attention_config.flash_attn_version = prefill_cfg[ + "flash_attn_version" ] - else: - if prefill_cfg["flash_attn_version"] is not None: - vllm_config.attention_config.flash_attn_version = prefill_cfg[ - "flash_attn_version" - ] - vllm_config.attention_config.disable_flashinfer_prefill = prefill_cfg[ - "disable_flashinfer_prefill" - ] - vllm_config.attention_config.use_cudnn_prefill = prefill_cfg[ - "use_cudnn_prefill" - ] - vllm_config.attention_config.use_trtllm_ragged_deepseek_prefill = ( - prefill_cfg["use_trtllm_ragged_deepseek_prefill"] - ) return vllm_config @@ -214,34 +201,27 @@ def create_minimal_vllm_config( _PREFILL_BACKEND_CONFIG: dict[str, dict] = { "fa2": { "flash_attn_version": 2, - "disable_flashinfer_prefill": True, - "use_cudnn_prefill": False, - "use_trtllm_ragged_deepseek_prefill": False, + "mla_prefill_backend": MLAPrefillBackendEnum.FLASH_ATTN, }, "fa3": { "flash_attn_version": 3, - "disable_flashinfer_prefill": True, - "use_cudnn_prefill": False, - "use_trtllm_ragged_deepseek_prefill": False, + "mla_prefill_backend": MLAPrefillBackendEnum.FLASH_ATTN, }, "fa4": { "flash_attn_version": 4, - "disable_flashinfer_prefill": True, - "use_cudnn_prefill": False, - "use_trtllm_ragged_deepseek_prefill": False, + "mla_prefill_backend": MLAPrefillBackendEnum.FLASH_ATTN, }, "flashinfer": { - "mla_prefill_backend_enum": "FLASHINFER", - }, - "cudnn": { - # cuDNN prefill backend was removed; AttentionConfig raises on use. - "mla_prefill_backend_enum": "FLASHINFER", + "flash_attn_version": None, + "mla_prefill_backend": MLAPrefillBackendEnum.FLASHINFER, }, "trtllm": { - "mla_prefill_backend_enum": "TRTLLM_RAGGED", + "flash_attn_version": None, + "mla_prefill_backend": MLAPrefillBackendEnum.TRTLLM_RAGGED, }, "tokenspeed": { - "mla_prefill_backend_enum": "TOKENSPEED_MLA", + "flash_attn_version": None, + "mla_prefill_backend": MLAPrefillBackendEnum.TOKENSPEED_MLA, }, } @@ -1020,7 +1000,6 @@ def _run_mla_benchmark_batched( f"version {fa_version}, got " f"{actual_fa_version} on {actual_class}." ) - # Run each benchmark with the shared impl for config, threshold, num_splits in configs_with_params: # Set threshold for this benchmark (FlashAttn/FlashMLA only) diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py index 2a2658016e2..f595ca6ecbd 100644 --- a/tests/engine/test_arg_utils.py +++ b/tests/engine/test_arg_utils.py @@ -368,12 +368,8 @@ def test_attention_config(): "true", "--attention-config.flash_attn_max_num_splits_for_cuda_graph", "16", - "--attention-config.use_trtllm_ragged_deepseek_prefill", - "true", "--attention-config.use_trtllm_attention", "true", - "--attention-config.disable_flashinfer_prefill", - "true", "--attention-config.disable_flashinfer_q_quantization", "true", ] @@ -385,9 +381,7 @@ def test_attention_config(): assert engine_args.attention_config.flash_attn_version == 3 assert engine_args.attention_config.use_prefill_decode_attention is True assert engine_args.attention_config.flash_attn_max_num_splits_for_cuda_graph == 16 - assert engine_args.attention_config.use_trtllm_ragged_deepseek_prefill is True assert engine_args.attention_config.use_trtllm_attention is True - assert engine_args.attention_config.disable_flashinfer_prefill is True assert engine_args.attention_config.disable_flashinfer_q_quantization is True # set to string form of a dict with all fields @@ -397,10 +391,7 @@ def test_attention_config(): '{"backend": "FLASHINFER", "flash_attn_version": 2, ' '"use_prefill_decode_attention": false, ' '"flash_attn_max_num_splits_for_cuda_graph": 8, ' - '"use_cudnn_prefill": false, ' - '"use_trtllm_ragged_deepseek_prefill": false, ' '"use_trtllm_attention": false, ' - '"disable_flashinfer_prefill": false, ' '"disable_flashinfer_q_quantization": false}', ] ) @@ -411,10 +402,7 @@ def test_attention_config(): assert engine_args.attention_config.flash_attn_version == 2 assert engine_args.attention_config.use_prefill_decode_attention is False assert engine_args.attention_config.flash_attn_max_num_splits_for_cuda_graph == 8 - assert engine_args.attention_config.use_cudnn_prefill is False - assert engine_args.attention_config.use_trtllm_ragged_deepseek_prefill is False assert engine_args.attention_config.use_trtllm_attention is False - assert engine_args.attention_config.disable_flashinfer_prefill is False assert engine_args.attention_config.disable_flashinfer_q_quantization is False # test --attention-backend flows into VllmConfig.attention_config diff --git a/tests/v1/attention/test_mla_prefill_selector.py b/tests/v1/attention/test_mla_prefill_selector.py index 068eb43faf4..d5c80c80c03 100644 --- a/tests/v1/attention/test_mla_prefill_selector.py +++ b/tests/v1/attention/test_mla_prefill_selector.py @@ -269,36 +269,21 @@ class TestMLAPrefillBackendParsing: ) -class TestDeprecatedFlagMigration: - """Tests for _migrate_deprecated_mla_prefill_flags in AttentionConfig.""" +class TestMLAPrefillBackendConfig: + """Tests for mla_prefill_backend configuration in AttentionConfig.""" - def test_no_deprecated_flags_leaves_backend_none(self): + def test_default_backend_is_none(self): config = AttentionConfig() assert config.mla_prefill_backend is None - def test_use_trtllm_ragged_migrates_to_trtllm_ragged(self): - config = AttentionConfig(use_trtllm_ragged_deepseek_prefill=True) - assert config.mla_prefill_backend == MLAPrefillBackendEnum.TRTLLM_RAGGED - - def test_disable_flashinfer_prefill_migrates_to_flash_attn(self): - config = AttentionConfig(disable_flashinfer_prefill=True) - assert config.mla_prefill_backend == MLAPrefillBackendEnum.FLASH_ATTN - - def test_explicit_backend_ignores_deprecated_flags(self): + def test_explicit_flash_attn_backend(self): config = AttentionConfig( mla_prefill_backend=MLAPrefillBackendEnum.FLASH_ATTN, - use_cudnn_prefill=True, ) assert config.mla_prefill_backend == MLAPrefillBackendEnum.FLASH_ATTN - def test_cudnn_raises_error(self): - match = "cuDNN MLA prefill backend has been removed" - with pytest.raises(ValueError, match=match): - AttentionConfig(use_cudnn_prefill=True) - - def test_trtllm_takes_priority_over_disable_flashinfer(self): + def test_explicit_trtllm_ragged_backend(self): config = AttentionConfig( - use_trtllm_ragged_deepseek_prefill=True, - disable_flashinfer_prefill=True, + mla_prefill_backend=MLAPrefillBackendEnum.TRTLLM_RAGGED, ) assert config.mla_prefill_backend == MLAPrefillBackendEnum.TRTLLM_RAGGED diff --git a/vllm/config/attention.py b/vllm/config/attention.py index b9f387bf80c..52ce9f102a6 100644 --- a/vllm/config/attention.py +++ b/vllm/config/attention.py @@ -6,12 +6,9 @@ from typing import Any, Literal from pydantic import field_validator from vllm.config.utils import config -from vllm.logger import init_logger from vllm.v1.attention.backends.mla.prefill.registry import MLAPrefillBackendEnum from vllm.v1.attention.backends.registry import AttentionBackendEnum -logger = init_logger(__name__) - @config class AttentionConfig: @@ -36,27 +33,16 @@ class AttentionConfig: Fixes the split count so grid dimensions are constant across captures, and buffers can be pre-allocated to avoid inflating the memory estimate.""" - use_cudnn_prefill: bool = False - """Deprecated: cuDNN prefill backend has been removed.""" - - use_trtllm_ragged_deepseek_prefill: bool = False - """Whether to use TRTLLM ragged deepseek prefill.""" - use_trtllm_attention: bool | None = None """If set to True/False, use or don't use the TRTLLM attention backend in flashinfer. If None, auto-detect the attention backend in flashinfer.""" - disable_flashinfer_prefill: bool | None = None - """Whether to disable flashinfer prefill.""" - disable_flashinfer_q_quantization: bool = False """If set, when using fp8 kv, do not quantize Q to fp8.""" mla_prefill_backend: MLAPrefillBackendEnum | None = None """MLA prefill backend to use. If None, will be selected automatically. - Valid options: FLASH_ATTN (FA3/FA4), FLASHINFER, TRTLLM_RAGGED. - This option supersedes use_trtllm_ragged_deepseek_prefill - and disable_flashinfer_prefill which are deprecated.""" + Valid options: FLASH_ATTN (FA3/FA4), FLASHINFER, TRTLLM_RAGGED.""" use_prefill_query_quantization: bool = False """If set, quantize query for attention in prefill.""" @@ -123,40 +109,3 @@ class AttentionConfig: if isinstance(value, str): return MLAPrefillBackendEnum[value.upper()] return value - - def __post_init__(self) -> None: - self._migrate_deprecated_mla_prefill_flags() - - def _migrate_deprecated_mla_prefill_flags(self) -> None: - """Migrate deprecated MLA prefill flags to mla_prefill_backend.""" - # If the new option is already set, it takes precedence - if self.mla_prefill_backend is not None: - return - - # Check for deprecated flags and migrate them. - # Only the first flag encountered sets the backend. - if self.use_cudnn_prefill: - raise ValueError( - "The cuDNN MLA prefill backend has been removed. " - "Use --attention-config.mla_prefill_backend=FLASH_ATTN or " - "FLASHINFER or TRTLLM_RAGGED instead." - ) - - if self.use_trtllm_ragged_deepseek_prefill: - if self.mla_prefill_backend is None: - self.mla_prefill_backend = MLAPrefillBackendEnum.TRTLLM_RAGGED - logger.warning_once( - "use_trtllm_ragged_deepseek_prefill is deprecated and " - "will be removed in v0.22. Use " - "--attention-config.mla_prefill_backend=TRTLLM_RAGGED " - "instead." - ) - - if self.disable_flashinfer_prefill: - if self.mla_prefill_backend is None: - self.mla_prefill_backend = MLAPrefillBackendEnum.FLASH_ATTN - logger.warning_once( - "disable_flashinfer_prefill is deprecated and will be removed " - "in v0.22. Use --attention-config.mla_prefill_backend=" - "FLASH_ATTN instead." - ) diff --git a/vllm/model_executor/layers/attention/mla_attention.py b/vllm/model_executor/layers/attention/mla_attention.py index 8b27dbdaa4f..b9ce84c2e50 100644 --- a/vllm/model_executor/layers/attention/mla_attention.py +++ b/vllm/model_executor/layers/attention/mla_attention.py @@ -779,14 +779,14 @@ class MLAAttention(nn.Module, AttentionLayerBase): attn_out, lse, get_dcp_group(), - is_lse_base_on_e=not getattr(self, "_use_fi_prefill", False), + is_lse_base_on_e=True, ) else: attn_out = cp_lse_ag_out_rs( attn_out, lse, get_dcp_group(), - is_lse_base_on_e=not getattr(self, "_use_fi_prefill", False), + is_lse_base_on_e=True, ) # v_up projection