mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[Perf] Support FP8 KV cache for Flashinfer MLA Sparse (#35891)
This commit is contained in:
@@ -206,7 +206,7 @@ configuration.
|
||||
|---------|--------|-----------|-------------|------------|------|--------|-----------|-----|-----------------|--------------|
|
||||
| `CUTLASS_MLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3` | 128 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 10.x |
|
||||
| `FLASHINFER_MLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3` | 32, 64 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | 10.x |
|
||||
| `FLASHINFER_MLA_SPARSE` | fp16, bf16 | `auto`, `bfloat16` | 32, 64 | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 10.x |
|
||||
| `FLASHINFER_MLA_SPARSE` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3` | 32, 64 | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 10.x |
|
||||
| `FLASHMLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3` | 64 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x-10.x |
|
||||
| `FLASHMLA_SPARSE` | bf16 | `auto`, `bfloat16`, `fp8_ds_mla` | 64 | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 9.x-10.x |
|
||||
| `FLASH_ATTN_MLA` | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x |
|
||||
|
||||
@@ -327,6 +327,12 @@ class MockSparseMLAAttentionLayer:
|
||||
self._k_scale_float = 1.0
|
||||
self._v_scale_float = 1.0
|
||||
|
||||
self._decode_concat_quant_fp8_op = _DecodeConcatQuantFP8(
|
||||
static=True,
|
||||
group_shape=GroupShape.PER_TENSOR,
|
||||
compile_native=True,
|
||||
)
|
||||
|
||||
def forward_impl(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
@@ -338,6 +344,7 @@ class MockSparseMLAAttentionLayer:
|
||||
) -> torch.Tensor:
|
||||
"""Forward for sparse MLA - uses forward_mqa for all tokens."""
|
||||
kv_cache_dtype = getattr(self.impl, "kv_cache_dtype", "auto")
|
||||
fp8_attention = kv_cache_dtype.startswith("fp8")
|
||||
|
||||
# Write to KV cache
|
||||
if kv_cache.numel() > 0:
|
||||
@@ -350,6 +357,9 @@ class MockSparseMLAAttentionLayer:
|
||||
scale=self._k_scale,
|
||||
)
|
||||
|
||||
if fp8_attention and kv_cache_dtype != "fp8_ds_mla":
|
||||
kv_cache = kv_cache.view(current_platform.fp8_dtype())
|
||||
|
||||
num_tokens = q.shape[0]
|
||||
|
||||
# Sparse MLA uses forward_mqa for all tokens
|
||||
@@ -367,8 +377,14 @@ class MockSparseMLAAttentionLayer:
|
||||
# Convert from (N, B, L) to (B, N, L)
|
||||
mqa_ql_nope = mqa_ql_nope.transpose(0, 1)
|
||||
|
||||
# Pass as tuple to forward_mqa
|
||||
mqa_q = (mqa_ql_nope, mqa_q_pe)
|
||||
if fp8_attention and self.impl.supports_quant_query_input:
|
||||
assert mqa_ql_nope.shape[0] == mqa_q_pe.shape[0]
|
||||
assert mqa_ql_nope.shape[1] == mqa_q_pe.shape[1]
|
||||
mqa_q = self._decode_concat_quant_fp8_op(
|
||||
mqa_ql_nope, mqa_q_pe, self._q_scale
|
||||
)
|
||||
else:
|
||||
mqa_q = (mqa_ql_nope, mqa_q_pe)
|
||||
|
||||
attn_out, _ = self.impl.forward_mqa(mqa_q, kv_cache, attn_metadata, self)
|
||||
|
||||
|
||||
@@ -191,6 +191,16 @@ def test_sparse_backend_decode_correctness(
|
||||
if kv_cache_dtype not in backend_cls.supported_kv_cache_dtypes:
|
||||
pytest.skip(f"{backend_cls.get_name()} does not support {kv_cache_dtype}")
|
||||
|
||||
if (
|
||||
backend_cls == FlashMLASparseBackend
|
||||
and kv_cache_dtype.startswith("fp8")
|
||||
and kv_cache_dtype != "fp8_ds_mla"
|
||||
):
|
||||
pytest.skip(
|
||||
"FlashMLA Sparse Attention backend fp8 only supports "
|
||||
"fp8_ds_mla kv-cache dtype"
|
||||
)
|
||||
|
||||
supported_block_sizes = backend_cls.get_supported_kernel_block_sizes()
|
||||
if block_size not in supported_block_sizes:
|
||||
pytest.skip(
|
||||
@@ -419,7 +429,7 @@ def test_sparse_backend_decode_correctness(
|
||||
num_blocks=vllm_config.cache_config.num_gpu_blocks,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
randomize_blocks=False,
|
||||
kv_cache_dtype=kv_cache_dtype if use_fp8_ds_mla_quantization else "auto",
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
scale=kv_cache_scale,
|
||||
)
|
||||
|
||||
|
||||
@@ -49,6 +49,11 @@ MLA_ATTENTION_FILE = (
|
||||
# Backends to skip during doc generation
|
||||
SKIP_BACKENDS = {"CUSTOM", "TORCH_SDPA"}
|
||||
|
||||
BACKEND_KV_DTYPE_EXCLUDES: dict[str, set[str]] = {
|
||||
# fp8 is an alias for fp8_ds_mla for FlashMLA Sparse
|
||||
"FLASHMLA_SPARSE": {"fp8"},
|
||||
}
|
||||
|
||||
|
||||
def is_relevant_file(filepath: str) -> bool:
|
||||
"""Check if a file matches any of the relevant patterns."""
|
||||
@@ -546,10 +551,19 @@ def analyze_backend(backend_name: str, class_path: str) -> dict[str, Any] | None
|
||||
tree, impl_class_name, "can_return_lse_for_decode", False, file_path
|
||||
)
|
||||
|
||||
kv_cache_dtypes = parse_kv_cache_dtypes(class_node)
|
||||
if backend_name in BACKEND_KV_DTYPE_EXCLUDES:
|
||||
excluded = BACKEND_KV_DTYPE_EXCLUDES[backend_name]
|
||||
kv_cache_dtypes = ", ".join(
|
||||
d
|
||||
for d in (d.strip() for d in kv_cache_dtypes.split(","))
|
||||
if d not in excluded
|
||||
)
|
||||
|
||||
return {
|
||||
"name": backend_name,
|
||||
"dtypes": parse_supported_dtypes(class_node),
|
||||
"kv_cache_dtypes": parse_kv_cache_dtypes(class_node),
|
||||
"kv_cache_dtypes": kv_cache_dtypes,
|
||||
"block_sizes": parse_block_sizes(class_node),
|
||||
"head_sizes": parse_head_sizes(class_node),
|
||||
"attn_types": parse_attention_types(class_node),
|
||||
|
||||
@@ -331,11 +331,6 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
calculate_kv_scales = False
|
||||
self.quant_config = quant_config
|
||||
|
||||
# Initialize KV cache quantization attributes
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.calculate_kv_scales = calculate_kv_scales
|
||||
_init_kv_cache_quant(self, quant_config, prefix)
|
||||
|
||||
dtype = torch.get_default_dtype()
|
||||
self.attn_backend = get_attn_backend(
|
||||
self.head_size,
|
||||
@@ -347,6 +342,36 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
num_heads=self.num_heads,
|
||||
)
|
||||
|
||||
# FlashMLA Sparse Attention fp8 backend uses "fp8_ds_mla" kv-cache format
|
||||
# Automatically convert fp8 kv-cache format to "fp8_ds_mla"
|
||||
if (
|
||||
self.attn_backend.get_name() == "FLASHMLA_SPARSE"
|
||||
and kv_cache_dtype.startswith("fp8")
|
||||
and kv_cache_dtype != "fp8_ds_mla"
|
||||
):
|
||||
assert cache_config is not None
|
||||
cache_config.cache_dtype = "fp8_ds_mla"
|
||||
kv_cache_dtype = "fp8_ds_mla"
|
||||
logger.info_once(
|
||||
"Using DeepSeek's fp8_ds_mla KV cache format. To use standard "
|
||||
"fp8 kv-cache format, please set `--attention-backend "
|
||||
"FLASHINFER_MLA_SPARSE`"
|
||||
)
|
||||
|
||||
if (
|
||||
self.attn_backend.get_name() == "FLASHINFER_MLA_SPARSE"
|
||||
and kv_cache_dtype.startswith("fp8")
|
||||
):
|
||||
logger.info_once(
|
||||
"Using standard fp8 KV cache format. To use DeepSeek's fp8_ds_mla "
|
||||
"KV cache format, please set `--attention-backend FLASHMLA_SPARSE`"
|
||||
)
|
||||
|
||||
# Initialize KV cache quantization attributes
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.calculate_kv_scales = calculate_kv_scales
|
||||
_init_kv_cache_quant(self, quant_config, prefix)
|
||||
|
||||
if (
|
||||
cache_config is not None
|
||||
and cache_config.enable_prefix_caching
|
||||
|
||||
@@ -31,20 +31,13 @@ class VerifyAndUpdateConfig:
|
||||
class DeepseekV32ForCausalLM(VerifyAndUpdateConfig):
|
||||
@classmethod
|
||||
def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
|
||||
"""
|
||||
Updated fp8 cache to custom "fp8_ds_mla" format for DeepSeekV32
|
||||
"""
|
||||
hf_config = vllm_config.model_config.hf_config
|
||||
|
||||
# Mirror the check in vllm/model_executor/models/deepseek_v2.py
|
||||
is_v32 = hasattr(hf_config, "index_topk")
|
||||
assert is_v32
|
||||
|
||||
# For DeepSeekV3.2, a custom fp8 format is used when fp8 kv-cache is enabled.
|
||||
cache_config = vllm_config.cache_config
|
||||
if cache_config.cache_dtype.startswith("fp8"):
|
||||
cache_config.cache_dtype = "fp8_ds_mla"
|
||||
logger.info("Using custom fp8 kv-cache format for DeepSeekV3.2")
|
||||
if cache_config.cache_dtype == "bfloat16":
|
||||
cache_config.cache_dtype = "auto"
|
||||
logger.info("Using bfloat16 kv-cache for DeepSeekV3.2")
|
||||
|
||||
@@ -63,6 +63,8 @@ class FlashInferMLASparseBackend(AttentionBackend):
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"bfloat16",
|
||||
"fp8",
|
||||
"fp8_e4m3",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
@@ -304,6 +306,11 @@ class FlashInferMLASparseImpl(SparseMLAAttentionImpl[FlashInferMLASparseMetadata
|
||||
self.bmm1_scale: float | None = None
|
||||
self.bmm2_scale: float | None = None
|
||||
|
||||
# fp8 query quantization is required when using fp8 kv_cache,
|
||||
# as the TRTLLM-GEN sparse MLA kernel requires matching dtypes
|
||||
# for query and kv_cache (mixed bf16+fp8 is not supported).
|
||||
self.supports_quant_query_input = True
|
||||
|
||||
def forward_mqa(
|
||||
self,
|
||||
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
|
||||
@@ -83,6 +83,7 @@ class FlashMLASparseBackend(AttentionBackend):
|
||||
"auto",
|
||||
"bfloat16",
|
||||
"fp8_ds_mla",
|
||||
"fp8", # alias for fp8_ds_mla
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
@@ -567,6 +568,12 @@ class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
|
||||
)
|
||||
self.fp8_decode_padded_heads = self._compute_fp8_decode_padded_heads(num_heads)
|
||||
|
||||
if kv_cache_dtype.startswith("fp8"):
|
||||
assert kv_cache_dtype == "fp8_ds_mla", (
|
||||
"FlashMLA Sparse Attention backend fp8 only supports "
|
||||
"fp8_ds_mla kv-cache dtype"
|
||||
)
|
||||
|
||||
if kv_cache_dtype == "fp8_ds_mla":
|
||||
# Reserve workspace during initialization
|
||||
vllm_config = get_current_vllm_config()
|
||||
|
||||
Reference in New Issue
Block a user