Add CuTe DSL sparse compressor support (#43584)

Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Co-authored-by: OpenAI Codex <codex@openai.com>
Co-authored-by: Yongye Zhu <zyy1102000@gmail.com>
This commit is contained in:
Jie Fang
2026-05-26 15:11:12 +08:00
committed by GitHub
parent e6adbd7834
commit a37e47100c
3 changed files with 1419 additions and 39 deletions
@@ -11,6 +11,12 @@ Three specialized kernels:
- _fused_kv_compress_norm_rope_insert_indexer_mxfp4_attn:
head=128, MXFP4 (block=32), 4 ue8m0 bytes
Additional cutedsl kernels:
- _compress_kv_sparse_attn_cutedsl / _norm_rope_insert_sparse_attn_cutedsl:
CuTe DSL split kernels for C128
- _fused_kv_compress_norm_rope_insert_sparse_attn_cutedsl:
CuTe DSL fused kernels for C4
RoPE is register-based via tl.reshape -> tl.split -> tl.interleave (or the
even/odd halves are consumed directly for MXFP4, no interleave needed).
FP8 UE8M0 quant uses tl.reshape to tile [N_QUANT_BLOCKS, QUANT_BLOCK] for
@@ -19,11 +25,43 @@ even/odd halves, producing (N_QUANT_BLOCKS, MXFP4_BLOCK/2) packed nibbles
and N_QUANT_BLOCKS ue8m0 bytes.
"""
from functools import cache
from vllm.triton_utils import tl, triton
from .fused_indexer_q import _fp32x2_to_fp4x2
@cache
def _get_sparse_attn_cutedsl_impls():
from .sparse_attn_compress_cutedsl import (
_compress_kv_sparse_attn_cutedsl,
_fused_kv_compress_norm_rope_insert_sparse_attn_cutedsl,
_norm_rope_insert_sparse_attn_cutedsl,
)
return (
_compress_kv_sparse_attn_cutedsl,
_norm_rope_insert_sparse_attn_cutedsl,
_fused_kv_compress_norm_rope_insert_sparse_attn_cutedsl,
)
def _compress_kv_sparse_attn_cutedsl(*args, **kwargs):
"""CuTe DSL sparse-attention compress wrapper."""
return _get_sparse_attn_cutedsl_impls()[0](*args, **kwargs)
def _norm_rope_insert_sparse_attn_cutedsl(*args, **kwargs):
"""CuTe DSL RMSNorm/RoPE/FP8-store wrapper."""
return _get_sparse_attn_cutedsl_impls()[1](*args, **kwargs)
def _fused_kv_compress_norm_rope_insert_sparse_attn_cutedsl(*args, **kwargs):
"""CuTe DSL fused C4 sparse-attention compressor wrapper."""
return _get_sparse_attn_cutedsl_impls()[2](*args, **kwargs)
# =============================================================================
# DeepseekV4 Attention path (head=512, nope=448 FP8 + rope=64 bf16)
# =============================================================================
File diff suppressed because it is too large Load Diff
+141 -39
View File
@@ -13,9 +13,11 @@ from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import MergedColumnParallelLinear
from vllm.models.deepseek_v4.common.ops.fused_compress_quant_cache import (
_compress_kv_sparse_attn_cutedsl,
_fused_kv_compress_norm_rope_insert_indexer_attn,
_fused_kv_compress_norm_rope_insert_indexer_mxfp4_attn,
_fused_kv_compress_norm_rope_insert_sparse_attn,
_fused_kv_compress_norm_rope_insert_sparse_attn_cutedsl,
_norm_rope_insert_sparse_attn_cutedsl,
)
from vllm.models.deepseek_v4.common.ops.fused_indexer_q import MXFP4_BLOCK_SIZE
from vllm.platforms import current_platform
@@ -171,6 +173,33 @@ class CompressorStateCache(torch.nn.Module, AttentionLayerBase):
class DeepseekCompressor(nn.Module):
_compressed_kv_buffers: ClassVar[dict[tuple[str, int, int], torch.Tensor]] = {}
@classmethod
def _get_compressed_kv_buffer(
cls,
device: str,
max_num_tokens: int,
head_dim: int,
) -> torch.Tensor:
if device == "cuda" and torch.accelerator.is_available():
device_key = f"cuda:{torch.accelerator.current_device_index()}"
alloc_device = torch.device(device_key)
else:
device_key = str(device)
alloc_device = torch.device(device)
key = (device_key, max_num_tokens, head_dim)
buffer = cls._compressed_kv_buffers.get(key)
if buffer is None:
buffer = torch.empty(
(max_num_tokens, head_dim),
dtype=torch.float32,
device=alloc_device,
)
cls._compressed_kv_buffers[key] = buffer
return buffer
def __init__(
self,
vllm_config: VllmConfig,
@@ -240,12 +269,24 @@ class DeepseekCompressor(nn.Module):
assert not use_fp4_cache, (
"MXFP4 cache is only supported for indexer (head=128)"
)
self._fused_kernel = _fused_kv_compress_norm_rope_insert_sparse_attn
self._use_cutedsl_sparse_compressor = True
self._use_cutedsl_fused_sparse_compressor = self.compress_ratio == 4
self._compress_kernel = _compress_kv_sparse_attn_cutedsl
self._norm_rope_store_kernel = _norm_rope_insert_sparse_attn_cutedsl
self._fused_sparse_kernel = (
_fused_kv_compress_norm_rope_insert_sparse_attn_cutedsl
)
self._compressed_kv_buffer = self._get_compressed_kv_buffer(
self.device,
vllm_config.scheduler_config.max_num_batched_tokens,
self.head_dim,
)
self._quant_block = 64
self._token_stride = self.nope_head_dim + self.rope_head_dim * 2
self._scale_dim = self.nope_head_dim // 64 + 1 # 7 real + 1 pad
self._num_warps = 4
elif self.head_dim == 128:
self._use_cutedsl_sparse_compressor = False
if use_fp4_cache:
self._fused_kernel = (
_fused_kv_compress_norm_rope_insert_indexer_mxfp4_attn
@@ -339,43 +380,104 @@ class DeepseekCompressor(nn.Module):
k_cache_metadata = cast(Any, attn_metadata[self.k_cache_prefix])
kv_cache = self._static_forward_context[self.k_cache_prefix].kv_cache
self._fused_kernel[(num_actual,)](
# state cache
state_cache,
state_cache.stride(0),
state_cache.stride(1),
# metadata
token_to_req_indices,
positions,
slot_mapping,
block_table,
block_table.stride(0),
block_size,
# RMSNorm
self.norm.weight,
self.rms_norm_eps,
# RoPE
cos_sin_cache,
cos_sin_cache.stride(0),
# KV cache
kv_cache,
k_cache_metadata.slot_mapping,
kv_cache.shape[1], # paged KV cache block size (tokens per block)
# constexprs
HEAD_SIZE=self.head_dim,
TRITON_BLOCK_SIZE=triton.next_power_of_2(self.head_dim),
STATE_WIDTH=state_width,
COMPRESS_RATIO=self.compress_ratio,
OVERLAP=self.overlap,
ROPE_HEAD_DIM=self.rope_head_dim,
FP8_MAX=448.0,
QUANT_BLOCK=self._quant_block,
TOKEN_STRIDE=self._token_stride,
SCALE_DIM=self._scale_dim,
KV_BLOCK_STRIDE=kv_cache.stride(0),
num_warps=self._num_warps,
**pdl_kwargs,
)
if self._use_cutedsl_sparse_compressor:
if self._use_cutedsl_fused_sparse_compressor:
self._fused_sparse_kernel(
state_cache,
token_to_req_indices,
positions,
slot_mapping,
block_table,
block_size,
self.norm.weight,
self.rms_norm_eps,
cos_sin_cache,
kv_cache,
k_cache_metadata.slot_mapping,
kv_cache.shape[1], # paged KV cache block size
kv_cache.stride(0),
head_size=self.head_dim,
state_width=state_width,
rope_head_dim=self.rope_head_dim,
fp8_max=448.0,
quant_block=self._quant_block,
token_stride=self._token_stride,
scale_dim=self._scale_dim,
compress_ratio=self.compress_ratio,
overlap=self.overlap,
)
else:
compressed_kv = self._compressed_kv_buffer[:num_actual]
self._compress_kernel(
state_cache,
token_to_req_indices,
positions,
slot_mapping,
block_table,
block_size,
compressed_kv,
head_size=self.head_dim,
state_width=state_width,
compress_ratio=self.compress_ratio,
overlap=self.overlap,
)
self._norm_rope_store_kernel(
compressed_kv,
positions,
slot_mapping,
self.norm.weight,
self.rms_norm_eps,
cos_sin_cache,
kv_cache,
k_cache_metadata.slot_mapping,
kv_cache.shape[1], # paged KV cache block size
kv_cache.stride(0),
head_size=self.head_dim,
rope_head_dim=self.rope_head_dim,
fp8_max=448.0,
quant_block=self._quant_block,
token_stride=self._token_stride,
scale_dim=self._scale_dim,
compress_ratio=self.compress_ratio,
)
else:
self._fused_kernel[(num_actual,)](
# state cache
state_cache,
state_cache.stride(0),
state_cache.stride(1),
# metadata
token_to_req_indices,
positions,
slot_mapping,
block_table,
block_table.stride(0),
block_size,
# RMSNorm
self.norm.weight,
self.rms_norm_eps,
# RoPE
cos_sin_cache,
cos_sin_cache.stride(0),
# KV cache
kv_cache,
k_cache_metadata.slot_mapping,
kv_cache.shape[1], # paged KV cache block size (tokens per block)
# constexprs
HEAD_SIZE=self.head_dim,
TRITON_BLOCK_SIZE=triton.next_power_of_2(self.head_dim),
STATE_WIDTH=state_width,
COMPRESS_RATIO=self.compress_ratio,
OVERLAP=self.overlap,
ROPE_HEAD_DIM=self.rope_head_dim,
FP8_MAX=448.0,
QUANT_BLOCK=self._quant_block,
TOKEN_STRIDE=self._token_stride,
SCALE_DIM=self._scale_dim,
KV_BLOCK_STRIDE=kv_cache.stride(0),
num_warps=self._num_warps,
**pdl_kwargs,
)
@triton.jit