[Perf] Support FP8 KV cache for Flashinfer MLA Sparse (#35891)

This commit is contained in:
Wei Zhao
2026-03-07 16:51:54 -05:00
committed by GitHub
parent a6be75dbd2
commit 379689d533
8 changed files with 89 additions and 17 deletions
+1 -1
View File
@@ -206,7 +206,7 @@ configuration.
|---------|--------|-----------|-------------|------------|------|--------|-----------|-----|-----------------|--------------|
| `CUTLASS_MLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3` | 128 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 10.x |
| `FLASHINFER_MLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3` | 32, 64 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | 10.x |
| `FLASHINFER_MLA_SPARSE` | fp16, bf16 | `auto`, `bfloat16` | 32, 64 | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 10.x |
| `FLASHINFER_MLA_SPARSE` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3` | 32, 64 | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 10.x |
| `FLASHMLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3` | 64 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x-10.x |
| `FLASHMLA_SPARSE` | bf16 | `auto`, `bfloat16`, `fp8_ds_mla` | 64 | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 9.x-10.x |
| `FLASH_ATTN_MLA` | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x |
+18 -2
View File
@@ -327,6 +327,12 @@ class MockSparseMLAAttentionLayer:
self._k_scale_float = 1.0
self._v_scale_float = 1.0
self._decode_concat_quant_fp8_op = _DecodeConcatQuantFP8(
static=True,
group_shape=GroupShape.PER_TENSOR,
compile_native=True,
)
def forward_impl(
self,
q: torch.Tensor,
@@ -338,6 +344,7 @@ class MockSparseMLAAttentionLayer:
) -> torch.Tensor:
"""Forward for sparse MLA - uses forward_mqa for all tokens."""
kv_cache_dtype = getattr(self.impl, "kv_cache_dtype", "auto")
fp8_attention = kv_cache_dtype.startswith("fp8")
# Write to KV cache
if kv_cache.numel() > 0:
@@ -350,6 +357,9 @@ class MockSparseMLAAttentionLayer:
scale=self._k_scale,
)
if fp8_attention and kv_cache_dtype != "fp8_ds_mla":
kv_cache = kv_cache.view(current_platform.fp8_dtype())
num_tokens = q.shape[0]
# Sparse MLA uses forward_mqa for all tokens
@@ -367,8 +377,14 @@ class MockSparseMLAAttentionLayer:
# Convert from (N, B, L) to (B, N, L)
mqa_ql_nope = mqa_ql_nope.transpose(0, 1)
# Pass as tuple to forward_mqa
mqa_q = (mqa_ql_nope, mqa_q_pe)
if fp8_attention and self.impl.supports_quant_query_input:
assert mqa_ql_nope.shape[0] == mqa_q_pe.shape[0]
assert mqa_ql_nope.shape[1] == mqa_q_pe.shape[1]
mqa_q = self._decode_concat_quant_fp8_op(
mqa_ql_nope, mqa_q_pe, self._q_scale
)
else:
mqa_q = (mqa_ql_nope, mqa_q_pe)
attn_out, _ = self.impl.forward_mqa(mqa_q, kv_cache, attn_metadata, self)
+11 -1
View File
@@ -191,6 +191,16 @@ def test_sparse_backend_decode_correctness(
if kv_cache_dtype not in backend_cls.supported_kv_cache_dtypes:
pytest.skip(f"{backend_cls.get_name()} does not support {kv_cache_dtype}")
if (
backend_cls == FlashMLASparseBackend
and kv_cache_dtype.startswith("fp8")
and kv_cache_dtype != "fp8_ds_mla"
):
pytest.skip(
"FlashMLA Sparse Attention backend fp8 only supports "
"fp8_ds_mla kv-cache dtype"
)
supported_block_sizes = backend_cls.get_supported_kernel_block_sizes()
if block_size not in supported_block_sizes:
pytest.skip(
@@ -419,7 +429,7 @@ def test_sparse_backend_decode_correctness(
num_blocks=vllm_config.cache_config.num_gpu_blocks,
common_attn_metadata=common_attn_metadata,
randomize_blocks=False,
kv_cache_dtype=kv_cache_dtype if use_fp8_ds_mla_quantization else "auto",
kv_cache_dtype=kv_cache_dtype,
scale=kv_cache_scale,
)
@@ -49,6 +49,11 @@ MLA_ATTENTION_FILE = (
# Backends to skip during doc generation
SKIP_BACKENDS = {"CUSTOM", "TORCH_SDPA"}
BACKEND_KV_DTYPE_EXCLUDES: dict[str, set[str]] = {
# fp8 is an alias for fp8_ds_mla for FlashMLA Sparse
"FLASHMLA_SPARSE": {"fp8"},
}
def is_relevant_file(filepath: str) -> bool:
"""Check if a file matches any of the relevant patterns."""
@@ -546,10 +551,19 @@ def analyze_backend(backend_name: str, class_path: str) -> dict[str, Any] | None
tree, impl_class_name, "can_return_lse_for_decode", False, file_path
)
kv_cache_dtypes = parse_kv_cache_dtypes(class_node)
if backend_name in BACKEND_KV_DTYPE_EXCLUDES:
excluded = BACKEND_KV_DTYPE_EXCLUDES[backend_name]
kv_cache_dtypes = ", ".join(
d
for d in (d.strip() for d in kv_cache_dtypes.split(","))
if d not in excluded
)
return {
"name": backend_name,
"dtypes": parse_supported_dtypes(class_node),
"kv_cache_dtypes": parse_kv_cache_dtypes(class_node),
"kv_cache_dtypes": kv_cache_dtypes,
"block_sizes": parse_block_sizes(class_node),
"head_sizes": parse_head_sizes(class_node),
"attn_types": parse_attention_types(class_node),
@@ -331,11 +331,6 @@ class MLAAttention(nn.Module, AttentionLayerBase):
calculate_kv_scales = False
self.quant_config = quant_config
# Initialize KV cache quantization attributes
self.kv_cache_dtype = kv_cache_dtype
self.calculate_kv_scales = calculate_kv_scales
_init_kv_cache_quant(self, quant_config, prefix)
dtype = torch.get_default_dtype()
self.attn_backend = get_attn_backend(
self.head_size,
@@ -347,6 +342,36 @@ class MLAAttention(nn.Module, AttentionLayerBase):
num_heads=self.num_heads,
)
# FlashMLA Sparse Attention fp8 backend uses "fp8_ds_mla" kv-cache format
# Automatically convert fp8 kv-cache format to "fp8_ds_mla"
if (
self.attn_backend.get_name() == "FLASHMLA_SPARSE"
and kv_cache_dtype.startswith("fp8")
and kv_cache_dtype != "fp8_ds_mla"
):
assert cache_config is not None
cache_config.cache_dtype = "fp8_ds_mla"
kv_cache_dtype = "fp8_ds_mla"
logger.info_once(
"Using DeepSeek's fp8_ds_mla KV cache format. To use standard "
"fp8 kv-cache format, please set `--attention-backend "
"FLASHINFER_MLA_SPARSE`"
)
if (
self.attn_backend.get_name() == "FLASHINFER_MLA_SPARSE"
and kv_cache_dtype.startswith("fp8")
):
logger.info_once(
"Using standard fp8 KV cache format. To use DeepSeek's fp8_ds_mla "
"KV cache format, please set `--attention-backend FLASHMLA_SPARSE`"
)
# Initialize KV cache quantization attributes
self.kv_cache_dtype = kv_cache_dtype
self.calculate_kv_scales = calculate_kv_scales
_init_kv_cache_quant(self, quant_config, prefix)
if (
cache_config is not None
and cache_config.enable_prefix_caching
-7
View File
@@ -31,20 +31,13 @@ class VerifyAndUpdateConfig:
class DeepseekV32ForCausalLM(VerifyAndUpdateConfig):
@classmethod
def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
"""
Updated fp8 cache to custom "fp8_ds_mla" format for DeepSeekV32
"""
hf_config = vllm_config.model_config.hf_config
# Mirror the check in vllm/model_executor/models/deepseek_v2.py
is_v32 = hasattr(hf_config, "index_topk")
assert is_v32
# For DeepSeekV3.2, a custom fp8 format is used when fp8 kv-cache is enabled.
cache_config = vllm_config.cache_config
if cache_config.cache_dtype.startswith("fp8"):
cache_config.cache_dtype = "fp8_ds_mla"
logger.info("Using custom fp8 kv-cache format for DeepSeekV3.2")
if cache_config.cache_dtype == "bfloat16":
cache_config.cache_dtype = "auto"
logger.info("Using bfloat16 kv-cache for DeepSeekV3.2")
@@ -63,6 +63,8 @@ class FlashInferMLASparseBackend(AttentionBackend):
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"bfloat16",
"fp8",
"fp8_e4m3",
]
@staticmethod
@@ -304,6 +306,11 @@ class FlashInferMLASparseImpl(SparseMLAAttentionImpl[FlashInferMLASparseMetadata
self.bmm1_scale: float | None = None
self.bmm2_scale: float | None = None
# fp8 query quantization is required when using fp8 kv_cache,
# as the TRTLLM-GEN sparse MLA kernel requires matching dtypes
# for query and kv_cache (mixed bf16+fp8 is not supported).
self.supports_quant_query_input = True
def forward_mqa(
self,
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
@@ -83,6 +83,7 @@ class FlashMLASparseBackend(AttentionBackend):
"auto",
"bfloat16",
"fp8_ds_mla",
"fp8", # alias for fp8_ds_mla
]
@staticmethod
@@ -567,6 +568,12 @@ class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
)
self.fp8_decode_padded_heads = self._compute_fp8_decode_padded_heads(num_heads)
if kv_cache_dtype.startswith("fp8"):
assert kv_cache_dtype == "fp8_ds_mla", (
"FlashMLA Sparse Attention backend fp8 only supports "
"fp8_ds_mla kv-cache dtype"
)
if kv_cache_dtype == "fp8_ds_mla":
# Reserve workspace during initialization
vllm_config = get_current_vllm_config()