mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[Attention] Remove deprecated MLA prefill arguments (#42555)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
@@ -29,6 +29,7 @@ from vllm.config import (
|
||||
VllmConfig,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.v1.attention.backends.mla.prefill.registry import MLAPrefillBackendEnum
|
||||
|
||||
# ============================================================================
|
||||
# VllmConfig Creation
|
||||
@@ -79,8 +80,8 @@ def create_minimal_vllm_config(
|
||||
index_topk: Optional topk value for sparse MLA backends. If provided,
|
||||
the config will include index_topk for sparse attention.
|
||||
prefill_backend: Prefill backend name (e.g., "fa3", "fa4", "flashinfer",
|
||||
"cudnn", "trtllm"). Configures the attention config to
|
||||
force the specified prefill backend.
|
||||
"trtllm"). Configures the attention config to force
|
||||
the specified prefill backend.
|
||||
|
||||
Returns:
|
||||
VllmConfig for benchmarking
|
||||
@@ -179,27 +180,13 @@ def create_minimal_vllm_config(
|
||||
|
||||
if prefill_backend is not None:
|
||||
prefill_cfg = get_prefill_backend_config(prefill_backend)
|
||||
if prefill_cfg.get("mla_prefill_backend_enum") is not None:
|
||||
# Registry-based backends bypass the deprecated boolean flags.
|
||||
from vllm.v1.attention.backends.mla.prefill import MLAPrefillBackendEnum
|
||||
|
||||
vllm_config.attention_config.mla_prefill_backend = MLAPrefillBackendEnum[
|
||||
prefill_cfg["mla_prefill_backend_enum"]
|
||||
vllm_config.attention_config.mla_prefill_backend = prefill_cfg[
|
||||
"mla_prefill_backend"
|
||||
]
|
||||
if prefill_cfg["flash_attn_version"] is not None:
|
||||
vllm_config.attention_config.flash_attn_version = prefill_cfg[
|
||||
"flash_attn_version"
|
||||
]
|
||||
else:
|
||||
if prefill_cfg["flash_attn_version"] is not None:
|
||||
vllm_config.attention_config.flash_attn_version = prefill_cfg[
|
||||
"flash_attn_version"
|
||||
]
|
||||
vllm_config.attention_config.disable_flashinfer_prefill = prefill_cfg[
|
||||
"disable_flashinfer_prefill"
|
||||
]
|
||||
vllm_config.attention_config.use_cudnn_prefill = prefill_cfg[
|
||||
"use_cudnn_prefill"
|
||||
]
|
||||
vllm_config.attention_config.use_trtllm_ragged_deepseek_prefill = (
|
||||
prefill_cfg["use_trtllm_ragged_deepseek_prefill"]
|
||||
)
|
||||
|
||||
return vllm_config
|
||||
|
||||
@@ -214,34 +201,27 @@ def create_minimal_vllm_config(
|
||||
_PREFILL_BACKEND_CONFIG: dict[str, dict] = {
|
||||
"fa2": {
|
||||
"flash_attn_version": 2,
|
||||
"disable_flashinfer_prefill": True,
|
||||
"use_cudnn_prefill": False,
|
||||
"use_trtllm_ragged_deepseek_prefill": False,
|
||||
"mla_prefill_backend": MLAPrefillBackendEnum.FLASH_ATTN,
|
||||
},
|
||||
"fa3": {
|
||||
"flash_attn_version": 3,
|
||||
"disable_flashinfer_prefill": True,
|
||||
"use_cudnn_prefill": False,
|
||||
"use_trtllm_ragged_deepseek_prefill": False,
|
||||
"mla_prefill_backend": MLAPrefillBackendEnum.FLASH_ATTN,
|
||||
},
|
||||
"fa4": {
|
||||
"flash_attn_version": 4,
|
||||
"disable_flashinfer_prefill": True,
|
||||
"use_cudnn_prefill": False,
|
||||
"use_trtllm_ragged_deepseek_prefill": False,
|
||||
"mla_prefill_backend": MLAPrefillBackendEnum.FLASH_ATTN,
|
||||
},
|
||||
"flashinfer": {
|
||||
"mla_prefill_backend_enum": "FLASHINFER",
|
||||
},
|
||||
"cudnn": {
|
||||
# cuDNN prefill backend was removed; AttentionConfig raises on use.
|
||||
"mla_prefill_backend_enum": "FLASHINFER",
|
||||
"flash_attn_version": None,
|
||||
"mla_prefill_backend": MLAPrefillBackendEnum.FLASHINFER,
|
||||
},
|
||||
"trtllm": {
|
||||
"mla_prefill_backend_enum": "TRTLLM_RAGGED",
|
||||
"flash_attn_version": None,
|
||||
"mla_prefill_backend": MLAPrefillBackendEnum.TRTLLM_RAGGED,
|
||||
},
|
||||
"tokenspeed": {
|
||||
"mla_prefill_backend_enum": "TOKENSPEED_MLA",
|
||||
"flash_attn_version": None,
|
||||
"mla_prefill_backend": MLAPrefillBackendEnum.TOKENSPEED_MLA,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1020,7 +1000,6 @@ def _run_mla_benchmark_batched(
|
||||
f"version {fa_version}, got "
|
||||
f"{actual_fa_version} on {actual_class}."
|
||||
)
|
||||
|
||||
# Run each benchmark with the shared impl
|
||||
for config, threshold, num_splits in configs_with_params:
|
||||
# Set threshold for this benchmark (FlashAttn/FlashMLA only)
|
||||
|
||||
@@ -368,12 +368,8 @@ def test_attention_config():
|
||||
"true",
|
||||
"--attention-config.flash_attn_max_num_splits_for_cuda_graph",
|
||||
"16",
|
||||
"--attention-config.use_trtllm_ragged_deepseek_prefill",
|
||||
"true",
|
||||
"--attention-config.use_trtllm_attention",
|
||||
"true",
|
||||
"--attention-config.disable_flashinfer_prefill",
|
||||
"true",
|
||||
"--attention-config.disable_flashinfer_q_quantization",
|
||||
"true",
|
||||
]
|
||||
@@ -385,9 +381,7 @@ def test_attention_config():
|
||||
assert engine_args.attention_config.flash_attn_version == 3
|
||||
assert engine_args.attention_config.use_prefill_decode_attention is True
|
||||
assert engine_args.attention_config.flash_attn_max_num_splits_for_cuda_graph == 16
|
||||
assert engine_args.attention_config.use_trtllm_ragged_deepseek_prefill is True
|
||||
assert engine_args.attention_config.use_trtllm_attention is True
|
||||
assert engine_args.attention_config.disable_flashinfer_prefill is True
|
||||
assert engine_args.attention_config.disable_flashinfer_q_quantization is True
|
||||
|
||||
# set to string form of a dict with all fields
|
||||
@@ -397,10 +391,7 @@ def test_attention_config():
|
||||
'{"backend": "FLASHINFER", "flash_attn_version": 2, '
|
||||
'"use_prefill_decode_attention": false, '
|
||||
'"flash_attn_max_num_splits_for_cuda_graph": 8, '
|
||||
'"use_cudnn_prefill": false, '
|
||||
'"use_trtllm_ragged_deepseek_prefill": false, '
|
||||
'"use_trtllm_attention": false, '
|
||||
'"disable_flashinfer_prefill": false, '
|
||||
'"disable_flashinfer_q_quantization": false}',
|
||||
]
|
||||
)
|
||||
@@ -411,10 +402,7 @@ def test_attention_config():
|
||||
assert engine_args.attention_config.flash_attn_version == 2
|
||||
assert engine_args.attention_config.use_prefill_decode_attention is False
|
||||
assert engine_args.attention_config.flash_attn_max_num_splits_for_cuda_graph == 8
|
||||
assert engine_args.attention_config.use_cudnn_prefill is False
|
||||
assert engine_args.attention_config.use_trtllm_ragged_deepseek_prefill is False
|
||||
assert engine_args.attention_config.use_trtllm_attention is False
|
||||
assert engine_args.attention_config.disable_flashinfer_prefill is False
|
||||
assert engine_args.attention_config.disable_flashinfer_q_quantization is False
|
||||
|
||||
# test --attention-backend flows into VllmConfig.attention_config
|
||||
|
||||
@@ -269,36 +269,21 @@ class TestMLAPrefillBackendParsing:
|
||||
)
|
||||
|
||||
|
||||
class TestDeprecatedFlagMigration:
|
||||
"""Tests for _migrate_deprecated_mla_prefill_flags in AttentionConfig."""
|
||||
class TestMLAPrefillBackendConfig:
|
||||
"""Tests for mla_prefill_backend configuration in AttentionConfig."""
|
||||
|
||||
def test_no_deprecated_flags_leaves_backend_none(self):
|
||||
def test_default_backend_is_none(self):
|
||||
config = AttentionConfig()
|
||||
assert config.mla_prefill_backend is None
|
||||
|
||||
def test_use_trtllm_ragged_migrates_to_trtllm_ragged(self):
|
||||
config = AttentionConfig(use_trtllm_ragged_deepseek_prefill=True)
|
||||
assert config.mla_prefill_backend == MLAPrefillBackendEnum.TRTLLM_RAGGED
|
||||
|
||||
def test_disable_flashinfer_prefill_migrates_to_flash_attn(self):
|
||||
config = AttentionConfig(disable_flashinfer_prefill=True)
|
||||
assert config.mla_prefill_backend == MLAPrefillBackendEnum.FLASH_ATTN
|
||||
|
||||
def test_explicit_backend_ignores_deprecated_flags(self):
|
||||
def test_explicit_flash_attn_backend(self):
|
||||
config = AttentionConfig(
|
||||
mla_prefill_backend=MLAPrefillBackendEnum.FLASH_ATTN,
|
||||
use_cudnn_prefill=True,
|
||||
)
|
||||
assert config.mla_prefill_backend == MLAPrefillBackendEnum.FLASH_ATTN
|
||||
|
||||
def test_cudnn_raises_error(self):
|
||||
match = "cuDNN MLA prefill backend has been removed"
|
||||
with pytest.raises(ValueError, match=match):
|
||||
AttentionConfig(use_cudnn_prefill=True)
|
||||
|
||||
def test_trtllm_takes_priority_over_disable_flashinfer(self):
|
||||
def test_explicit_trtllm_ragged_backend(self):
|
||||
config = AttentionConfig(
|
||||
use_trtllm_ragged_deepseek_prefill=True,
|
||||
disable_flashinfer_prefill=True,
|
||||
mla_prefill_backend=MLAPrefillBackendEnum.TRTLLM_RAGGED,
|
||||
)
|
||||
assert config.mla_prefill_backend == MLAPrefillBackendEnum.TRTLLM_RAGGED
|
||||
|
||||
@@ -6,12 +6,9 @@ from typing import Any, Literal
|
||||
from pydantic import field_validator
|
||||
|
||||
from vllm.config.utils import config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.mla.prefill.registry import MLAPrefillBackendEnum
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@config
|
||||
class AttentionConfig:
|
||||
@@ -36,27 +33,16 @@ class AttentionConfig:
|
||||
Fixes the split count so grid dimensions are constant across captures,
|
||||
and buffers can be pre-allocated to avoid inflating the memory estimate."""
|
||||
|
||||
use_cudnn_prefill: bool = False
|
||||
"""Deprecated: cuDNN prefill backend has been removed."""
|
||||
|
||||
use_trtllm_ragged_deepseek_prefill: bool = False
|
||||
"""Whether to use TRTLLM ragged deepseek prefill."""
|
||||
|
||||
use_trtllm_attention: bool | None = None
|
||||
"""If set to True/False, use or don't use the TRTLLM attention backend
|
||||
in flashinfer. If None, auto-detect the attention backend in flashinfer."""
|
||||
|
||||
disable_flashinfer_prefill: bool | None = None
|
||||
"""Whether to disable flashinfer prefill."""
|
||||
|
||||
disable_flashinfer_q_quantization: bool = False
|
||||
"""If set, when using fp8 kv, do not quantize Q to fp8."""
|
||||
|
||||
mla_prefill_backend: MLAPrefillBackendEnum | None = None
|
||||
"""MLA prefill backend to use. If None, will be selected automatically.
|
||||
Valid options: FLASH_ATTN (FA3/FA4), FLASHINFER, TRTLLM_RAGGED.
|
||||
This option supersedes use_trtllm_ragged_deepseek_prefill
|
||||
and disable_flashinfer_prefill which are deprecated."""
|
||||
Valid options: FLASH_ATTN (FA3/FA4), FLASHINFER, TRTLLM_RAGGED."""
|
||||
|
||||
use_prefill_query_quantization: bool = False
|
||||
"""If set, quantize query for attention in prefill."""
|
||||
@@ -123,40 +109,3 @@ class AttentionConfig:
|
||||
if isinstance(value, str):
|
||||
return MLAPrefillBackendEnum[value.upper()]
|
||||
return value
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self._migrate_deprecated_mla_prefill_flags()
|
||||
|
||||
def _migrate_deprecated_mla_prefill_flags(self) -> None:
|
||||
"""Migrate deprecated MLA prefill flags to mla_prefill_backend."""
|
||||
# If the new option is already set, it takes precedence
|
||||
if self.mla_prefill_backend is not None:
|
||||
return
|
||||
|
||||
# Check for deprecated flags and migrate them.
|
||||
# Only the first flag encountered sets the backend.
|
||||
if self.use_cudnn_prefill:
|
||||
raise ValueError(
|
||||
"The cuDNN MLA prefill backend has been removed. "
|
||||
"Use --attention-config.mla_prefill_backend=FLASH_ATTN or "
|
||||
"FLASHINFER or TRTLLM_RAGGED instead."
|
||||
)
|
||||
|
||||
if self.use_trtllm_ragged_deepseek_prefill:
|
||||
if self.mla_prefill_backend is None:
|
||||
self.mla_prefill_backend = MLAPrefillBackendEnum.TRTLLM_RAGGED
|
||||
logger.warning_once(
|
||||
"use_trtllm_ragged_deepseek_prefill is deprecated and "
|
||||
"will be removed in v0.22. Use "
|
||||
"--attention-config.mla_prefill_backend=TRTLLM_RAGGED "
|
||||
"instead."
|
||||
)
|
||||
|
||||
if self.disable_flashinfer_prefill:
|
||||
if self.mla_prefill_backend is None:
|
||||
self.mla_prefill_backend = MLAPrefillBackendEnum.FLASH_ATTN
|
||||
logger.warning_once(
|
||||
"disable_flashinfer_prefill is deprecated and will be removed "
|
||||
"in v0.22. Use --attention-config.mla_prefill_backend="
|
||||
"FLASH_ATTN instead."
|
||||
)
|
||||
|
||||
@@ -779,14 +779,14 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
attn_out,
|
||||
lse,
|
||||
get_dcp_group(),
|
||||
is_lse_base_on_e=not getattr(self, "_use_fi_prefill", False),
|
||||
is_lse_base_on_e=True,
|
||||
)
|
||||
else:
|
||||
attn_out = cp_lse_ag_out_rs(
|
||||
attn_out,
|
||||
lse,
|
||||
get_dcp_group(),
|
||||
is_lse_base_on_e=not getattr(self, "_use_fi_prefill", False),
|
||||
is_lse_base_on_e=True,
|
||||
)
|
||||
|
||||
# v_up projection
|
||||
|
||||
Reference in New Issue
Block a user