[Attention] Abstract the MLA prefill backends and eliminate cuDNN (#32623)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
This commit is contained in:
Matthew Bonanni
2026-05-01 13:36:20 -04:00
committed by GitHub
parent 51295793a2
commit f3fef12350
16 changed files with 1629 additions and 708 deletions
@@ -30,7 +30,6 @@ REPO_ROOT = Path(__file__).parent.parent.parent
RELEVANT_PATTERNS = [
"vllm/v1/attention/backends/*.py",
"vllm/v1/attention/backends/**/*.py",
"vllm/v1/attention/backends/fa_utils.py",
"vllm/model_executor/layers/attention/mla_attention.py",
"vllm/platforms/cuda.py",
"tools/pre_commit/generate_attention_backend_docs.py",
@@ -68,6 +67,11 @@ def is_relevant_file(filepath: str) -> bool:
return any(fnmatch.fnmatch(path_str, pattern) for pattern in RELEVANT_PATTERNS)
MLA_PREFILL_DIR = BACKENDS_DIR / "mla" / "prefill"
MLA_PREFILL_REGISTRY_FILE = MLA_PREFILL_DIR / "registry.py"
MLA_PREFILL_SELECTOR_FILE = MLA_PREFILL_DIR / "selector.py"
# ---------------------------------------------------------------------------
# AST utility helpers
# ---------------------------------------------------------------------------
@@ -293,6 +297,242 @@ def get_file_from_class_path(class_path: str) -> Path | None:
return py_file if py_file.exists() else None
def parse_mla_prefill_registry() -> dict[str, str]:
"""Parse MLAPrefillBackendEnum from the prefill registry.
Returns:
A dict mapping backend names to their class paths.
"""
if not MLA_PREFILL_REGISTRY_FILE.exists():
return {}
try:
tree = ast.parse(MLA_PREFILL_REGISTRY_FILE.read_text())
except Exception:
return {}
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef) and node.name == "MLAPrefillBackendEnum":
return _extract_enum_values(node)
return {}
def parse_mla_prefill_priorities() -> dict[str, list[str]]:
"""Parse MLA prefill backend priorities from selector.py.
Returns:
A dict with keys like 'blackwell' and 'default' containing
lists of backend enum names in priority order.
"""
if not MLA_PREFILL_SELECTOR_FILE.exists():
return {}
try:
tree = ast.parse(MLA_PREFILL_SELECTOR_FILE.read_text())
except Exception:
return {}
priorities: dict[str, list[str]] = {}
for node in ast.walk(tree):
if not isinstance(node, ast.FunctionDef):
continue
if node.name != "_get_mla_prefill_backend_priorities":
continue
# Look for if statements checking device_capability.major
for stmt in ast.walk(node):
if not isinstance(stmt, ast.If):
continue
# Check if it's a capability.major == 10 check (Blackwell)
is_blackwell = (
isinstance(stmt.test, ast.Compare)
and isinstance(stmt.test.left, ast.Attribute)
and stmt.test.left.attr == "major"
and stmt.test.comparators
and isinstance(stmt.test.comparators[0], ast.Constant)
and stmt.test.comparators[0].value == 10
)
# Extract backends from return statements
for body_stmt in stmt.body:
if isinstance(body_stmt, ast.Return) and isinstance(
body_stmt.value, ast.List
):
backends = []
for elt in body_stmt.value.elts:
if isinstance(elt, ast.Attribute):
backends.append(elt.attr)
if is_blackwell:
priorities["blackwell"] = backends
else:
priorities["default"] = backends
# Extract from else branch
for else_stmt in stmt.orelse:
if isinstance(else_stmt, ast.Return) and isinstance(
else_stmt.value, ast.List
):
backends = []
for elt in else_stmt.value.elts:
if isinstance(elt, ast.Attribute):
backends.append(elt.attr)
priorities["default"] = backends
return priorities
def parse_mla_prefill_backend_file(class_path: str) -> dict[str, Any] | None:
"""Parse a single MLA prefill backend file to extract its properties.
Args:
class_path: The fully qualified class path.
Returns:
A dict with backend properties, or None if parsing fails.
"""
file_path = get_file_from_class_path(class_path)
if file_path is None:
return None
try:
tree = ast.parse(file_path.read_text())
except Exception:
return None
class_name = class_path.rsplit(".", 1)[1]
class_node = find_class_in_ast(tree, class_name)
if class_node is None:
return None
info: dict[str, Any] = {
"compute_capability": "Any",
"requires_r1_dims": False,
"dtypes": "fp16, bf16", # Default from base class
}
# Parse class variables
for item in class_node.body:
if isinstance(item, ast.Assign):
for target in item.targets:
if (
isinstance(target, ast.Name)
and target.id == "requires_r1_mla_dimensions"
and isinstance(item.value, ast.Constant)
):
info["requires_r1_dims"] = item.value.value
# Parse supported_dtypes class variable
if (
isinstance(item, ast.AnnAssign)
and isinstance(item.target, ast.Name)
and item.target.id == "supported_dtypes"
and isinstance(item.value, ast.List)
):
dtype_map = {"float16": "fp16", "bfloat16": "bf16", "float32": "fp32"}
dtypes = []
for elt in item.value.elts:
if isinstance(elt, ast.Attribute):
dtypes.append(dtype_map.get(elt.attr, elt.attr))
if dtypes:
info["dtypes"] = ", ".join(dtypes)
# Parse get_name static method
get_name_method = find_method(class_node, "get_name")
if get_name_method:
for n in ast.walk(get_name_method):
if isinstance(n, ast.Return) and isinstance(n.value, ast.Constant):
info["name"] = n.value.value
# Parse supports_compute_capability classmethod
cc_method = find_method(class_node, "supports_compute_capability")
if cc_method:
for n in ast.walk(cc_method):
# Look for capability.major == 10 style checks
if (
isinstance(n, ast.Compare)
and isinstance(n.left, ast.Attribute)
and n.left.attr == "major"
and n.comparators
and isinstance(n.comparators[0], ast.Constant)
):
major = n.comparators[0].value
info["compute_capability"] = f"{major}.x"
return info
def parse_mla_prefill_backends() -> list[dict[str, Any]]:
"""Parse MLA prefill backend options from the prefill registry.
MLA uses different backends for prefill vs decode. The decode backends are
registered in the main registry, but prefill backends have their own
registry at vllm/v1/attention/backends/mla/prefill/registry.py.
Returns a list of prefill backend info dicts with their requirements.
"""
registry = parse_mla_prefill_registry()
priorities = parse_mla_prefill_priorities()
if not registry:
return []
# Get the priority order (Blackwell order shows all backends)
priority_order = priorities.get("blackwell", list(registry.keys()))
prefill_backends: list[dict[str, Any]] = []
# Backend-specific metadata that can't be easily parsed from code
backend_metadata = {
"TRTLLM_RAGGED": {
"description": "TensorRT-LLM ragged attention",
},
"FLASHINFER": {
"description": "FlashInfer CUTLASS backend",
},
"FLASH_ATTN": {
"description": "FlashAttention varlen (FA2/FA3/FA4)",
},
}
for backend_name in priority_order:
if backend_name not in registry:
continue
class_path = registry[backend_name]
backend_info = parse_mla_prefill_backend_file(class_path)
if backend_info is None:
continue
metadata = backend_metadata.get(backend_name, {})
display_name = backend_info.get("name", backend_name)
# Add marker for default Blackwell backend
marker = ""
if backend_name == priority_order[0] and priorities.get("blackwell"):
marker = ""
notes = ""
if backend_info.get("requires_r1_dims"):
notes = "DeepSeek R1 dims only"
elif backend_name == "FLASH_ATTN":
notes = "FA4 on SM100+, FA3 on SM90, FA2 otherwise"
prefill_backends.append(
{
"name": display_name,
"marker": marker,
"description": metadata.get("description", ""),
"dtypes": backend_info.get("dtypes", "fp16, bf16"),
"compute_capability": backend_info.get("compute_capability", "Any"),
"notes": notes,
}
)
return prefill_backends
# ---------------------------------------------------------------------------
# Backend feature extraction from AST
# ---------------------------------------------------------------------------
@@ -807,86 +1047,6 @@ def parse_flashinfer_trtllm_features() -> dict[str, dict[str, Any]]:
}
def parse_mla_prefill_backends() -> list[dict[str, Any]]:
"""Parse MLA prefill backend options from mla_attention.py.
MLA uses different backends for prefill vs decode. The decode backends are
registered in the registry, but prefill backends are selected at runtime
based on conditions in MLACommonImpl.__init__.
Returns a list of prefill backend info dicts with their requirements.
"""
if not MLA_ATTENTION_FILE.exists():
return []
try:
tree = ast.parse(MLA_ATTENTION_FILE.read_text())
except Exception:
return []
# Find compute capability requirements by parsing use_* functions
trtllm_cc = _find_cc_in_function(tree, "use_trtllm_ragged_deepseek_prefill")
flashinfer_cc = _find_cc_in_function(tree, "use_flashinfer_prefill")
cudnn_cc = _find_cc_in_function(tree, "use_cudnn_prefill")
# Build prefill backend list based on what we found
# Order matches the priority in MLACommonImpl.__init__
prefill_backends: list[dict[str, Any]] = []
# TRT-LLM Ragged (highest priority if available)
if trtllm_cc:
prefill_backends.append(
{
"name": "TRT-LLM Ragged‡",
"description": "TensorRT-LLM ragged attention",
"compute_capability": trtllm_cc,
"enable": "Default on SM100",
"disable": "`-ac.use_trtllm_ragged_deepseek_prefill=0`",
"notes": "DeepSeek R1 dims only",
}
)
# FlashInfer prefill
if flashinfer_cc:
prefill_backends.append(
{
"name": "FlashInfer",
"description": "FlashInfer CUTLASS backend",
"compute_capability": flashinfer_cc,
"enable": "`-ac.disable_flashinfer_prefill=0`",
"disable": "`-ac.disable_flashinfer_prefill=1`",
"notes": "DeepSeek R1 dims only",
}
)
# cuDNN prefill
if cudnn_cc:
prefill_backends.append(
{
"name": "cuDNN",
"description": "cuDNN-based attention",
"compute_capability": cudnn_cc,
"enable": "`-ac.use_cudnn_prefill=1`",
"disable": "`-ac.use_cudnn_prefill=0`",
"notes": "",
}
)
# FlashAttention is always available as fallback
prefill_backends.append(
{
"name": "FlashAttention",
"description": "FlashAttention varlen (FA2/FA3)",
"compute_capability": "Any",
"enable": "Default fallback",
"disable": "Use other backends",
"notes": "FA3 on SM90, FA2 otherwise",
}
)
return prefill_backends
# ---------------------------------------------------------------------------
# Backend variant expansion (FA2/FA3/FA4, FlashInfer native/TRTLLM)
# ---------------------------------------------------------------------------
@@ -1415,20 +1575,22 @@ def generate_mla_section(
"",
"### Prefill Backends",
"",
"The prefill backend is selected at runtime based on hardware and",
"configuration.",
"To explicitly select a prefill backend, use",
"`-ac.mla_prefill_backend=<BACKEND>` (e.g., `FLASH_ATTN`, `FLASHINFER`).",
"Otherwise, the prefill backend is selected automatically at runtime based on",
"hardware and configuration.",
"",
"| Backend | Description | Compute Cap. | Enable | Disable | Notes |",
"| ------- | ----------- | ------------ | ------ | ------- | ----- |",
"| Backend | Description | Dtypes | Compute Cap. | Notes |",
"| ------- | ----------- | ------ | ------------ | ----- |",
]
for backend in prefill_backends:
row = "| {} | {} | {} | {} | {} | {} |".format(
row = "| `{}`{} | {} | {} | {} | {} |".format(
backend["name"],
backend.get("marker", ""),
backend["description"],
backend.get("dtypes", "fp16, bf16"),
backend["compute_capability"],
backend["enable"],
backend["disable"],
backend.get("notes", ""),
)
lines.append(row.replace(" ", " "))
@@ -1441,6 +1603,9 @@ def generate_mla_section(
"",
"### Decode Backends",
"",
"MLA decode backends are selected using the standard",
"`-ac.backend=<BACKEND>` argument (e.g., `FLASHMLA`, `TRITON_MLA`).",
"",
]
)