mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[MLA Attention Backend] Add TOKENSPEED_MLA backend for DSR1/Kimi K25 prefill + decode on Blackwell (#41778)
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com> Signed-off-by: Roger Wang <hey@rogerw.io> Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
@@ -53,6 +53,7 @@ backends:
|
||||
- FLASHINFER_MLA
|
||||
- FLASH_ATTN_MLA # Hopper only
|
||||
- FLASHMLA # Hopper only
|
||||
- TOKENSPEED_MLA # Blackwell + R1 dims + FP8 KV (use --kv-cache-dtype fp8)
|
||||
|
||||
device: "cuda:0"
|
||||
repeats: 100
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
# Compares all available MLA prefill backends:
|
||||
# FA backends: fa2, fa3, fa4 (FlashAttention versions)
|
||||
# Non-FA: flashinfer, cudnn, trtllm (Blackwell-only, require flashinfer)
|
||||
# CuTe DSL: tokenspeed (Blackwell + R1 dims, requires tokenspeed_mla)
|
||||
#
|
||||
# Uses cutlass_mla as the decode backend for impl construction
|
||||
# (only the prefill path is exercised).
|
||||
@@ -120,6 +121,7 @@ prefill_backends:
|
||||
- flashinfer
|
||||
- cudnn
|
||||
- trtllm
|
||||
- tokenspeed
|
||||
|
||||
device: "cuda:0"
|
||||
repeats: 20
|
||||
|
||||
@@ -179,19 +179,27 @@ def create_minimal_vllm_config(
|
||||
|
||||
if prefill_backend is not None:
|
||||
prefill_cfg = get_prefill_backend_config(prefill_backend)
|
||||
if prefill_cfg["flash_attn_version"] is not None:
|
||||
vllm_config.attention_config.flash_attn_version = prefill_cfg[
|
||||
"flash_attn_version"
|
||||
if prefill_cfg.get("mla_prefill_backend_enum") is not None:
|
||||
# Registry-based backends bypass the deprecated boolean flags.
|
||||
from vllm.v1.attention.backends.mla.prefill import MLAPrefillBackendEnum
|
||||
|
||||
vllm_config.attention_config.mla_prefill_backend = MLAPrefillBackendEnum[
|
||||
prefill_cfg["mla_prefill_backend_enum"]
|
||||
]
|
||||
vllm_config.attention_config.disable_flashinfer_prefill = prefill_cfg[
|
||||
"disable_flashinfer_prefill"
|
||||
]
|
||||
vllm_config.attention_config.use_cudnn_prefill = prefill_cfg[
|
||||
"use_cudnn_prefill"
|
||||
]
|
||||
vllm_config.attention_config.use_trtllm_ragged_deepseek_prefill = prefill_cfg[
|
||||
"use_trtllm_ragged_deepseek_prefill"
|
||||
]
|
||||
else:
|
||||
if prefill_cfg["flash_attn_version"] is not None:
|
||||
vllm_config.attention_config.flash_attn_version = prefill_cfg[
|
||||
"flash_attn_version"
|
||||
]
|
||||
vllm_config.attention_config.disable_flashinfer_prefill = prefill_cfg[
|
||||
"disable_flashinfer_prefill"
|
||||
]
|
||||
vllm_config.attention_config.use_cudnn_prefill = prefill_cfg[
|
||||
"use_cudnn_prefill"
|
||||
]
|
||||
vllm_config.attention_config.use_trtllm_ragged_deepseek_prefill = (
|
||||
prefill_cfg["use_trtllm_ragged_deepseek_prefill"]
|
||||
)
|
||||
|
||||
return vllm_config
|
||||
|
||||
@@ -223,22 +231,17 @@ _PREFILL_BACKEND_CONFIG: dict[str, dict] = {
|
||||
"use_trtllm_ragged_deepseek_prefill": False,
|
||||
},
|
||||
"flashinfer": {
|
||||
"flash_attn_version": None,
|
||||
"disable_flashinfer_prefill": False,
|
||||
"use_cudnn_prefill": False,
|
||||
"use_trtllm_ragged_deepseek_prefill": False,
|
||||
"mla_prefill_backend_enum": "FLASHINFER",
|
||||
},
|
||||
"cudnn": {
|
||||
"flash_attn_version": None,
|
||||
"disable_flashinfer_prefill": True,
|
||||
"use_cudnn_prefill": True,
|
||||
"use_trtllm_ragged_deepseek_prefill": False,
|
||||
# cuDNN prefill backend was removed; AttentionConfig raises on use.
|
||||
"mla_prefill_backend_enum": "FLASHINFER",
|
||||
},
|
||||
"trtllm": {
|
||||
"flash_attn_version": None,
|
||||
"disable_flashinfer_prefill": True,
|
||||
"use_cudnn_prefill": False,
|
||||
"use_trtllm_ragged_deepseek_prefill": True,
|
||||
"mla_prefill_backend_enum": "TRTLLM_RAGGED",
|
||||
},
|
||||
"tokenspeed": {
|
||||
"mla_prefill_backend_enum": "TOKENSPEED_MLA",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -625,6 +628,21 @@ def _create_backend_impl(
|
||||
# Create mock layer
|
||||
layer = MockLayer(device, impl=impl, kv_cache_spec=kv_cache_spec)
|
||||
|
||||
# Attach a prefill backend (MLAAttention does this in __init__; the metadata
|
||||
# builder reads layer.prefill_backend from static_forward_context).
|
||||
from vllm.v1.attention.backends.mla.prefill import get_mla_prefill_backend
|
||||
|
||||
prefill_backend_cls = get_mla_prefill_backend(vllm_config)
|
||||
layer.prefill_backend = prefill_backend_cls(
|
||||
num_heads=mla_dims["num_q_heads"],
|
||||
scale=(mla_dims["qk_nope_head_dim"] + mla_dims["qk_rope_head_dim"]) ** -0.5,
|
||||
kv_lora_rank=mla_dims["kv_lora_rank"],
|
||||
qk_nope_head_dim=mla_dims["qk_nope_head_dim"],
|
||||
qk_rope_head_dim=mla_dims["qk_rope_head_dim"],
|
||||
v_head_dim=mla_dims["v_head_dim"],
|
||||
vllm_config=vllm_config,
|
||||
)
|
||||
|
||||
# Create builder instance if needed
|
||||
builder_instance = None
|
||||
if builder_class:
|
||||
@@ -961,19 +979,6 @@ def _run_mla_benchmark_batched(
|
||||
results = []
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
# Clear cached prefill backend detection functions so they re-evaluate
|
||||
# with the current VllmConfig. These are @functools.cache decorated and
|
||||
# would otherwise return stale results from a previous backend's config.
|
||||
from vllm.model_executor.layers.attention.mla_attention import (
|
||||
use_cudnn_prefill,
|
||||
use_flashinfer_prefill,
|
||||
use_trtllm_ragged_deepseek_prefill,
|
||||
)
|
||||
|
||||
use_flashinfer_prefill.cache_clear()
|
||||
use_cudnn_prefill.cache_clear()
|
||||
use_trtllm_ragged_deepseek_prefill.cache_clear()
|
||||
|
||||
# Create backend impl, layer, builder, and indexer (reused across benchmarks)
|
||||
impl, layer, builder_instance, indexer = _create_backend_impl(
|
||||
backend_cfg,
|
||||
@@ -985,36 +990,35 @@ def _run_mla_benchmark_batched(
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
)
|
||||
|
||||
# Verify the actual prefill backend matches what was requested
|
||||
# Verify the actual prefill backend matches what was requested. The
|
||||
# selector + impl construction already raise on misuse; here we just
|
||||
# check the resolved class against the requested name as a sanity guard.
|
||||
if prefill_backend is not None:
|
||||
prefill_cfg = get_prefill_backend_config(prefill_backend)
|
||||
fa_version = prefill_cfg["flash_attn_version"]
|
||||
|
||||
if fa_version is not None:
|
||||
# FA backend: verify the impl's FA version
|
||||
actual_fa_version = getattr(impl, "vllm_flash_attn_version", None)
|
||||
expected_class = {
|
||||
"fa2": "FlashAttnPrefillBackend",
|
||||
"fa3": "FlashAttnPrefillBackend",
|
||||
"fa4": "FlashAttnPrefillBackend",
|
||||
"flashinfer": "FlashInferPrefillBackend",
|
||||
"trtllm": "TrtllmRaggedPrefillBackend",
|
||||
"tokenspeed": "TokenspeedMLAPrefillBackend",
|
||||
}.get(prefill_backend)
|
||||
actual_class = type(getattr(layer, "prefill_backend", None)).__name__
|
||||
if expected_class and actual_class != expected_class:
|
||||
raise RuntimeError(
|
||||
f"Prefill backend '{prefill_backend}' requested "
|
||||
f"{expected_class}, got {actual_class}. Check "
|
||||
f"attention_config plumbing or installed deps."
|
||||
)
|
||||
if prefill_backend in {"fa2", "fa3", "fa4"}:
|
||||
fa_version = int(prefill_backend[2:])
|
||||
actual_fa_version = getattr(
|
||||
layer.prefill_backend, "vllm_flash_attn_version", None
|
||||
)
|
||||
if actual_fa_version != fa_version:
|
||||
raise RuntimeError(
|
||||
f"Prefill backend '{prefill_backend}' requested FA "
|
||||
f"version {fa_version}, but the impl is using FA "
|
||||
f"version {actual_fa_version}. Check "
|
||||
f"vllm/v1/attention/backends/fa_utils.py."
|
||||
)
|
||||
else:
|
||||
# Non-FA backend: verify the builder picked the right path
|
||||
expected_flags = {
|
||||
"flashinfer": "_use_fi_prefill",
|
||||
"cudnn": "_use_cudnn_prefill",
|
||||
"trtllm": "_use_trtllm_ragged_prefill",
|
||||
}
|
||||
flag_name = expected_flags.get(prefill_backend)
|
||||
if flag_name and not getattr(builder_instance, flag_name, False):
|
||||
raise RuntimeError(
|
||||
f"Prefill backend '{prefill_backend}' was requested "
|
||||
f"but the metadata builder did not enable it. This "
|
||||
f"usually means a dependency is missing (e.g., "
|
||||
f"flashinfer not installed) or the platform doesn't "
|
||||
f"support it."
|
||||
f"version {fa_version}, got "
|
||||
f"{actual_fa_version} on {actual_class}."
|
||||
)
|
||||
|
||||
# Run each benchmark with the shared impl
|
||||
|
||||
@@ -125,12 +125,13 @@ Priority is **1 = highest** (tried first).
|
||||
| Priority | Backend |
|
||||
| -------- | ------- |
|
||||
| 1 | `FLASHINFER_MLA` |
|
||||
| 2 | `CUTLASS_MLA` |
|
||||
| 3 | `FLASH_ATTN_MLA` |
|
||||
| 4 | `FLASHMLA` |
|
||||
| 5 | `TRITON_MLA` |
|
||||
| 6 | `FLASHINFER_MLA_SPARSE`**\*** |
|
||||
| 7 | `FLASHMLA_SPARSE` |
|
||||
| 2 | `TOKENSPEED_MLA` |
|
||||
| 3 | `CUTLASS_MLA` |
|
||||
| 4 | `FLASH_ATTN_MLA` |
|
||||
| 5 | `FLASHMLA` |
|
||||
| 6 | `TRITON_MLA` |
|
||||
| 7 | `FLASHINFER_MLA_SPARSE`**\*** |
|
||||
| 8 | `FLASHMLA_SPARSE` |
|
||||
|
||||
**Ampere/Hopper (SM 8.x-9.x):**
|
||||
|
||||
@@ -202,6 +203,7 @@ hardware and configuration.
|
||||
| `FLASH_ATTN`‡ | FlashAttention varlen (FA2/FA3/FA4) | fp16, bf16 | Any | FA4 on SM100+, FA3 on SM90, FA2 otherwise |
|
||||
| `TRTLLM_RAGGED` | TensorRT-LLM ragged attention | fp16, bf16 | 10.x | DeepSeek R1 dims only |
|
||||
| `FLASHINFER` | FlashInfer CUTLASS backend | fp16, bf16 | 10.x | DeepSeek R1 dims only |
|
||||
| `TOKENSPEED_MLA` | | fp16, bf16 | 10.x | DeepSeek R1 dims only |
|
||||
|
||||
> **‡** TRT-LLM Ragged is the default on Blackwell (SM100).
|
||||
> On other GPUs, FlashAttention is used as the default.
|
||||
@@ -222,5 +224,6 @@ MLA decode backends are selected using the standard
|
||||
| `ROCM_AITER_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %1 | Any | ❌ | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
|
||||
| `ROCM_AITER_MLA_SPARSE` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | 1, 64 | Any | ❌ | ❌ | ✅ | ❌ | ❌ | Decoder | N/A |
|
||||
| `ROCM_AITER_TRITON_MLA` | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
|
||||
| `TOKENSPEED_MLA` | fp16, bf16 | `fp8`, `fp8_e4m3` | 32, 64 | Any | ❌ | ❌ | ❌ | ❌ | ❌ | Decoder | 10.x |
|
||||
| `TRITON_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | %16 | Any | ❌ | ❌ | ❌ | ❌ | ✅ | Decoder | Any |
|
||||
| `XPU_MLA_SPARSE` | fp16, bf16 | `auto`, `float16`, `bfloat16` | Any | 576 | ❌ | ❌ | ✅ | ❌ | ❌ | Decoder | Any |
|
||||
|
||||
@@ -23,3 +23,6 @@ fastsafetensors >= 0.2.2
|
||||
# QuACK and Cutlass DSL for FA4 (cute-DSL implementation)
|
||||
nvidia-cutlass-dsl[cu13]>=4.4.2
|
||||
quack-kernels>=0.3.3
|
||||
|
||||
# Tokenspeed_MLA for faster mla with spec decode
|
||||
tokenspeed-mla==0.1.2
|
||||
+21
-12
@@ -224,19 +224,28 @@ def init_test_http_connection():
|
||||
def dist_init():
|
||||
from tests.utils import ensure_current_vllm_config
|
||||
|
||||
temp_file = tempfile.mkstemp()[1]
|
||||
# Close the fd returned by mkstemp; FileStore opens the path itself.
|
||||
# Leaving it open leaks one FD per test and eventually exhausts the
|
||||
# ulimit, causing FileStore's destructor to throw c10::DistStoreError
|
||||
# ("Too many open files") during gc and abort the process.
|
||||
fd, temp_file = tempfile.mkstemp()
|
||||
os.close(fd)
|
||||
|
||||
with ensure_current_vllm_config():
|
||||
init_distributed_environment(
|
||||
world_size=1,
|
||||
rank=0,
|
||||
distributed_init_method=f"file://{temp_file}",
|
||||
local_rank=0,
|
||||
backend="nccl",
|
||||
)
|
||||
initialize_model_parallel(1, 1)
|
||||
yield
|
||||
cleanup_dist_env_and_memory()
|
||||
try:
|
||||
with ensure_current_vllm_config():
|
||||
init_distributed_environment(
|
||||
world_size=1,
|
||||
rank=0,
|
||||
distributed_init_method=f"file://{temp_file}",
|
||||
local_rank=0,
|
||||
backend="nccl",
|
||||
)
|
||||
initialize_model_parallel(1, 1)
|
||||
yield
|
||||
cleanup_dist_env_and_memory()
|
||||
finally:
|
||||
with contextlib.suppress(OSError):
|
||||
os.unlink(temp_file)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -20,17 +20,20 @@ from tests.v1.attention.utils import (
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config.vllm import set_current_vllm_config
|
||||
from vllm.model_executor.layers.attention.mla_attention import (
|
||||
MLAAttention,
|
||||
QueryLenSupport,
|
||||
_DecodeConcatQuantFP8,
|
||||
)
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
from vllm.v1.attention.backend import CommonAttentionMetadata
|
||||
from vllm.v1.attention.backends.fa_utils import flash_attn_supports_mla
|
||||
from vllm.v1.attention.backends.mla.prefill import get_mla_prefill_backend
|
||||
from vllm.v1.attention.backends.mla.prefill import (
|
||||
MLAPrefillBackendEnum,
|
||||
get_mla_prefill_backend,
|
||||
)
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.v1.attention.ops.flashmla import is_flashmla_dense_supported
|
||||
from vllm.v1.kv_cache_interface import MLAAttentionSpec
|
||||
@@ -41,6 +44,7 @@ BACKENDS_TO_TEST = [
|
||||
AttentionBackendEnum.FLASH_ATTN_MLA,
|
||||
AttentionBackendEnum.FLASHINFER_MLA,
|
||||
AttentionBackendEnum.TRITON_MLA,
|
||||
AttentionBackendEnum.TOKENSPEED_MLA,
|
||||
]
|
||||
|
||||
DEVICE_TYPE = current_platform.device_type
|
||||
@@ -49,6 +53,7 @@ DEVICE_TYPE = current_platform.device_type
|
||||
if not torch.cuda.is_available() or torch.cuda.get_device_properties(0).major < 10:
|
||||
BACKENDS_TO_TEST.remove(AttentionBackendEnum.CUTLASS_MLA)
|
||||
BACKENDS_TO_TEST.remove(AttentionBackendEnum.FLASHINFER_MLA)
|
||||
BACKENDS_TO_TEST.remove(AttentionBackendEnum.TOKENSPEED_MLA)
|
||||
|
||||
# Remove FLASH_ATTN_MLA from the list if not supported
|
||||
if not flash_attn_supports_mla():
|
||||
@@ -58,6 +63,22 @@ if not flash_attn_supports_mla():
|
||||
if not is_flashmla_dense_supported()[0]:
|
||||
BACKENDS_TO_TEST.remove(AttentionBackendEnum.FLASHMLA)
|
||||
|
||||
# Remove TOKENSPEED_MLA if the optional package is not installed
|
||||
if AttentionBackendEnum.TOKENSPEED_MLA in BACKENDS_TO_TEST:
|
||||
try:
|
||||
import tokenspeed_mla # noqa: F401
|
||||
except ImportError:
|
||||
BACKENDS_TO_TEST.remove(AttentionBackendEnum.TOKENSPEED_MLA)
|
||||
|
||||
|
||||
# Filtered per-test via validate_configuration (capability/deps/dims).
|
||||
PREFILL_BACKENDS_TO_TEST = [
|
||||
MLAPrefillBackendEnum.FLASH_ATTN,
|
||||
MLAPrefillBackendEnum.FLASHINFER,
|
||||
MLAPrefillBackendEnum.TRTLLM_RAGGED,
|
||||
MLAPrefillBackendEnum.TOKENSPEED_MLA,
|
||||
]
|
||||
|
||||
|
||||
SPEC_DECODE_BACKENDS = []
|
||||
for backend in BACKENDS_TO_TEST:
|
||||
@@ -389,14 +410,18 @@ class MockSparseMLAAttentionLayer:
|
||||
return output
|
||||
|
||||
|
||||
class MockMLAAttentionLayer(AttentionLayerBase):
|
||||
class MockMLAAttentionLayer(MLAAttention):
|
||||
"""A mock MLA attention layer for testing.
|
||||
|
||||
This replicates the forward_impl logic from MLAAttention to allow
|
||||
testing MLA backends without the full layer infrastructure.
|
||||
|
||||
The W_UK_T and W_UV weight matrices are created on the layer (like in
|
||||
MLAAttention.process_weights_after_loading), not on the impl.
|
||||
Subclasses MLAAttention so that backends that filter
|
||||
`static_forward_context` by `isinstance(layer, MLAAttention)` (e.g.
|
||||
FlashInfer prefill, which reads sm_scale through that filter) see the
|
||||
mock as a real MLA layer. MLAAttention.__init__ is intentionally
|
||||
skipped — it would create its own impl/prefill_backend and self-register
|
||||
in static_forward_context, which fights what the test sets up below.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -412,6 +437,7 @@ class MockMLAAttentionLayer(AttentionLayerBase):
|
||||
q_scale: float,
|
||||
k_scale: float,
|
||||
):
|
||||
torch.nn.Module.__init__(self)
|
||||
self.impl = impl
|
||||
self.num_heads = num_heads
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
@@ -562,11 +588,15 @@ def run_attention_backend(
|
||||
q_scale: float,
|
||||
k_scale: float,
|
||||
kv_cache_dtype: str = "auto",
|
||||
prefill_backend: MLAPrefillBackendEnum | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Run attention computation using the specified backend's AttentionImpl."""
|
||||
|
||||
builder_cls, impl_cls = try_get_attention_backend(backend)
|
||||
|
||||
# Force the prefill backend selection (None means auto-select).
|
||||
vllm_config.attention_config.mla_prefill_backend = prefill_backend
|
||||
|
||||
# Set the current vllm config so that get_current_vllm_config() works
|
||||
# in the backend implementations
|
||||
with set_current_vllm_config(vllm_config):
|
||||
@@ -578,7 +608,11 @@ def run_attention_backend(
|
||||
vllm_config.parallel_config
|
||||
)
|
||||
head_size = vllm_config.model_config.get_head_size()
|
||||
scale = 1.0 / (head_size**0.5)
|
||||
# Production MLA passes 1/sqrt(qk_head_dim) (the prefill scale) to the
|
||||
# impl and forwards the same value to the prefill backend. FLASHINFER
|
||||
# prefill reads sm_scale back from impl.scale via global_hyperparameters
|
||||
# at plan() time, so impl.scale must agree with prefill_backend.scale.
|
||||
scale = (qk_nope_head_dim + qk_rope_head_dim) ** -0.5
|
||||
impl = impl_cls(
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
@@ -683,6 +717,7 @@ def run_attention_backend(
|
||||
@pytest.mark.parametrize("tensor_parallel_size", [1, 4, 8, 16])
|
||||
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8", "fp8_e4m3"])
|
||||
@pytest.mark.parametrize(("q_scale", "k_scale"), [(1.0, 1.0), (2.0, 3.0)])
|
||||
@pytest.mark.parametrize("prefill_backend", PREFILL_BACKENDS_TO_TEST)
|
||||
def test_backend_correctness(
|
||||
default_vllm_config,
|
||||
dist_init,
|
||||
@@ -693,6 +728,7 @@ def test_backend_correctness(
|
||||
kv_cache_dtype: str,
|
||||
q_scale: float,
|
||||
k_scale: float,
|
||||
prefill_backend: MLAPrefillBackendEnum,
|
||||
):
|
||||
"""
|
||||
Test that all backends produce similar outputs to a reference implementation
|
||||
@@ -729,6 +765,24 @@ def test_backend_correctness(
|
||||
if not backends_to_test:
|
||||
pytest.skip(f"No backends support kv_cache_dtype={kv_cache_dtype}")
|
||||
|
||||
# Skip prefill backends that can't satisfy capability/deps/R1 constraints.
|
||||
from vllm.v1.attention.backends.mla.prefill.selector import (
|
||||
MLAPrefillSelectorConfig,
|
||||
)
|
||||
|
||||
try:
|
||||
prefill_invalid_reasons = prefill_backend.get_class().validate_configuration(
|
||||
current_platform.get_device_capability(),
|
||||
MLAPrefillSelectorConfig(dtype=torch.bfloat16, is_r1_compatible=True),
|
||||
)
|
||||
except ImportError:
|
||||
prefill_invalid_reasons = ["ImportError"]
|
||||
if prefill_invalid_reasons:
|
||||
pytest.skip(
|
||||
f"Prefill backend {prefill_backend.name} unavailable: "
|
||||
f"{prefill_invalid_reasons}"
|
||||
)
|
||||
|
||||
batch_spec = BATCH_SPECS[batch_spec_name]
|
||||
is_spec_decode_test = batch_spec_name.startswith("spec_decode")
|
||||
unique_block_sizes = sorted(set(BACKEND_BLOCK_SIZES[b] for b in backends_to_test))
|
||||
@@ -799,9 +853,13 @@ def test_backend_correctness(
|
||||
assert kv_lora_rank + qk_rope_head_dim == head_size, (
|
||||
f"MLA dimensions don't match: {total_head_size} != {head_size}"
|
||||
)
|
||||
decode_scale = 1.0 / (total_head_size**0.5)
|
||||
qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
||||
prefill_scale = qk_head_dim**-0.5
|
||||
# MLA reuses prefill_scale for the decode path: production sets
|
||||
# impl.scale = 1/sqrt(qk_head_dim) and the decode kernels apply it even
|
||||
# though the latent attention runs at head_size dimensions. Keeping the
|
||||
# reference here in sync with run_attention_backend's impl.scale.
|
||||
decode_scale = prefill_scale
|
||||
|
||||
# 2. Generate data and compute SDPA reference output for MLA
|
||||
all_q_vllm, all_kv_c_vllm, all_k_pe_vllm = [], [], []
|
||||
@@ -1092,6 +1150,7 @@ def test_backend_correctness(
|
||||
qk_rope_head_dim,
|
||||
v_head_dim,
|
||||
mock_kv_b_proj,
|
||||
prefill_backend=prefill_backend,
|
||||
q_scale=q_scale,
|
||||
k_scale=k_scale,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
|
||||
@@ -1362,6 +1362,7 @@ def backend_supports_prefill_query_quantization() -> bool:
|
||||
return backend_cls.get_name() in (
|
||||
"FLASHINFER",
|
||||
"TRTLLM_RAGGED",
|
||||
"TOKENSPEED_MLA",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -110,6 +110,10 @@ def _get_backend_priorities(
|
||||
|
||||
return [
|
||||
AttentionBackendEnum.FLASHINFER_MLA,
|
||||
# R1 dims + FP8 KV only; rejected by supports_combination
|
||||
# otherwise. Behind FLASHINFER_MLA: wins past bs≈8, regresses
|
||||
# at bs≤2.
|
||||
AttentionBackendEnum.TOKENSPEED_MLA,
|
||||
AttentionBackendEnum.CUTLASS_MLA,
|
||||
AttentionBackendEnum.FLASH_ATTN_MLA,
|
||||
AttentionBackendEnum.FLASHMLA,
|
||||
|
||||
@@ -43,6 +43,10 @@ class MLAPrefillBackendEnum(Enum, metaclass=_MLAPrefillBackendEnumMeta):
|
||||
"vllm.v1.attention.backends.mla.prefill.trtllm_ragged."
|
||||
"TrtllmRaggedPrefillBackend"
|
||||
)
|
||||
TOKENSPEED_MLA = (
|
||||
"vllm.v1.attention.backends.mla.prefill.tokenspeed_mla."
|
||||
"TokenspeedMLAPrefillBackend"
|
||||
)
|
||||
|
||||
def get_path(self) -> str:
|
||||
"""Get the fully qualified class path for this backend."""
|
||||
|
||||
@@ -67,6 +67,7 @@ def _get_mla_prefill_backend_priorities(
|
||||
MLAPrefillBackendEnum.FLASH_ATTN,
|
||||
MLAPrefillBackendEnum.TRTLLM_RAGGED,
|
||||
MLAPrefillBackendEnum.FLASHINFER,
|
||||
MLAPrefillBackendEnum.TOKENSPEED_MLA,
|
||||
]
|
||||
else: # Hopper (SM90) and older
|
||||
return [
|
||||
|
||||
@@ -0,0 +1,180 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""TokenSpeed CuTe DSL backend for MLA prefill."""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.v1.attention.backends.mla.prefill.base import MLAPrefillBackend
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.attention.mla_attention import (
|
||||
MLACommonPrefillMetadata,
|
||||
)
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
|
||||
|
||||
class TokenspeedMLAPrefillBackend(MLAPrefillBackend):
|
||||
"""TokenSpeed CuTe DSL backend for MLA prefill."""
|
||||
|
||||
requires_r1_mla_dimensions = True
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "TOKENSPEED_MLA"
|
||||
|
||||
@classmethod
|
||||
def supports_compute_capability(cls, device_capability: "DeviceCapability") -> bool:
|
||||
return device_capability.major == 10
|
||||
|
||||
_INSTALL_HINT = (
|
||||
"tokenspeed_mla package is not installed. "
|
||||
"Install it with: `uv pip install tokenspeed-mla`"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_available(cls) -> bool:
|
||||
try:
|
||||
from tokenspeed_mla import (
|
||||
tokenspeed_mla_prefill, # noqa: F401
|
||||
)
|
||||
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def validate_configuration(
|
||||
cls,
|
||||
device_capability,
|
||||
selector_config,
|
||||
) -> list[str]:
|
||||
# Replace the generic "required dependencies not available" message
|
||||
# from the base class with a specific install hint so users know
|
||||
# exactly which package to install when they explicitly select this
|
||||
# backend without having tokenspeed_mla installed.
|
||||
reasons = super().validate_configuration(device_capability, selector_config)
|
||||
return [
|
||||
cls._INSTALL_HINT if r == "required dependencies not available" else r
|
||||
for r in reasons
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
scale: float,
|
||||
kv_lora_rank: int,
|
||||
qk_nope_head_dim: int,
|
||||
qk_rope_head_dim: int,
|
||||
v_head_dim: int,
|
||||
vllm_config: "VllmConfig",
|
||||
) -> None:
|
||||
super().__init__(
|
||||
num_heads=num_heads,
|
||||
scale=scale,
|
||||
kv_lora_rank=kv_lora_rank,
|
||||
qk_nope_head_dim=qk_nope_head_dim,
|
||||
qk_rope_head_dim=qk_rope_head_dim,
|
||||
v_head_dim=v_head_dim,
|
||||
vllm_config=vllm_config,
|
||||
)
|
||||
|
||||
# Pre-JIT BF16 and FP8 prefill kernels. Idempotent — also called from
|
||||
# TokenspeedMLAImpl.__init__; second call is a no-op.
|
||||
from tokenspeed_mla import warmup_compile_prefill
|
||||
|
||||
for q_dtype in (torch.bfloat16, torch.float8_e4m3fn):
|
||||
warmup_compile_prefill(
|
||||
q_dtype=q_dtype,
|
||||
d_qk=qk_nope_head_dim + qk_rope_head_dim,
|
||||
d_v=v_head_dim,
|
||||
enable_pdl=False,
|
||||
)
|
||||
|
||||
def prepare_metadata(
|
||||
self,
|
||||
prefill_metadata: "MLACommonPrefillMetadata",
|
||||
) -> None:
|
||||
super().prepare_metadata(prefill_metadata)
|
||||
# Kernel signature requires `seq_lens` but the implementation never reads
|
||||
# it (per-batch lengths are derived from `cum_seq_lens` diffs); compute
|
||||
# for parity with trtllm_ragged. cuda-graph padding in
|
||||
# `query_start_loc` is saturated to `total_num_tokens`
|
||||
# (gpu_model_runner.py:1905), so trailing diffs are 0 and padded batches
|
||||
# are kernel no-ops — same reason trtllm passes the padded length as
|
||||
# batch_size directly.
|
||||
self._query_seq_lens = (
|
||||
prefill_metadata.query_start_loc[1:] - prefill_metadata.query_start_loc[:-1]
|
||||
)
|
||||
|
||||
def run_prefill_new_tokens(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
return_softmax_lse: bool,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
from tokenspeed_mla import tokenspeed_mla_prefill
|
||||
|
||||
# `v` arrives as the second half of `kv_nope.split(...)` in
|
||||
# mla_attention.forward_mha — a non-contiguous view of `kv_nope` along
|
||||
# dim=-1. The kernel does `v.reshape(1, total_kv, h_k, 1, d_v)` which
|
||||
# would silently copy on a non-contiguous tensor; force contiguity here
|
||||
# so the copy (if any) happens once outside the kernel call.
|
||||
v = v.contiguous()
|
||||
|
||||
ret = tokenspeed_mla_prefill(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
seq_lens=self._query_seq_lens,
|
||||
cum_seq_lens=self._prefill_metadata.query_start_loc,
|
||||
max_seq_len=self._prefill_metadata.max_query_len,
|
||||
batch_size=self._query_seq_lens.shape[0],
|
||||
softmax_scale=self.scale,
|
||||
is_causal=True,
|
||||
return_lse=return_softmax_lse,
|
||||
enable_pdl=False,
|
||||
)
|
||||
|
||||
if isinstance(ret, tuple):
|
||||
# Convert from (q_len, num_heads) to (num_heads, q_len)
|
||||
return ret[0], ret[1].transpose(0, 1).contiguous()
|
||||
return ret
|
||||
|
||||
def run_prefill_context_chunk(
|
||||
self,
|
||||
chunk_idx: int,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
from tokenspeed_mla import tokenspeed_mla_prefill
|
||||
|
||||
assert self._prefill_metadata.chunked_context is not None
|
||||
chunked = self._prefill_metadata.chunked_context
|
||||
|
||||
# See note in run_prefill_new_tokens — `v` is a split-view of `kv_nope`
|
||||
# in `_compute_prefill_context` and arrives non-contiguous.
|
||||
v = v.contiguous()
|
||||
|
||||
attn_out, lse = tokenspeed_mla_prefill(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
seq_lens=chunked.seq_lens[chunk_idx],
|
||||
cum_seq_lens=chunked.cu_seq_lens[chunk_idx],
|
||||
max_seq_len=chunked.max_seq_lens[chunk_idx],
|
||||
batch_size=chunked.seq_lens[chunk_idx].shape[0],
|
||||
softmax_scale=self.scale,
|
||||
is_causal=False,
|
||||
return_lse=True,
|
||||
cum_seq_lens_q=self._prefill_metadata.query_start_loc,
|
||||
max_seq_len_q=self._prefill_metadata.max_query_len,
|
||||
enable_pdl=False,
|
||||
)
|
||||
|
||||
# Convert from (q_len, num_heads) to (num_heads, q_len)
|
||||
return attn_out, lse.transpose(0, 1).contiguous()
|
||||
@@ -0,0 +1,277 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""TokenSpeed CuTe DSL MLA decode backend (Blackwell, FP8 KV cache only)."""
|
||||
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention.mla_attention import (
|
||||
MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder,
|
||||
QueryLenSupport,
|
||||
)
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.utils.torch_utils import is_quantized_kv_cache
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionCGSupport,
|
||||
AttentionLayer,
|
||||
AttentionType,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import KVCacheLayoutType
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Workspace upper bound for tokenspeed_mla_decode (per-device, lazy):
|
||||
# num_sms * num_heads * MAX_Q_LEN * (kv_lora_rank + 1) * sizeof(float32)
|
||||
# Matches the kernel's `get_workspace_size` formula. MAX_Q_LEN=8 covers up to
|
||||
# EAGLE3 / MTP-2 spec decoding query lengths; larger q_len fails the kernel's
|
||||
# own buffer check.
|
||||
_TOKENSPEED_MAX_Q_LEN = 8
|
||||
|
||||
_g_workspace: dict[torch.device, torch.Tensor] = {}
|
||||
|
||||
|
||||
def _get_workspace(
|
||||
device: torch.device, num_heads: int, kv_lora_rank: int
|
||||
) -> torch.Tensor:
|
||||
from tokenspeed_mla import get_num_sm
|
||||
|
||||
needed = (
|
||||
get_num_sm(device) * num_heads * _TOKENSPEED_MAX_Q_LEN * (kv_lora_rank + 1) * 4
|
||||
)
|
||||
existing = _g_workspace.get(device)
|
||||
if existing is None or existing.numel() < needed:
|
||||
_g_workspace[device] = torch.empty(needed, dtype=torch.int8, device=device)
|
||||
return _g_workspace[device]
|
||||
|
||||
|
||||
class TokenspeedMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
|
||||
query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM
|
||||
|
||||
|
||||
class TokenspeedMLABackend(MLACommonBackend):
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"fp8",
|
||||
"fp8_e4m3",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||
return [32, 64]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "TOKENSPEED_MLA"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["TokenspeedMLAImpl"]:
|
||||
return TokenspeedMLAImpl
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["TokenspeedMLAMetadataBuilder"]:
|
||||
return TokenspeedMLAMetadataBuilder
|
||||
|
||||
@classmethod
|
||||
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
|
||||
return capability.major == 10
|
||||
|
||||
@classmethod
|
||||
def supports_combination(
|
||||
cls,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: CacheDType | None,
|
||||
block_size: int | None,
|
||||
use_mla: bool,
|
||||
has_sink: bool,
|
||||
use_sparse: bool,
|
||||
device_capability: DeviceCapability,
|
||||
) -> str | None:
|
||||
# Surface a clear install hint up front rather than letting a raw
|
||||
# ModuleNotFoundError fire deep inside `forward_mqa` at first request.
|
||||
try:
|
||||
import tokenspeed_mla # noqa: F401
|
||||
except ImportError:
|
||||
return (
|
||||
"tokenspeed_mla package is not installed. "
|
||||
"Install it with: `uv pip install tokenspeed-mla`"
|
||||
)
|
||||
|
||||
# tokenspeed_mla CuTe DSL kernel is shape-specialized for DeepSeek R1
|
||||
# MLA dimensions (qk_nope=128, qk_rope=64, v=128). Reject anything else.
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
if vllm_config.model_config is not None:
|
||||
hf_text_config = vllm_config.model_config.hf_text_config
|
||||
qk_nope_head_dim = getattr(hf_text_config, "qk_nope_head_dim", 0)
|
||||
qk_rope_head_dim = getattr(hf_text_config, "qk_rope_head_dim", 0)
|
||||
v_head_dim = getattr(hf_text_config, "v_head_dim", 0)
|
||||
if qk_nope_head_dim != 128 or qk_rope_head_dim != 64 or v_head_dim != 128:
|
||||
return (
|
||||
"tokenspeed_mla requires DeepSeek R1 MLA dimensions "
|
||||
"(qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128), "
|
||||
f"got ({qk_nope_head_dim}, {qk_rope_head_dim}, {v_head_dim})"
|
||||
)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None":
|
||||
return "HND"
|
||||
|
||||
|
||||
class TokenspeedMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None,
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: str | None,
|
||||
# MLA Specific Arguments
|
||||
**mla_args,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
num_heads,
|
||||
head_size,
|
||||
scale,
|
||||
num_kv_heads,
|
||||
alibi_slopes,
|
||||
sliding_window,
|
||||
kv_cache_dtype,
|
||||
logits_soft_cap,
|
||||
attn_type,
|
||||
kv_sharing_target_layer_name,
|
||||
**mla_args,
|
||||
)
|
||||
|
||||
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"TokenspeedMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, logits_soft_cap"
|
||||
)
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError(
|
||||
"Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"TokenspeedMLAImpl"
|
||||
)
|
||||
|
||||
if not is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"TokenspeedMLAImpl requires an FP8 KV cache "
|
||||
"(--kv-cache-dtype fp8 or fp8_e4m3); "
|
||||
f"got kv_cache_dtype={self.kv_cache_dtype!r}."
|
||||
)
|
||||
|
||||
# Allocate (or fetch the cached) workspace lazily on first forward —
|
||||
# __init__ runs before the device is necessarily set on the worker;
|
||||
# we know it for sure at forward time when we see the input tensor.
|
||||
self._workspace_buffer: torch.Tensor | None = None
|
||||
self.softmax_scale: float | None = None
|
||||
self.output_scale: float | None = None
|
||||
|
||||
# Pre-JIT BF16 and FP8 prefill kernels here too — decode impl always
|
||||
# runs when tokenspeed is selected, prefill backend may not (user can
|
||||
# pair with flash_attn / trtllm). Idempotent.
|
||||
from tokenspeed_mla import warmup_compile_prefill
|
||||
|
||||
for q_dtype in (torch.bfloat16, torch.float8_e4m3fn):
|
||||
warmup_compile_prefill(
|
||||
q_dtype=q_dtype,
|
||||
d_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
|
||||
d_v=self.v_head_dim,
|
||||
enable_pdl=False,
|
||||
)
|
||||
|
||||
def forward_mqa(
|
||||
self,
|
||||
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: MLACommonMetadata,
|
||||
layer: AttentionLayer,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
from tokenspeed_mla import tokenspeed_mla_decode
|
||||
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
if isinstance(q, tuple):
|
||||
q_nope, q_pe = q
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
|
||||
# supports_quant_query_input=True (set in MLACommonImpl) tells the
|
||||
# pipeline to concat+FP8-quantize Q upstream via _decode_concat_quant_fp8_op.
|
||||
# The kernel is shape-specialized for FP8 Q + FP8 KV, so anything else
|
||||
# here means the upstream quant didn't run and the kernel will produce
|
||||
# garbage.
|
||||
assert q.dtype == torch.float8_e4m3fn, (
|
||||
f"TokenspeedMLAImpl expected FP8 query (supports_quant_query_input=True), "
|
||||
f"got {q.dtype}. Pipeline isinstance(q, tuple)={isinstance(q, tuple)}, "
|
||||
f"q_scale={layer._q_scale_float}, k_scale={layer._k_scale_float}."
|
||||
)
|
||||
|
||||
# tokenspeed_mla_decode expects query shape
|
||||
# (num_decodes, q_len_per_request, num_heads, head_dim).
|
||||
if attn_metadata.num_decode_tokens % attn_metadata.num_decodes != 0:
|
||||
logger.warning_once(
|
||||
"""TokenspeedMLAImpl got a query of uneven length.
|
||||
This usually indicates an issue in batch reordering
|
||||
or incorrect setup in dummy_run."""
|
||||
)
|
||||
q = q.unsqueeze(1)
|
||||
else:
|
||||
q = q.view(attn_metadata.num_decodes, -1, q.shape[-2], q.shape[-1])
|
||||
|
||||
if self.softmax_scale is None:
|
||||
# FP8 KV cache is mandatory for this backend, so q_scale/k_scale
|
||||
# always apply. softmax_scale is bmm1; output_scale is bmm2 — both
|
||||
# required to recover the correct attention output from the FP8
|
||||
# KV cache (V is stored as V_real/k_scale).
|
||||
self.softmax_scale = (
|
||||
self.scale * layer._q_scale_float * layer._k_scale_float
|
||||
)
|
||||
self.output_scale = layer._k_scale_float
|
||||
|
||||
if self._workspace_buffer is None:
|
||||
self._workspace_buffer = _get_workspace(
|
||||
q.device, self.num_heads, self.kv_lora_rank
|
||||
)
|
||||
|
||||
# vLLM kv_c_and_k_pe_cache is already (num_blocks, block_size, head_size).
|
||||
# tokenspeed_mla_decode wants 3D — pass as-is (no unsqueeze, unlike trtllm).
|
||||
o = tokenspeed_mla_decode(
|
||||
query=q,
|
||||
kv_cache=kv_c_and_k_pe_cache,
|
||||
workspace_buffer=self._workspace_buffer,
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||
block_tables=attn_metadata.decode.block_table,
|
||||
seq_lens=attn_metadata.decode.seq_lens,
|
||||
max_seq_len=attn_metadata.max_seq_len,
|
||||
softmax_scale=self.softmax_scale,
|
||||
output_scale=self.output_scale,
|
||||
enable_pdl=False,
|
||||
)
|
||||
|
||||
# Flatten the output for consistent shape
|
||||
o = o.view(-1, o.shape[-2], o.shape[-1])
|
||||
|
||||
# tokenspeed_mla_decode does not return LSE.
|
||||
return o, None
|
||||
@@ -63,6 +63,9 @@ class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta):
|
||||
FLASHINFER_MLA = (
|
||||
"vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend"
|
||||
)
|
||||
TOKENSPEED_MLA = (
|
||||
"vllm.v1.attention.backends.mla.tokenspeed_mla.TokenspeedMLABackend"
|
||||
)
|
||||
FLASHINFER_MLA_SPARSE = (
|
||||
"vllm.v1.attention.backends.mla.flashinfer_mla_sparse."
|
||||
"FlashInferMLASparseBackend"
|
||||
|
||||
Reference in New Issue
Block a user