[MLA Attention Backend] Add TOKENSPEED_MLA backend for DSR1/Kimi K25 prefill + decode on Blackwell (#41778)

Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Yongye Zhu
2026-05-14 02:48:02 -04:00
committed by GitHub
parent fd7d858c8a
commit 0d2732dd91
14 changed files with 639 additions and 88 deletions
@@ -53,6 +53,7 @@ backends:
- FLASHINFER_MLA
- FLASH_ATTN_MLA # Hopper only
- FLASHMLA # Hopper only
- TOKENSPEED_MLA # Blackwell + R1 dims + FP8 KV (use --kv-cache-dtype fp8)
device: "cuda:0"
repeats: 100
@@ -3,6 +3,7 @@
# Compares all available MLA prefill backends:
# FA backends: fa2, fa3, fa4 (FlashAttention versions)
# Non-FA: flashinfer, cudnn, trtllm (Blackwell-only, require flashinfer)
# CuTe DSL: tokenspeed (Blackwell + R1 dims, requires tokenspeed_mla)
#
# Uses cutlass_mla as the decode backend for impl construction
# (only the prefill path is exercised).
@@ -120,6 +121,7 @@ prefill_backends:
- flashinfer
- cudnn
- trtllm
- tokenspeed
device: "cuda:0"
repeats: 20
+67 -63
View File
@@ -179,19 +179,27 @@ def create_minimal_vllm_config(
if prefill_backend is not None:
prefill_cfg = get_prefill_backend_config(prefill_backend)
if prefill_cfg["flash_attn_version"] is not None:
vllm_config.attention_config.flash_attn_version = prefill_cfg[
"flash_attn_version"
if prefill_cfg.get("mla_prefill_backend_enum") is not None:
# Registry-based backends bypass the deprecated boolean flags.
from vllm.v1.attention.backends.mla.prefill import MLAPrefillBackendEnum
vllm_config.attention_config.mla_prefill_backend = MLAPrefillBackendEnum[
prefill_cfg["mla_prefill_backend_enum"]
]
vllm_config.attention_config.disable_flashinfer_prefill = prefill_cfg[
"disable_flashinfer_prefill"
]
vllm_config.attention_config.use_cudnn_prefill = prefill_cfg[
"use_cudnn_prefill"
]
vllm_config.attention_config.use_trtllm_ragged_deepseek_prefill = prefill_cfg[
"use_trtllm_ragged_deepseek_prefill"
]
else:
if prefill_cfg["flash_attn_version"] is not None:
vllm_config.attention_config.flash_attn_version = prefill_cfg[
"flash_attn_version"
]
vllm_config.attention_config.disable_flashinfer_prefill = prefill_cfg[
"disable_flashinfer_prefill"
]
vllm_config.attention_config.use_cudnn_prefill = prefill_cfg[
"use_cudnn_prefill"
]
vllm_config.attention_config.use_trtllm_ragged_deepseek_prefill = (
prefill_cfg["use_trtllm_ragged_deepseek_prefill"]
)
return vllm_config
@@ -223,22 +231,17 @@ _PREFILL_BACKEND_CONFIG: dict[str, dict] = {
"use_trtllm_ragged_deepseek_prefill": False,
},
"flashinfer": {
"flash_attn_version": None,
"disable_flashinfer_prefill": False,
"use_cudnn_prefill": False,
"use_trtllm_ragged_deepseek_prefill": False,
"mla_prefill_backend_enum": "FLASHINFER",
},
"cudnn": {
"flash_attn_version": None,
"disable_flashinfer_prefill": True,
"use_cudnn_prefill": True,
"use_trtllm_ragged_deepseek_prefill": False,
# cuDNN prefill backend was removed; AttentionConfig raises on use.
"mla_prefill_backend_enum": "FLASHINFER",
},
"trtllm": {
"flash_attn_version": None,
"disable_flashinfer_prefill": True,
"use_cudnn_prefill": False,
"use_trtllm_ragged_deepseek_prefill": True,
"mla_prefill_backend_enum": "TRTLLM_RAGGED",
},
"tokenspeed": {
"mla_prefill_backend_enum": "TOKENSPEED_MLA",
},
}
@@ -625,6 +628,21 @@ def _create_backend_impl(
# Create mock layer
layer = MockLayer(device, impl=impl, kv_cache_spec=kv_cache_spec)
# Attach a prefill backend (MLAAttention does this in __init__; the metadata
# builder reads layer.prefill_backend from static_forward_context).
from vllm.v1.attention.backends.mla.prefill import get_mla_prefill_backend
prefill_backend_cls = get_mla_prefill_backend(vllm_config)
layer.prefill_backend = prefill_backend_cls(
num_heads=mla_dims["num_q_heads"],
scale=(mla_dims["qk_nope_head_dim"] + mla_dims["qk_rope_head_dim"]) ** -0.5,
kv_lora_rank=mla_dims["kv_lora_rank"],
qk_nope_head_dim=mla_dims["qk_nope_head_dim"],
qk_rope_head_dim=mla_dims["qk_rope_head_dim"],
v_head_dim=mla_dims["v_head_dim"],
vllm_config=vllm_config,
)
# Create builder instance if needed
builder_instance = None
if builder_class:
@@ -961,19 +979,6 @@ def _run_mla_benchmark_batched(
results = []
with set_current_vllm_config(vllm_config):
# Clear cached prefill backend detection functions so they re-evaluate
# with the current VllmConfig. These are @functools.cache decorated and
# would otherwise return stale results from a previous backend's config.
from vllm.model_executor.layers.attention.mla_attention import (
use_cudnn_prefill,
use_flashinfer_prefill,
use_trtllm_ragged_deepseek_prefill,
)
use_flashinfer_prefill.cache_clear()
use_cudnn_prefill.cache_clear()
use_trtllm_ragged_deepseek_prefill.cache_clear()
# Create backend impl, layer, builder, and indexer (reused across benchmarks)
impl, layer, builder_instance, indexer = _create_backend_impl(
backend_cfg,
@@ -985,36 +990,35 @@ def _run_mla_benchmark_batched(
kv_cache_dtype=kv_cache_dtype,
)
# Verify the actual prefill backend matches what was requested
# Verify the actual prefill backend matches what was requested. The
# selector + impl construction already raise on misuse; here we just
# check the resolved class against the requested name as a sanity guard.
if prefill_backend is not None:
prefill_cfg = get_prefill_backend_config(prefill_backend)
fa_version = prefill_cfg["flash_attn_version"]
if fa_version is not None:
# FA backend: verify the impl's FA version
actual_fa_version = getattr(impl, "vllm_flash_attn_version", None)
expected_class = {
"fa2": "FlashAttnPrefillBackend",
"fa3": "FlashAttnPrefillBackend",
"fa4": "FlashAttnPrefillBackend",
"flashinfer": "FlashInferPrefillBackend",
"trtllm": "TrtllmRaggedPrefillBackend",
"tokenspeed": "TokenspeedMLAPrefillBackend",
}.get(prefill_backend)
actual_class = type(getattr(layer, "prefill_backend", None)).__name__
if expected_class and actual_class != expected_class:
raise RuntimeError(
f"Prefill backend '{prefill_backend}' requested "
f"{expected_class}, got {actual_class}. Check "
f"attention_config plumbing or installed deps."
)
if prefill_backend in {"fa2", "fa3", "fa4"}:
fa_version = int(prefill_backend[2:])
actual_fa_version = getattr(
layer.prefill_backend, "vllm_flash_attn_version", None
)
if actual_fa_version != fa_version:
raise RuntimeError(
f"Prefill backend '{prefill_backend}' requested FA "
f"version {fa_version}, but the impl is using FA "
f"version {actual_fa_version}. Check "
f"vllm/v1/attention/backends/fa_utils.py."
)
else:
# Non-FA backend: verify the builder picked the right path
expected_flags = {
"flashinfer": "_use_fi_prefill",
"cudnn": "_use_cudnn_prefill",
"trtllm": "_use_trtllm_ragged_prefill",
}
flag_name = expected_flags.get(prefill_backend)
if flag_name and not getattr(builder_instance, flag_name, False):
raise RuntimeError(
f"Prefill backend '{prefill_backend}' was requested "
f"but the metadata builder did not enable it. This "
f"usually means a dependency is missing (e.g., "
f"flashinfer not installed) or the platform doesn't "
f"support it."
f"version {fa_version}, got "
f"{actual_fa_version} on {actual_class}."
)
# Run each benchmark with the shared impl
+9 -6
View File
@@ -125,12 +125,13 @@ Priority is **1 = highest** (tried first).
| Priority | Backend |
| -------- | ------- |
| 1 | `FLASHINFER_MLA` |
| 2 | `CUTLASS_MLA` |
| 3 | `FLASH_ATTN_MLA` |
| 4 | `FLASHMLA` |
| 5 | `TRITON_MLA` |
| 6 | `FLASHINFER_MLA_SPARSE`**\*** |
| 7 | `FLASHMLA_SPARSE` |
| 2 | `TOKENSPEED_MLA` |
| 3 | `CUTLASS_MLA` |
| 4 | `FLASH_ATTN_MLA` |
| 5 | `FLASHMLA` |
| 6 | `TRITON_MLA` |
| 7 | `FLASHINFER_MLA_SPARSE`**\*** |
| 8 | `FLASHMLA_SPARSE` |
**Ampere/Hopper (SM 8.x-9.x):**
@@ -202,6 +203,7 @@ hardware and configuration.
| `FLASH_ATTN`‡ | FlashAttention varlen (FA2/FA3/FA4) | fp16, bf16 | Any | FA4 on SM100+, FA3 on SM90, FA2 otherwise |
| `TRTLLM_RAGGED` | TensorRT-LLM ragged attention | fp16, bf16 | 10.x | DeepSeek R1 dims only |
| `FLASHINFER` | FlashInfer CUTLASS backend | fp16, bf16 | 10.x | DeepSeek R1 dims only |
| `TOKENSPEED_MLA` | | fp16, bf16 | 10.x | DeepSeek R1 dims only |
> **‡** TRT-LLM Ragged is the default on Blackwell (SM100).
> On other GPUs, FlashAttention is used as the default.
@@ -222,5 +224,6 @@ MLA decode backends are selected using the standard
| `ROCM_AITER_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %1 | Any | ❌ | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
| `ROCM_AITER_MLA_SPARSE` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | 1, 64 | Any | ❌ | ❌ | ✅ | ❌ | ❌ | Decoder | N/A |
| `ROCM_AITER_TRITON_MLA` | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
| `TOKENSPEED_MLA` | fp16, bf16 | `fp8`, `fp8_e4m3` | 32, 64 | Any | ❌ | ❌ | ❌ | ❌ | ❌ | Decoder | 10.x |
| `TRITON_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | %16 | Any | ❌ | ❌ | ❌ | ❌ | ✅ | Decoder | Any |
| `XPU_MLA_SPARSE` | fp16, bf16 | `auto`, `float16`, `bfloat16` | Any | 576 | ❌ | ❌ | ✅ | ❌ | ❌ | Decoder | Any |
+3
View File
@@ -23,3 +23,6 @@ fastsafetensors >= 0.2.2
# QuACK and Cutlass DSL for FA4 (cute-DSL implementation)
nvidia-cutlass-dsl[cu13]>=4.4.2
quack-kernels>=0.3.3
# Tokenspeed_MLA for faster mla with spec decode
tokenspeed-mla==0.1.2
+21 -12
View File
@@ -224,19 +224,28 @@ def init_test_http_connection():
def dist_init():
from tests.utils import ensure_current_vllm_config
temp_file = tempfile.mkstemp()[1]
# Close the fd returned by mkstemp; FileStore opens the path itself.
# Leaving it open leaks one FD per test and eventually exhausts the
# ulimit, causing FileStore's destructor to throw c10::DistStoreError
# ("Too many open files") during gc and abort the process.
fd, temp_file = tempfile.mkstemp()
os.close(fd)
with ensure_current_vllm_config():
init_distributed_environment(
world_size=1,
rank=0,
distributed_init_method=f"file://{temp_file}",
local_rank=0,
backend="nccl",
)
initialize_model_parallel(1, 1)
yield
cleanup_dist_env_and_memory()
try:
with ensure_current_vllm_config():
init_distributed_environment(
world_size=1,
rank=0,
distributed_init_method=f"file://{temp_file}",
local_rank=0,
backend="nccl",
)
initialize_model_parallel(1, 1)
yield
cleanup_dist_env_and_memory()
finally:
with contextlib.suppress(OSError):
os.unlink(temp_file)
@pytest.fixture
+66 -7
View File
@@ -20,17 +20,20 @@ from tests.v1.attention.utils import (
from vllm import _custom_ops as ops
from vllm.config.vllm import set_current_vllm_config
from vllm.model_executor.layers.attention.mla_attention import (
MLAAttention,
QueryLenSupport,
_DecodeConcatQuantFP8,
)
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.attention.backend import CommonAttentionMetadata
from vllm.v1.attention.backends.fa_utils import flash_attn_supports_mla
from vllm.v1.attention.backends.mla.prefill import get_mla_prefill_backend
from vllm.v1.attention.backends.mla.prefill import (
MLAPrefillBackendEnum,
get_mla_prefill_backend,
)
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.ops.flashmla import is_flashmla_dense_supported
from vllm.v1.kv_cache_interface import MLAAttentionSpec
@@ -41,6 +44,7 @@ BACKENDS_TO_TEST = [
AttentionBackendEnum.FLASH_ATTN_MLA,
AttentionBackendEnum.FLASHINFER_MLA,
AttentionBackendEnum.TRITON_MLA,
AttentionBackendEnum.TOKENSPEED_MLA,
]
DEVICE_TYPE = current_platform.device_type
@@ -49,6 +53,7 @@ DEVICE_TYPE = current_platform.device_type
if not torch.cuda.is_available() or torch.cuda.get_device_properties(0).major < 10:
BACKENDS_TO_TEST.remove(AttentionBackendEnum.CUTLASS_MLA)
BACKENDS_TO_TEST.remove(AttentionBackendEnum.FLASHINFER_MLA)
BACKENDS_TO_TEST.remove(AttentionBackendEnum.TOKENSPEED_MLA)
# Remove FLASH_ATTN_MLA from the list if not supported
if not flash_attn_supports_mla():
@@ -58,6 +63,22 @@ if not flash_attn_supports_mla():
if not is_flashmla_dense_supported()[0]:
BACKENDS_TO_TEST.remove(AttentionBackendEnum.FLASHMLA)
# Remove TOKENSPEED_MLA if the optional package is not installed
if AttentionBackendEnum.TOKENSPEED_MLA in BACKENDS_TO_TEST:
try:
import tokenspeed_mla # noqa: F401
except ImportError:
BACKENDS_TO_TEST.remove(AttentionBackendEnum.TOKENSPEED_MLA)
# Filtered per-test via validate_configuration (capability/deps/dims).
PREFILL_BACKENDS_TO_TEST = [
MLAPrefillBackendEnum.FLASH_ATTN,
MLAPrefillBackendEnum.FLASHINFER,
MLAPrefillBackendEnum.TRTLLM_RAGGED,
MLAPrefillBackendEnum.TOKENSPEED_MLA,
]
SPEC_DECODE_BACKENDS = []
for backend in BACKENDS_TO_TEST:
@@ -389,14 +410,18 @@ class MockSparseMLAAttentionLayer:
return output
class MockMLAAttentionLayer(AttentionLayerBase):
class MockMLAAttentionLayer(MLAAttention):
"""A mock MLA attention layer for testing.
This replicates the forward_impl logic from MLAAttention to allow
testing MLA backends without the full layer infrastructure.
The W_UK_T and W_UV weight matrices are created on the layer (like in
MLAAttention.process_weights_after_loading), not on the impl.
Subclasses MLAAttention so that backends that filter
`static_forward_context` by `isinstance(layer, MLAAttention)` (e.g.
FlashInfer prefill, which reads sm_scale through that filter) see the
mock as a real MLA layer. MLAAttention.__init__ is intentionally
skipped it would create its own impl/prefill_backend and self-register
in static_forward_context, which fights what the test sets up below.
"""
def __init__(
@@ -412,6 +437,7 @@ class MockMLAAttentionLayer(AttentionLayerBase):
q_scale: float,
k_scale: float,
):
torch.nn.Module.__init__(self)
self.impl = impl
self.num_heads = num_heads
self.qk_nope_head_dim = qk_nope_head_dim
@@ -562,11 +588,15 @@ def run_attention_backend(
q_scale: float,
k_scale: float,
kv_cache_dtype: str = "auto",
prefill_backend: MLAPrefillBackendEnum | None = None,
) -> torch.Tensor:
"""Run attention computation using the specified backend's AttentionImpl."""
builder_cls, impl_cls = try_get_attention_backend(backend)
# Force the prefill backend selection (None means auto-select).
vllm_config.attention_config.mla_prefill_backend = prefill_backend
# Set the current vllm config so that get_current_vllm_config() works
# in the backend implementations
with set_current_vllm_config(vllm_config):
@@ -578,7 +608,11 @@ def run_attention_backend(
vllm_config.parallel_config
)
head_size = vllm_config.model_config.get_head_size()
scale = 1.0 / (head_size**0.5)
# Production MLA passes 1/sqrt(qk_head_dim) (the prefill scale) to the
# impl and forwards the same value to the prefill backend. FLASHINFER
# prefill reads sm_scale back from impl.scale via global_hyperparameters
# at plan() time, so impl.scale must agree with prefill_backend.scale.
scale = (qk_nope_head_dim + qk_rope_head_dim) ** -0.5
impl = impl_cls(
num_heads=num_heads,
head_size=head_size,
@@ -683,6 +717,7 @@ def run_attention_backend(
@pytest.mark.parametrize("tensor_parallel_size", [1, 4, 8, 16])
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8", "fp8_e4m3"])
@pytest.mark.parametrize(("q_scale", "k_scale"), [(1.0, 1.0), (2.0, 3.0)])
@pytest.mark.parametrize("prefill_backend", PREFILL_BACKENDS_TO_TEST)
def test_backend_correctness(
default_vllm_config,
dist_init,
@@ -693,6 +728,7 @@ def test_backend_correctness(
kv_cache_dtype: str,
q_scale: float,
k_scale: float,
prefill_backend: MLAPrefillBackendEnum,
):
"""
Test that all backends produce similar outputs to a reference implementation
@@ -729,6 +765,24 @@ def test_backend_correctness(
if not backends_to_test:
pytest.skip(f"No backends support kv_cache_dtype={kv_cache_dtype}")
# Skip prefill backends that can't satisfy capability/deps/R1 constraints.
from vllm.v1.attention.backends.mla.prefill.selector import (
MLAPrefillSelectorConfig,
)
try:
prefill_invalid_reasons = prefill_backend.get_class().validate_configuration(
current_platform.get_device_capability(),
MLAPrefillSelectorConfig(dtype=torch.bfloat16, is_r1_compatible=True),
)
except ImportError:
prefill_invalid_reasons = ["ImportError"]
if prefill_invalid_reasons:
pytest.skip(
f"Prefill backend {prefill_backend.name} unavailable: "
f"{prefill_invalid_reasons}"
)
batch_spec = BATCH_SPECS[batch_spec_name]
is_spec_decode_test = batch_spec_name.startswith("spec_decode")
unique_block_sizes = sorted(set(BACKEND_BLOCK_SIZES[b] for b in backends_to_test))
@@ -799,9 +853,13 @@ def test_backend_correctness(
assert kv_lora_rank + qk_rope_head_dim == head_size, (
f"MLA dimensions don't match: {total_head_size} != {head_size}"
)
decode_scale = 1.0 / (total_head_size**0.5)
qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
prefill_scale = qk_head_dim**-0.5
# MLA reuses prefill_scale for the decode path: production sets
# impl.scale = 1/sqrt(qk_head_dim) and the decode kernels apply it even
# though the latent attention runs at head_size dimensions. Keeping the
# reference here in sync with run_attention_backend's impl.scale.
decode_scale = prefill_scale
# 2. Generate data and compute SDPA reference output for MLA
all_q_vllm, all_kv_c_vllm, all_k_pe_vllm = [], [], []
@@ -1092,6 +1150,7 @@ def test_backend_correctness(
qk_rope_head_dim,
v_head_dim,
mock_kv_b_proj,
prefill_backend=prefill_backend,
q_scale=q_scale,
k_scale=k_scale,
kv_cache_dtype=kv_cache_dtype,
@@ -1362,6 +1362,7 @@ def backend_supports_prefill_query_quantization() -> bool:
return backend_cls.get_name() in (
"FLASHINFER",
"TRTLLM_RAGGED",
"TOKENSPEED_MLA",
)
+4
View File
@@ -110,6 +110,10 @@ def _get_backend_priorities(
return [
AttentionBackendEnum.FLASHINFER_MLA,
# R1 dims + FP8 KV only; rejected by supports_combination
# otherwise. Behind FLASHINFER_MLA: wins past bs≈8, regresses
# at bs≤2.
AttentionBackendEnum.TOKENSPEED_MLA,
AttentionBackendEnum.CUTLASS_MLA,
AttentionBackendEnum.FLASH_ATTN_MLA,
AttentionBackendEnum.FLASHMLA,
@@ -43,6 +43,10 @@ class MLAPrefillBackendEnum(Enum, metaclass=_MLAPrefillBackendEnumMeta):
"vllm.v1.attention.backends.mla.prefill.trtllm_ragged."
"TrtllmRaggedPrefillBackend"
)
TOKENSPEED_MLA = (
"vllm.v1.attention.backends.mla.prefill.tokenspeed_mla."
"TokenspeedMLAPrefillBackend"
)
def get_path(self) -> str:
"""Get the fully qualified class path for this backend."""
@@ -67,6 +67,7 @@ def _get_mla_prefill_backend_priorities(
MLAPrefillBackendEnum.FLASH_ATTN,
MLAPrefillBackendEnum.TRTLLM_RAGGED,
MLAPrefillBackendEnum.FLASHINFER,
MLAPrefillBackendEnum.TOKENSPEED_MLA,
]
else: # Hopper (SM90) and older
return [
@@ -0,0 +1,180 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""TokenSpeed CuTe DSL backend for MLA prefill."""
from typing import TYPE_CHECKING
import torch
from vllm.v1.attention.backends.mla.prefill.base import MLAPrefillBackend
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.model_executor.layers.attention.mla_attention import (
MLACommonPrefillMetadata,
)
from vllm.platforms.interface import DeviceCapability
class TokenspeedMLAPrefillBackend(MLAPrefillBackend):
"""TokenSpeed CuTe DSL backend for MLA prefill."""
requires_r1_mla_dimensions = True
@staticmethod
def get_name() -> str:
return "TOKENSPEED_MLA"
@classmethod
def supports_compute_capability(cls, device_capability: "DeviceCapability") -> bool:
return device_capability.major == 10
_INSTALL_HINT = (
"tokenspeed_mla package is not installed. "
"Install it with: `uv pip install tokenspeed-mla`"
)
@classmethod
def is_available(cls) -> bool:
try:
from tokenspeed_mla import (
tokenspeed_mla_prefill, # noqa: F401
)
return True
except ImportError:
return False
@classmethod
def validate_configuration(
cls,
device_capability,
selector_config,
) -> list[str]:
# Replace the generic "required dependencies not available" message
# from the base class with a specific install hint so users know
# exactly which package to install when they explicitly select this
# backend without having tokenspeed_mla installed.
reasons = super().validate_configuration(device_capability, selector_config)
return [
cls._INSTALL_HINT if r == "required dependencies not available" else r
for r in reasons
]
def __init__(
self,
num_heads: int,
scale: float,
kv_lora_rank: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
v_head_dim: int,
vllm_config: "VllmConfig",
) -> None:
super().__init__(
num_heads=num_heads,
scale=scale,
kv_lora_rank=kv_lora_rank,
qk_nope_head_dim=qk_nope_head_dim,
qk_rope_head_dim=qk_rope_head_dim,
v_head_dim=v_head_dim,
vllm_config=vllm_config,
)
# Pre-JIT BF16 and FP8 prefill kernels. Idempotent — also called from
# TokenspeedMLAImpl.__init__; second call is a no-op.
from tokenspeed_mla import warmup_compile_prefill
for q_dtype in (torch.bfloat16, torch.float8_e4m3fn):
warmup_compile_prefill(
q_dtype=q_dtype,
d_qk=qk_nope_head_dim + qk_rope_head_dim,
d_v=v_head_dim,
enable_pdl=False,
)
def prepare_metadata(
self,
prefill_metadata: "MLACommonPrefillMetadata",
) -> None:
super().prepare_metadata(prefill_metadata)
# Kernel signature requires `seq_lens` but the implementation never reads
# it (per-batch lengths are derived from `cum_seq_lens` diffs); compute
# for parity with trtllm_ragged. cuda-graph padding in
# `query_start_loc` is saturated to `total_num_tokens`
# (gpu_model_runner.py:1905), so trailing diffs are 0 and padded batches
# are kernel no-ops — same reason trtllm passes the padded length as
# batch_size directly.
self._query_seq_lens = (
prefill_metadata.query_start_loc[1:] - prefill_metadata.query_start_loc[:-1]
)
def run_prefill_new_tokens(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
return_softmax_lse: bool,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from tokenspeed_mla import tokenspeed_mla_prefill
# `v` arrives as the second half of `kv_nope.split(...)` in
# mla_attention.forward_mha — a non-contiguous view of `kv_nope` along
# dim=-1. The kernel does `v.reshape(1, total_kv, h_k, 1, d_v)` which
# would silently copy on a non-contiguous tensor; force contiguity here
# so the copy (if any) happens once outside the kernel call.
v = v.contiguous()
ret = tokenspeed_mla_prefill(
query=q,
key=k,
value=v,
seq_lens=self._query_seq_lens,
cum_seq_lens=self._prefill_metadata.query_start_loc,
max_seq_len=self._prefill_metadata.max_query_len,
batch_size=self._query_seq_lens.shape[0],
softmax_scale=self.scale,
is_causal=True,
return_lse=return_softmax_lse,
enable_pdl=False,
)
if isinstance(ret, tuple):
# Convert from (q_len, num_heads) to (num_heads, q_len)
return ret[0], ret[1].transpose(0, 1).contiguous()
return ret
def run_prefill_context_chunk(
self,
chunk_idx: int,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
from tokenspeed_mla import tokenspeed_mla_prefill
assert self._prefill_metadata.chunked_context is not None
chunked = self._prefill_metadata.chunked_context
# See note in run_prefill_new_tokens — `v` is a split-view of `kv_nope`
# in `_compute_prefill_context` and arrives non-contiguous.
v = v.contiguous()
attn_out, lse = tokenspeed_mla_prefill(
query=q,
key=k,
value=v,
seq_lens=chunked.seq_lens[chunk_idx],
cum_seq_lens=chunked.cu_seq_lens[chunk_idx],
max_seq_len=chunked.max_seq_lens[chunk_idx],
batch_size=chunked.seq_lens[chunk_idx].shape[0],
softmax_scale=self.scale,
is_causal=False,
return_lse=True,
cum_seq_lens_q=self._prefill_metadata.query_start_loc,
max_seq_len_q=self._prefill_metadata.max_query_len,
enable_pdl=False,
)
# Convert from (q_len, num_heads) to (num_heads, q_len)
return attn_out, lse.transpose(0, 1).contiguous()
@@ -0,0 +1,277 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""TokenSpeed CuTe DSL MLA decode backend (Blackwell, FP8 KV cache only)."""
from typing import ClassVar
import torch
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.model_executor.layers.attention.mla_attention import (
MLACommonBackend,
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder,
QueryLenSupport,
)
from vllm.platforms.interface import DeviceCapability
from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.attention.backend import (
AttentionCGSupport,
AttentionLayer,
AttentionType,
MultipleOf,
)
from vllm.v1.attention.backends.utils import KVCacheLayoutType
logger = init_logger(__name__)
# Workspace upper bound for tokenspeed_mla_decode (per-device, lazy):
# num_sms * num_heads * MAX_Q_LEN * (kv_lora_rank + 1) * sizeof(float32)
# Matches the kernel's `get_workspace_size` formula. MAX_Q_LEN=8 covers up to
# EAGLE3 / MTP-2 spec decoding query lengths; larger q_len fails the kernel's
# own buffer check.
_TOKENSPEED_MAX_Q_LEN = 8
_g_workspace: dict[torch.device, torch.Tensor] = {}
def _get_workspace(
device: torch.device, num_heads: int, kv_lora_rank: int
) -> torch.Tensor:
from tokenspeed_mla import get_num_sm
needed = (
get_num_sm(device) * num_heads * _TOKENSPEED_MAX_Q_LEN * (kv_lora_rank + 1) * 4
)
existing = _g_workspace.get(device)
if existing is None or existing.numel() < needed:
_g_workspace[device] = torch.empty(needed, dtype=torch.int8, device=device)
return _g_workspace[device]
class TokenspeedMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM
class TokenspeedMLABackend(MLACommonBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"fp8",
"fp8_e4m3",
]
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [32, 64]
@staticmethod
def get_name() -> str:
return "TOKENSPEED_MLA"
@staticmethod
def get_impl_cls() -> type["TokenspeedMLAImpl"]:
return TokenspeedMLAImpl
@staticmethod
def get_builder_cls() -> type["TokenspeedMLAMetadataBuilder"]:
return TokenspeedMLAMetadataBuilder
@classmethod
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
return capability.major == 10
@classmethod
def supports_combination(
cls,
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: CacheDType | None,
block_size: int | None,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
device_capability: DeviceCapability,
) -> str | None:
# Surface a clear install hint up front rather than letting a raw
# ModuleNotFoundError fire deep inside `forward_mqa` at first request.
try:
import tokenspeed_mla # noqa: F401
except ImportError:
return (
"tokenspeed_mla package is not installed. "
"Install it with: `uv pip install tokenspeed-mla`"
)
# tokenspeed_mla CuTe DSL kernel is shape-specialized for DeepSeek R1
# MLA dimensions (qk_nope=128, qk_rope=64, v=128). Reject anything else.
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
if vllm_config.model_config is not None:
hf_text_config = vllm_config.model_config.hf_text_config
qk_nope_head_dim = getattr(hf_text_config, "qk_nope_head_dim", 0)
qk_rope_head_dim = getattr(hf_text_config, "qk_rope_head_dim", 0)
v_head_dim = getattr(hf_text_config, "v_head_dim", 0)
if qk_nope_head_dim != 128 or qk_rope_head_dim != 64 or v_head_dim != 128:
return (
"tokenspeed_mla requires DeepSeek R1 MLA dimensions "
"(qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128), "
f"got ({qk_nope_head_dim}, {qk_rope_head_dim}, {v_head_dim})"
)
return None
@classmethod
def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None":
return "HND"
class TokenspeedMLAImpl(MLACommonImpl[MLACommonMetadata]):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: list[float] | None,
sliding_window: int | None,
kv_cache_dtype: str,
logits_soft_cap: float | None,
attn_type: str,
kv_sharing_target_layer_name: str | None,
# MLA Specific Arguments
**mla_args,
) -> None:
super().__init__(
num_heads,
head_size,
scale,
num_kv_heads,
alibi_slopes,
sliding_window,
kv_cache_dtype,
logits_soft_cap,
attn_type,
kv_sharing_target_layer_name,
**mla_args,
)
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
if any(unsupported_features):
raise NotImplementedError(
"TokenspeedMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, logits_soft_cap"
)
if attn_type != AttentionType.DECODER:
raise NotImplementedError(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"TokenspeedMLAImpl"
)
if not is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"TokenspeedMLAImpl requires an FP8 KV cache "
"(--kv-cache-dtype fp8 or fp8_e4m3); "
f"got kv_cache_dtype={self.kv_cache_dtype!r}."
)
# Allocate (or fetch the cached) workspace lazily on first forward —
# __init__ runs before the device is necessarily set on the worker;
# we know it for sure at forward time when we see the input tensor.
self._workspace_buffer: torch.Tensor | None = None
self.softmax_scale: float | None = None
self.output_scale: float | None = None
# Pre-JIT BF16 and FP8 prefill kernels here too — decode impl always
# runs when tokenspeed is selected, prefill backend may not (user can
# pair with flash_attn / trtllm). Idempotent.
from tokenspeed_mla import warmup_compile_prefill
for q_dtype in (torch.bfloat16, torch.float8_e4m3fn):
warmup_compile_prefill(
q_dtype=q_dtype,
d_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
d_v=self.v_head_dim,
enable_pdl=False,
)
def forward_mqa(
self,
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
layer: AttentionLayer,
) -> tuple[torch.Tensor, torch.Tensor | None]:
from tokenspeed_mla import tokenspeed_mla_decode
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
if isinstance(q, tuple):
q_nope, q_pe = q
q = torch.cat([q_nope, q_pe], dim=-1)
# supports_quant_query_input=True (set in MLACommonImpl) tells the
# pipeline to concat+FP8-quantize Q upstream via _decode_concat_quant_fp8_op.
# The kernel is shape-specialized for FP8 Q + FP8 KV, so anything else
# here means the upstream quant didn't run and the kernel will produce
# garbage.
assert q.dtype == torch.float8_e4m3fn, (
f"TokenspeedMLAImpl expected FP8 query (supports_quant_query_input=True), "
f"got {q.dtype}. Pipeline isinstance(q, tuple)={isinstance(q, tuple)}, "
f"q_scale={layer._q_scale_float}, k_scale={layer._k_scale_float}."
)
# tokenspeed_mla_decode expects query shape
# (num_decodes, q_len_per_request, num_heads, head_dim).
if attn_metadata.num_decode_tokens % attn_metadata.num_decodes != 0:
logger.warning_once(
"""TokenspeedMLAImpl got a query of uneven length.
This usually indicates an issue in batch reordering
or incorrect setup in dummy_run."""
)
q = q.unsqueeze(1)
else:
q = q.view(attn_metadata.num_decodes, -1, q.shape[-2], q.shape[-1])
if self.softmax_scale is None:
# FP8 KV cache is mandatory for this backend, so q_scale/k_scale
# always apply. softmax_scale is bmm1; output_scale is bmm2 — both
# required to recover the correct attention output from the FP8
# KV cache (V is stored as V_real/k_scale).
self.softmax_scale = (
self.scale * layer._q_scale_float * layer._k_scale_float
)
self.output_scale = layer._k_scale_float
if self._workspace_buffer is None:
self._workspace_buffer = _get_workspace(
q.device, self.num_heads, self.kv_lora_rank
)
# vLLM kv_c_and_k_pe_cache is already (num_blocks, block_size, head_size).
# tokenspeed_mla_decode wants 3D — pass as-is (no unsqueeze, unlike trtllm).
o = tokenspeed_mla_decode(
query=q,
kv_cache=kv_c_and_k_pe_cache,
workspace_buffer=self._workspace_buffer,
kv_lora_rank=self.kv_lora_rank,
qk_rope_head_dim=self.qk_rope_head_dim,
block_tables=attn_metadata.decode.block_table,
seq_lens=attn_metadata.decode.seq_lens,
max_seq_len=attn_metadata.max_seq_len,
softmax_scale=self.softmax_scale,
output_scale=self.output_scale,
enable_pdl=False,
)
# Flatten the output for consistent shape
o = o.view(-1, o.shape[-2], o.shape[-1])
# tokenspeed_mla_decode does not return LSE.
return o, None
+3
View File
@@ -63,6 +63,9 @@ class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta):
FLASHINFER_MLA = (
"vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend"
)
TOKENSPEED_MLA = (
"vllm.v1.attention.backends.mla.tokenspeed_mla.TokenspeedMLABackend"
)
FLASHINFER_MLA_SPARSE = (
"vllm.v1.attention.backends.mla.flashinfer_mla_sparse."
"FlashInferMLASparseBackend"