mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[Attention] Abstract the MLA prefill backends and eliminate cuDNN (#32623)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
This commit is contained in:
@@ -30,7 +30,6 @@ REPO_ROOT = Path(__file__).parent.parent.parent
|
||||
RELEVANT_PATTERNS = [
|
||||
"vllm/v1/attention/backends/*.py",
|
||||
"vllm/v1/attention/backends/**/*.py",
|
||||
"vllm/v1/attention/backends/fa_utils.py",
|
||||
"vllm/model_executor/layers/attention/mla_attention.py",
|
||||
"vllm/platforms/cuda.py",
|
||||
"tools/pre_commit/generate_attention_backend_docs.py",
|
||||
@@ -68,6 +67,11 @@ def is_relevant_file(filepath: str) -> bool:
|
||||
return any(fnmatch.fnmatch(path_str, pattern) for pattern in RELEVANT_PATTERNS)
|
||||
|
||||
|
||||
MLA_PREFILL_DIR = BACKENDS_DIR / "mla" / "prefill"
|
||||
MLA_PREFILL_REGISTRY_FILE = MLA_PREFILL_DIR / "registry.py"
|
||||
MLA_PREFILL_SELECTOR_FILE = MLA_PREFILL_DIR / "selector.py"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AST utility helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -293,6 +297,242 @@ def get_file_from_class_path(class_path: str) -> Path | None:
|
||||
return py_file if py_file.exists() else None
|
||||
|
||||
|
||||
def parse_mla_prefill_registry() -> dict[str, str]:
|
||||
"""Parse MLAPrefillBackendEnum from the prefill registry.
|
||||
|
||||
Returns:
|
||||
A dict mapping backend names to their class paths.
|
||||
"""
|
||||
if not MLA_PREFILL_REGISTRY_FILE.exists():
|
||||
return {}
|
||||
|
||||
try:
|
||||
tree = ast.parse(MLA_PREFILL_REGISTRY_FILE.read_text())
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.ClassDef) and node.name == "MLAPrefillBackendEnum":
|
||||
return _extract_enum_values(node)
|
||||
return {}
|
||||
|
||||
|
||||
def parse_mla_prefill_priorities() -> dict[str, list[str]]:
|
||||
"""Parse MLA prefill backend priorities from selector.py.
|
||||
|
||||
Returns:
|
||||
A dict with keys like 'blackwell' and 'default' containing
|
||||
lists of backend enum names in priority order.
|
||||
"""
|
||||
if not MLA_PREFILL_SELECTOR_FILE.exists():
|
||||
return {}
|
||||
|
||||
try:
|
||||
tree = ast.parse(MLA_PREFILL_SELECTOR_FILE.read_text())
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
priorities: dict[str, list[str]] = {}
|
||||
|
||||
for node in ast.walk(tree):
|
||||
if not isinstance(node, ast.FunctionDef):
|
||||
continue
|
||||
if node.name != "_get_mla_prefill_backend_priorities":
|
||||
continue
|
||||
|
||||
# Look for if statements checking device_capability.major
|
||||
for stmt in ast.walk(node):
|
||||
if not isinstance(stmt, ast.If):
|
||||
continue
|
||||
|
||||
# Check if it's a capability.major == 10 check (Blackwell)
|
||||
is_blackwell = (
|
||||
isinstance(stmt.test, ast.Compare)
|
||||
and isinstance(stmt.test.left, ast.Attribute)
|
||||
and stmt.test.left.attr == "major"
|
||||
and stmt.test.comparators
|
||||
and isinstance(stmt.test.comparators[0], ast.Constant)
|
||||
and stmt.test.comparators[0].value == 10
|
||||
)
|
||||
|
||||
# Extract backends from return statements
|
||||
for body_stmt in stmt.body:
|
||||
if isinstance(body_stmt, ast.Return) and isinstance(
|
||||
body_stmt.value, ast.List
|
||||
):
|
||||
backends = []
|
||||
for elt in body_stmt.value.elts:
|
||||
if isinstance(elt, ast.Attribute):
|
||||
backends.append(elt.attr)
|
||||
if is_blackwell:
|
||||
priorities["blackwell"] = backends
|
||||
else:
|
||||
priorities["default"] = backends
|
||||
|
||||
# Extract from else branch
|
||||
for else_stmt in stmt.orelse:
|
||||
if isinstance(else_stmt, ast.Return) and isinstance(
|
||||
else_stmt.value, ast.List
|
||||
):
|
||||
backends = []
|
||||
for elt in else_stmt.value.elts:
|
||||
if isinstance(elt, ast.Attribute):
|
||||
backends.append(elt.attr)
|
||||
priorities["default"] = backends
|
||||
|
||||
return priorities
|
||||
|
||||
|
||||
def parse_mla_prefill_backend_file(class_path: str) -> dict[str, Any] | None:
|
||||
"""Parse a single MLA prefill backend file to extract its properties.
|
||||
|
||||
Args:
|
||||
class_path: The fully qualified class path.
|
||||
|
||||
Returns:
|
||||
A dict with backend properties, or None if parsing fails.
|
||||
"""
|
||||
file_path = get_file_from_class_path(class_path)
|
||||
if file_path is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
tree = ast.parse(file_path.read_text())
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
class_name = class_path.rsplit(".", 1)[1]
|
||||
class_node = find_class_in_ast(tree, class_name)
|
||||
if class_node is None:
|
||||
return None
|
||||
|
||||
info: dict[str, Any] = {
|
||||
"compute_capability": "Any",
|
||||
"requires_r1_dims": False,
|
||||
"dtypes": "fp16, bf16", # Default from base class
|
||||
}
|
||||
|
||||
# Parse class variables
|
||||
for item in class_node.body:
|
||||
if isinstance(item, ast.Assign):
|
||||
for target in item.targets:
|
||||
if (
|
||||
isinstance(target, ast.Name)
|
||||
and target.id == "requires_r1_mla_dimensions"
|
||||
and isinstance(item.value, ast.Constant)
|
||||
):
|
||||
info["requires_r1_dims"] = item.value.value
|
||||
|
||||
# Parse supported_dtypes class variable
|
||||
if (
|
||||
isinstance(item, ast.AnnAssign)
|
||||
and isinstance(item.target, ast.Name)
|
||||
and item.target.id == "supported_dtypes"
|
||||
and isinstance(item.value, ast.List)
|
||||
):
|
||||
dtype_map = {"float16": "fp16", "bfloat16": "bf16", "float32": "fp32"}
|
||||
dtypes = []
|
||||
for elt in item.value.elts:
|
||||
if isinstance(elt, ast.Attribute):
|
||||
dtypes.append(dtype_map.get(elt.attr, elt.attr))
|
||||
if dtypes:
|
||||
info["dtypes"] = ", ".join(dtypes)
|
||||
|
||||
# Parse get_name static method
|
||||
get_name_method = find_method(class_node, "get_name")
|
||||
if get_name_method:
|
||||
for n in ast.walk(get_name_method):
|
||||
if isinstance(n, ast.Return) and isinstance(n.value, ast.Constant):
|
||||
info["name"] = n.value.value
|
||||
|
||||
# Parse supports_compute_capability classmethod
|
||||
cc_method = find_method(class_node, "supports_compute_capability")
|
||||
if cc_method:
|
||||
for n in ast.walk(cc_method):
|
||||
# Look for capability.major == 10 style checks
|
||||
if (
|
||||
isinstance(n, ast.Compare)
|
||||
and isinstance(n.left, ast.Attribute)
|
||||
and n.left.attr == "major"
|
||||
and n.comparators
|
||||
and isinstance(n.comparators[0], ast.Constant)
|
||||
):
|
||||
major = n.comparators[0].value
|
||||
info["compute_capability"] = f"{major}.x"
|
||||
|
||||
return info
|
||||
|
||||
|
||||
def parse_mla_prefill_backends() -> list[dict[str, Any]]:
|
||||
"""Parse MLA prefill backend options from the prefill registry.
|
||||
|
||||
MLA uses different backends for prefill vs decode. The decode backends are
|
||||
registered in the main registry, but prefill backends have their own
|
||||
registry at vllm/v1/attention/backends/mla/prefill/registry.py.
|
||||
|
||||
Returns a list of prefill backend info dicts with their requirements.
|
||||
"""
|
||||
registry = parse_mla_prefill_registry()
|
||||
priorities = parse_mla_prefill_priorities()
|
||||
|
||||
if not registry:
|
||||
return []
|
||||
|
||||
# Get the priority order (Blackwell order shows all backends)
|
||||
priority_order = priorities.get("blackwell", list(registry.keys()))
|
||||
|
||||
prefill_backends: list[dict[str, Any]] = []
|
||||
|
||||
# Backend-specific metadata that can't be easily parsed from code
|
||||
backend_metadata = {
|
||||
"TRTLLM_RAGGED": {
|
||||
"description": "TensorRT-LLM ragged attention",
|
||||
},
|
||||
"FLASHINFER": {
|
||||
"description": "FlashInfer CUTLASS backend",
|
||||
},
|
||||
"FLASH_ATTN": {
|
||||
"description": "FlashAttention varlen (FA2/FA3/FA4)",
|
||||
},
|
||||
}
|
||||
|
||||
for backend_name in priority_order:
|
||||
if backend_name not in registry:
|
||||
continue
|
||||
|
||||
class_path = registry[backend_name]
|
||||
backend_info = parse_mla_prefill_backend_file(class_path)
|
||||
if backend_info is None:
|
||||
continue
|
||||
|
||||
metadata = backend_metadata.get(backend_name, {})
|
||||
display_name = backend_info.get("name", backend_name)
|
||||
|
||||
# Add marker for default Blackwell backend
|
||||
marker = ""
|
||||
if backend_name == priority_order[0] and priorities.get("blackwell"):
|
||||
marker = "‡"
|
||||
|
||||
notes = ""
|
||||
if backend_info.get("requires_r1_dims"):
|
||||
notes = "DeepSeek R1 dims only"
|
||||
elif backend_name == "FLASH_ATTN":
|
||||
notes = "FA4 on SM100+, FA3 on SM90, FA2 otherwise"
|
||||
|
||||
prefill_backends.append(
|
||||
{
|
||||
"name": display_name,
|
||||
"marker": marker,
|
||||
"description": metadata.get("description", ""),
|
||||
"dtypes": backend_info.get("dtypes", "fp16, bf16"),
|
||||
"compute_capability": backend_info.get("compute_capability", "Any"),
|
||||
"notes": notes,
|
||||
}
|
||||
)
|
||||
|
||||
return prefill_backends
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Backend feature extraction from AST
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -807,86 +1047,6 @@ def parse_flashinfer_trtllm_features() -> dict[str, dict[str, Any]]:
|
||||
}
|
||||
|
||||
|
||||
def parse_mla_prefill_backends() -> list[dict[str, Any]]:
|
||||
"""Parse MLA prefill backend options from mla_attention.py.
|
||||
|
||||
MLA uses different backends for prefill vs decode. The decode backends are
|
||||
registered in the registry, but prefill backends are selected at runtime
|
||||
based on conditions in MLACommonImpl.__init__.
|
||||
|
||||
Returns a list of prefill backend info dicts with their requirements.
|
||||
"""
|
||||
if not MLA_ATTENTION_FILE.exists():
|
||||
return []
|
||||
|
||||
try:
|
||||
tree = ast.parse(MLA_ATTENTION_FILE.read_text())
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
# Find compute capability requirements by parsing use_* functions
|
||||
trtllm_cc = _find_cc_in_function(tree, "use_trtllm_ragged_deepseek_prefill")
|
||||
flashinfer_cc = _find_cc_in_function(tree, "use_flashinfer_prefill")
|
||||
cudnn_cc = _find_cc_in_function(tree, "use_cudnn_prefill")
|
||||
|
||||
# Build prefill backend list based on what we found
|
||||
# Order matches the priority in MLACommonImpl.__init__
|
||||
prefill_backends: list[dict[str, Any]] = []
|
||||
|
||||
# TRT-LLM Ragged (highest priority if available)
|
||||
if trtllm_cc:
|
||||
prefill_backends.append(
|
||||
{
|
||||
"name": "TRT-LLM Ragged‡",
|
||||
"description": "TensorRT-LLM ragged attention",
|
||||
"compute_capability": trtllm_cc,
|
||||
"enable": "Default on SM100",
|
||||
"disable": "`-ac.use_trtllm_ragged_deepseek_prefill=0`",
|
||||
"notes": "DeepSeek R1 dims only",
|
||||
}
|
||||
)
|
||||
|
||||
# FlashInfer prefill
|
||||
if flashinfer_cc:
|
||||
prefill_backends.append(
|
||||
{
|
||||
"name": "FlashInfer",
|
||||
"description": "FlashInfer CUTLASS backend",
|
||||
"compute_capability": flashinfer_cc,
|
||||
"enable": "`-ac.disable_flashinfer_prefill=0`",
|
||||
"disable": "`-ac.disable_flashinfer_prefill=1`",
|
||||
"notes": "DeepSeek R1 dims only",
|
||||
}
|
||||
)
|
||||
|
||||
# cuDNN prefill
|
||||
if cudnn_cc:
|
||||
prefill_backends.append(
|
||||
{
|
||||
"name": "cuDNN",
|
||||
"description": "cuDNN-based attention",
|
||||
"compute_capability": cudnn_cc,
|
||||
"enable": "`-ac.use_cudnn_prefill=1`",
|
||||
"disable": "`-ac.use_cudnn_prefill=0`",
|
||||
"notes": "",
|
||||
}
|
||||
)
|
||||
|
||||
# FlashAttention is always available as fallback
|
||||
prefill_backends.append(
|
||||
{
|
||||
"name": "FlashAttention",
|
||||
"description": "FlashAttention varlen (FA2/FA3)",
|
||||
"compute_capability": "Any",
|
||||
"enable": "Default fallback",
|
||||
"disable": "Use other backends",
|
||||
"notes": "FA3 on SM90, FA2 otherwise",
|
||||
}
|
||||
)
|
||||
|
||||
return prefill_backends
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Backend variant expansion (FA2/FA3/FA4, FlashInfer native/TRTLLM)
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -1415,20 +1575,22 @@ def generate_mla_section(
|
||||
"",
|
||||
"### Prefill Backends",
|
||||
"",
|
||||
"The prefill backend is selected at runtime based on hardware and",
|
||||
"configuration.",
|
||||
"To explicitly select a prefill backend, use",
|
||||
"`-ac.mla_prefill_backend=<BACKEND>` (e.g., `FLASH_ATTN`, `FLASHINFER`).",
|
||||
"Otherwise, the prefill backend is selected automatically at runtime based on",
|
||||
"hardware and configuration.",
|
||||
"",
|
||||
"| Backend | Description | Compute Cap. | Enable | Disable | Notes |",
|
||||
"| ------- | ----------- | ------------ | ------ | ------- | ----- |",
|
||||
"| Backend | Description | Dtypes | Compute Cap. | Notes |",
|
||||
"| ------- | ----------- | ------ | ------------ | ----- |",
|
||||
]
|
||||
|
||||
for backend in prefill_backends:
|
||||
row = "| {} | {} | {} | {} | {} | {} |".format(
|
||||
row = "| `{}`{} | {} | {} | {} | {} |".format(
|
||||
backend["name"],
|
||||
backend.get("marker", ""),
|
||||
backend["description"],
|
||||
backend.get("dtypes", "fp16, bf16"),
|
||||
backend["compute_capability"],
|
||||
backend["enable"],
|
||||
backend["disable"],
|
||||
backend.get("notes", ""),
|
||||
)
|
||||
lines.append(row.replace(" ", " "))
|
||||
@@ -1441,6 +1603,9 @@ def generate_mla_section(
|
||||
"",
|
||||
"### Decode Backends",
|
||||
"",
|
||||
"MLA decode backends are selected using the standard",
|
||||
"`-ac.backend=<BACKEND>` argument (e.g., `FLASHMLA`, `TRITON_MLA`).",
|
||||
"",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user