mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[Attention] Remove deprecated MLA prefill arguments (#42555)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
@@ -29,6 +29,7 @@ from vllm.config import (
|
|||||||
VllmConfig,
|
VllmConfig,
|
||||||
set_current_vllm_config,
|
set_current_vllm_config,
|
||||||
)
|
)
|
||||||
|
from vllm.v1.attention.backends.mla.prefill.registry import MLAPrefillBackendEnum
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
# VllmConfig Creation
|
# VllmConfig Creation
|
||||||
@@ -79,8 +80,8 @@ def create_minimal_vllm_config(
|
|||||||
index_topk: Optional topk value for sparse MLA backends. If provided,
|
index_topk: Optional topk value for sparse MLA backends. If provided,
|
||||||
the config will include index_topk for sparse attention.
|
the config will include index_topk for sparse attention.
|
||||||
prefill_backend: Prefill backend name (e.g., "fa3", "fa4", "flashinfer",
|
prefill_backend: Prefill backend name (e.g., "fa3", "fa4", "flashinfer",
|
||||||
"cudnn", "trtllm"). Configures the attention config to
|
"trtllm"). Configures the attention config to force
|
||||||
force the specified prefill backend.
|
the specified prefill backend.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
VllmConfig for benchmarking
|
VllmConfig for benchmarking
|
||||||
@@ -179,27 +180,13 @@ def create_minimal_vllm_config(
|
|||||||
|
|
||||||
if prefill_backend is not None:
|
if prefill_backend is not None:
|
||||||
prefill_cfg = get_prefill_backend_config(prefill_backend)
|
prefill_cfg = get_prefill_backend_config(prefill_backend)
|
||||||
if prefill_cfg.get("mla_prefill_backend_enum") is not None:
|
vllm_config.attention_config.mla_prefill_backend = prefill_cfg[
|
||||||
# Registry-based backends bypass the deprecated boolean flags.
|
"mla_prefill_backend"
|
||||||
from vllm.v1.attention.backends.mla.prefill import MLAPrefillBackendEnum
|
]
|
||||||
|
if prefill_cfg["flash_attn_version"] is not None:
|
||||||
vllm_config.attention_config.mla_prefill_backend = MLAPrefillBackendEnum[
|
vllm_config.attention_config.flash_attn_version = prefill_cfg[
|
||||||
prefill_cfg["mla_prefill_backend_enum"]
|
"flash_attn_version"
|
||||||
]
|
]
|
||||||
else:
|
|
||||||
if prefill_cfg["flash_attn_version"] is not None:
|
|
||||||
vllm_config.attention_config.flash_attn_version = prefill_cfg[
|
|
||||||
"flash_attn_version"
|
|
||||||
]
|
|
||||||
vllm_config.attention_config.disable_flashinfer_prefill = prefill_cfg[
|
|
||||||
"disable_flashinfer_prefill"
|
|
||||||
]
|
|
||||||
vllm_config.attention_config.use_cudnn_prefill = prefill_cfg[
|
|
||||||
"use_cudnn_prefill"
|
|
||||||
]
|
|
||||||
vllm_config.attention_config.use_trtllm_ragged_deepseek_prefill = (
|
|
||||||
prefill_cfg["use_trtllm_ragged_deepseek_prefill"]
|
|
||||||
)
|
|
||||||
|
|
||||||
return vllm_config
|
return vllm_config
|
||||||
|
|
||||||
@@ -214,34 +201,27 @@ def create_minimal_vllm_config(
|
|||||||
_PREFILL_BACKEND_CONFIG: dict[str, dict] = {
|
_PREFILL_BACKEND_CONFIG: dict[str, dict] = {
|
||||||
"fa2": {
|
"fa2": {
|
||||||
"flash_attn_version": 2,
|
"flash_attn_version": 2,
|
||||||
"disable_flashinfer_prefill": True,
|
"mla_prefill_backend": MLAPrefillBackendEnum.FLASH_ATTN,
|
||||||
"use_cudnn_prefill": False,
|
|
||||||
"use_trtllm_ragged_deepseek_prefill": False,
|
|
||||||
},
|
},
|
||||||
"fa3": {
|
"fa3": {
|
||||||
"flash_attn_version": 3,
|
"flash_attn_version": 3,
|
||||||
"disable_flashinfer_prefill": True,
|
"mla_prefill_backend": MLAPrefillBackendEnum.FLASH_ATTN,
|
||||||
"use_cudnn_prefill": False,
|
|
||||||
"use_trtllm_ragged_deepseek_prefill": False,
|
|
||||||
},
|
},
|
||||||
"fa4": {
|
"fa4": {
|
||||||
"flash_attn_version": 4,
|
"flash_attn_version": 4,
|
||||||
"disable_flashinfer_prefill": True,
|
"mla_prefill_backend": MLAPrefillBackendEnum.FLASH_ATTN,
|
||||||
"use_cudnn_prefill": False,
|
|
||||||
"use_trtllm_ragged_deepseek_prefill": False,
|
|
||||||
},
|
},
|
||||||
"flashinfer": {
|
"flashinfer": {
|
||||||
"mla_prefill_backend_enum": "FLASHINFER",
|
"flash_attn_version": None,
|
||||||
},
|
"mla_prefill_backend": MLAPrefillBackendEnum.FLASHINFER,
|
||||||
"cudnn": {
|
|
||||||
# cuDNN prefill backend was removed; AttentionConfig raises on use.
|
|
||||||
"mla_prefill_backend_enum": "FLASHINFER",
|
|
||||||
},
|
},
|
||||||
"trtllm": {
|
"trtllm": {
|
||||||
"mla_prefill_backend_enum": "TRTLLM_RAGGED",
|
"flash_attn_version": None,
|
||||||
|
"mla_prefill_backend": MLAPrefillBackendEnum.TRTLLM_RAGGED,
|
||||||
},
|
},
|
||||||
"tokenspeed": {
|
"tokenspeed": {
|
||||||
"mla_prefill_backend_enum": "TOKENSPEED_MLA",
|
"flash_attn_version": None,
|
||||||
|
"mla_prefill_backend": MLAPrefillBackendEnum.TOKENSPEED_MLA,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1020,7 +1000,6 @@ def _run_mla_benchmark_batched(
|
|||||||
f"version {fa_version}, got "
|
f"version {fa_version}, got "
|
||||||
f"{actual_fa_version} on {actual_class}."
|
f"{actual_fa_version} on {actual_class}."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Run each benchmark with the shared impl
|
# Run each benchmark with the shared impl
|
||||||
for config, threshold, num_splits in configs_with_params:
|
for config, threshold, num_splits in configs_with_params:
|
||||||
# Set threshold for this benchmark (FlashAttn/FlashMLA only)
|
# Set threshold for this benchmark (FlashAttn/FlashMLA only)
|
||||||
|
|||||||
@@ -368,12 +368,8 @@ def test_attention_config():
|
|||||||
"true",
|
"true",
|
||||||
"--attention-config.flash_attn_max_num_splits_for_cuda_graph",
|
"--attention-config.flash_attn_max_num_splits_for_cuda_graph",
|
||||||
"16",
|
"16",
|
||||||
"--attention-config.use_trtllm_ragged_deepseek_prefill",
|
|
||||||
"true",
|
|
||||||
"--attention-config.use_trtllm_attention",
|
"--attention-config.use_trtllm_attention",
|
||||||
"true",
|
"true",
|
||||||
"--attention-config.disable_flashinfer_prefill",
|
|
||||||
"true",
|
|
||||||
"--attention-config.disable_flashinfer_q_quantization",
|
"--attention-config.disable_flashinfer_q_quantization",
|
||||||
"true",
|
"true",
|
||||||
]
|
]
|
||||||
@@ -385,9 +381,7 @@ def test_attention_config():
|
|||||||
assert engine_args.attention_config.flash_attn_version == 3
|
assert engine_args.attention_config.flash_attn_version == 3
|
||||||
assert engine_args.attention_config.use_prefill_decode_attention is True
|
assert engine_args.attention_config.use_prefill_decode_attention is True
|
||||||
assert engine_args.attention_config.flash_attn_max_num_splits_for_cuda_graph == 16
|
assert engine_args.attention_config.flash_attn_max_num_splits_for_cuda_graph == 16
|
||||||
assert engine_args.attention_config.use_trtllm_ragged_deepseek_prefill is True
|
|
||||||
assert engine_args.attention_config.use_trtllm_attention is True
|
assert engine_args.attention_config.use_trtllm_attention is True
|
||||||
assert engine_args.attention_config.disable_flashinfer_prefill is True
|
|
||||||
assert engine_args.attention_config.disable_flashinfer_q_quantization is True
|
assert engine_args.attention_config.disable_flashinfer_q_quantization is True
|
||||||
|
|
||||||
# set to string form of a dict with all fields
|
# set to string form of a dict with all fields
|
||||||
@@ -397,10 +391,7 @@ def test_attention_config():
|
|||||||
'{"backend": "FLASHINFER", "flash_attn_version": 2, '
|
'{"backend": "FLASHINFER", "flash_attn_version": 2, '
|
||||||
'"use_prefill_decode_attention": false, '
|
'"use_prefill_decode_attention": false, '
|
||||||
'"flash_attn_max_num_splits_for_cuda_graph": 8, '
|
'"flash_attn_max_num_splits_for_cuda_graph": 8, '
|
||||||
'"use_cudnn_prefill": false, '
|
|
||||||
'"use_trtllm_ragged_deepseek_prefill": false, '
|
|
||||||
'"use_trtllm_attention": false, '
|
'"use_trtllm_attention": false, '
|
||||||
'"disable_flashinfer_prefill": false, '
|
|
||||||
'"disable_flashinfer_q_quantization": false}',
|
'"disable_flashinfer_q_quantization": false}',
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@@ -411,10 +402,7 @@ def test_attention_config():
|
|||||||
assert engine_args.attention_config.flash_attn_version == 2
|
assert engine_args.attention_config.flash_attn_version == 2
|
||||||
assert engine_args.attention_config.use_prefill_decode_attention is False
|
assert engine_args.attention_config.use_prefill_decode_attention is False
|
||||||
assert engine_args.attention_config.flash_attn_max_num_splits_for_cuda_graph == 8
|
assert engine_args.attention_config.flash_attn_max_num_splits_for_cuda_graph == 8
|
||||||
assert engine_args.attention_config.use_cudnn_prefill is False
|
|
||||||
assert engine_args.attention_config.use_trtllm_ragged_deepseek_prefill is False
|
|
||||||
assert engine_args.attention_config.use_trtllm_attention is False
|
assert engine_args.attention_config.use_trtllm_attention is False
|
||||||
assert engine_args.attention_config.disable_flashinfer_prefill is False
|
|
||||||
assert engine_args.attention_config.disable_flashinfer_q_quantization is False
|
assert engine_args.attention_config.disable_flashinfer_q_quantization is False
|
||||||
|
|
||||||
# test --attention-backend flows into VllmConfig.attention_config
|
# test --attention-backend flows into VllmConfig.attention_config
|
||||||
|
|||||||
@@ -269,36 +269,21 @@ class TestMLAPrefillBackendParsing:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestDeprecatedFlagMigration:
|
class TestMLAPrefillBackendConfig:
|
||||||
"""Tests for _migrate_deprecated_mla_prefill_flags in AttentionConfig."""
|
"""Tests for mla_prefill_backend configuration in AttentionConfig."""
|
||||||
|
|
||||||
def test_no_deprecated_flags_leaves_backend_none(self):
|
def test_default_backend_is_none(self):
|
||||||
config = AttentionConfig()
|
config = AttentionConfig()
|
||||||
assert config.mla_prefill_backend is None
|
assert config.mla_prefill_backend is None
|
||||||
|
|
||||||
def test_use_trtllm_ragged_migrates_to_trtllm_ragged(self):
|
def test_explicit_flash_attn_backend(self):
|
||||||
config = AttentionConfig(use_trtllm_ragged_deepseek_prefill=True)
|
|
||||||
assert config.mla_prefill_backend == MLAPrefillBackendEnum.TRTLLM_RAGGED
|
|
||||||
|
|
||||||
def test_disable_flashinfer_prefill_migrates_to_flash_attn(self):
|
|
||||||
config = AttentionConfig(disable_flashinfer_prefill=True)
|
|
||||||
assert config.mla_prefill_backend == MLAPrefillBackendEnum.FLASH_ATTN
|
|
||||||
|
|
||||||
def test_explicit_backend_ignores_deprecated_flags(self):
|
|
||||||
config = AttentionConfig(
|
config = AttentionConfig(
|
||||||
mla_prefill_backend=MLAPrefillBackendEnum.FLASH_ATTN,
|
mla_prefill_backend=MLAPrefillBackendEnum.FLASH_ATTN,
|
||||||
use_cudnn_prefill=True,
|
|
||||||
)
|
)
|
||||||
assert config.mla_prefill_backend == MLAPrefillBackendEnum.FLASH_ATTN
|
assert config.mla_prefill_backend == MLAPrefillBackendEnum.FLASH_ATTN
|
||||||
|
|
||||||
def test_cudnn_raises_error(self):
|
def test_explicit_trtllm_ragged_backend(self):
|
||||||
match = "cuDNN MLA prefill backend has been removed"
|
|
||||||
with pytest.raises(ValueError, match=match):
|
|
||||||
AttentionConfig(use_cudnn_prefill=True)
|
|
||||||
|
|
||||||
def test_trtllm_takes_priority_over_disable_flashinfer(self):
|
|
||||||
config = AttentionConfig(
|
config = AttentionConfig(
|
||||||
use_trtllm_ragged_deepseek_prefill=True,
|
mla_prefill_backend=MLAPrefillBackendEnum.TRTLLM_RAGGED,
|
||||||
disable_flashinfer_prefill=True,
|
|
||||||
)
|
)
|
||||||
assert config.mla_prefill_backend == MLAPrefillBackendEnum.TRTLLM_RAGGED
|
assert config.mla_prefill_backend == MLAPrefillBackendEnum.TRTLLM_RAGGED
|
||||||
|
|||||||
@@ -6,12 +6,9 @@ from typing import Any, Literal
|
|||||||
from pydantic import field_validator
|
from pydantic import field_validator
|
||||||
|
|
||||||
from vllm.config.utils import config
|
from vllm.config.utils import config
|
||||||
from vllm.logger import init_logger
|
|
||||||
from vllm.v1.attention.backends.mla.prefill.registry import MLAPrefillBackendEnum
|
from vllm.v1.attention.backends.mla.prefill.registry import MLAPrefillBackendEnum
|
||||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
@config
|
@config
|
||||||
class AttentionConfig:
|
class AttentionConfig:
|
||||||
@@ -36,27 +33,16 @@ class AttentionConfig:
|
|||||||
Fixes the split count so grid dimensions are constant across captures,
|
Fixes the split count so grid dimensions are constant across captures,
|
||||||
and buffers can be pre-allocated to avoid inflating the memory estimate."""
|
and buffers can be pre-allocated to avoid inflating the memory estimate."""
|
||||||
|
|
||||||
use_cudnn_prefill: bool = False
|
|
||||||
"""Deprecated: cuDNN prefill backend has been removed."""
|
|
||||||
|
|
||||||
use_trtllm_ragged_deepseek_prefill: bool = False
|
|
||||||
"""Whether to use TRTLLM ragged deepseek prefill."""
|
|
||||||
|
|
||||||
use_trtllm_attention: bool | None = None
|
use_trtllm_attention: bool | None = None
|
||||||
"""If set to True/False, use or don't use the TRTLLM attention backend
|
"""If set to True/False, use or don't use the TRTLLM attention backend
|
||||||
in flashinfer. If None, auto-detect the attention backend in flashinfer."""
|
in flashinfer. If None, auto-detect the attention backend in flashinfer."""
|
||||||
|
|
||||||
disable_flashinfer_prefill: bool | None = None
|
|
||||||
"""Whether to disable flashinfer prefill."""
|
|
||||||
|
|
||||||
disable_flashinfer_q_quantization: bool = False
|
disable_flashinfer_q_quantization: bool = False
|
||||||
"""If set, when using fp8 kv, do not quantize Q to fp8."""
|
"""If set, when using fp8 kv, do not quantize Q to fp8."""
|
||||||
|
|
||||||
mla_prefill_backend: MLAPrefillBackendEnum | None = None
|
mla_prefill_backend: MLAPrefillBackendEnum | None = None
|
||||||
"""MLA prefill backend to use. If None, will be selected automatically.
|
"""MLA prefill backend to use. If None, will be selected automatically.
|
||||||
Valid options: FLASH_ATTN (FA3/FA4), FLASHINFER, TRTLLM_RAGGED.
|
Valid options: FLASH_ATTN (FA3/FA4), FLASHINFER, TRTLLM_RAGGED."""
|
||||||
This option supersedes use_trtllm_ragged_deepseek_prefill
|
|
||||||
and disable_flashinfer_prefill which are deprecated."""
|
|
||||||
|
|
||||||
use_prefill_query_quantization: bool = False
|
use_prefill_query_quantization: bool = False
|
||||||
"""If set, quantize query for attention in prefill."""
|
"""If set, quantize query for attention in prefill."""
|
||||||
@@ -123,40 +109,3 @@ class AttentionConfig:
|
|||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
return MLAPrefillBackendEnum[value.upper()]
|
return MLAPrefillBackendEnum[value.upper()]
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
|
||||||
self._migrate_deprecated_mla_prefill_flags()
|
|
||||||
|
|
||||||
def _migrate_deprecated_mla_prefill_flags(self) -> None:
|
|
||||||
"""Migrate deprecated MLA prefill flags to mla_prefill_backend."""
|
|
||||||
# If the new option is already set, it takes precedence
|
|
||||||
if self.mla_prefill_backend is not None:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Check for deprecated flags and migrate them.
|
|
||||||
# Only the first flag encountered sets the backend.
|
|
||||||
if self.use_cudnn_prefill:
|
|
||||||
raise ValueError(
|
|
||||||
"The cuDNN MLA prefill backend has been removed. "
|
|
||||||
"Use --attention-config.mla_prefill_backend=FLASH_ATTN or "
|
|
||||||
"FLASHINFER or TRTLLM_RAGGED instead."
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.use_trtllm_ragged_deepseek_prefill:
|
|
||||||
if self.mla_prefill_backend is None:
|
|
||||||
self.mla_prefill_backend = MLAPrefillBackendEnum.TRTLLM_RAGGED
|
|
||||||
logger.warning_once(
|
|
||||||
"use_trtllm_ragged_deepseek_prefill is deprecated and "
|
|
||||||
"will be removed in v0.22. Use "
|
|
||||||
"--attention-config.mla_prefill_backend=TRTLLM_RAGGED "
|
|
||||||
"instead."
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.disable_flashinfer_prefill:
|
|
||||||
if self.mla_prefill_backend is None:
|
|
||||||
self.mla_prefill_backend = MLAPrefillBackendEnum.FLASH_ATTN
|
|
||||||
logger.warning_once(
|
|
||||||
"disable_flashinfer_prefill is deprecated and will be removed "
|
|
||||||
"in v0.22. Use --attention-config.mla_prefill_backend="
|
|
||||||
"FLASH_ATTN instead."
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -779,14 +779,14 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
|||||||
attn_out,
|
attn_out,
|
||||||
lse,
|
lse,
|
||||||
get_dcp_group(),
|
get_dcp_group(),
|
||||||
is_lse_base_on_e=not getattr(self, "_use_fi_prefill", False),
|
is_lse_base_on_e=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
attn_out = cp_lse_ag_out_rs(
|
attn_out = cp_lse_ag_out_rs(
|
||||||
attn_out,
|
attn_out,
|
||||||
lse,
|
lse,
|
||||||
get_dcp_group(),
|
get_dcp_group(),
|
||||||
is_lse_base_on_e=not getattr(self, "_use_fi_prefill", False),
|
is_lse_base_on_e=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# v_up projection
|
# v_up projection
|
||||||
|
|||||||
Reference in New Issue
Block a user