mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
Add CuTe DSL sparse compressor support (#43584)
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com> Co-authored-by: OpenAI Codex <codex@openai.com> Co-authored-by: Yongye Zhu <zyy1102000@gmail.com>
This commit is contained in:
@@ -11,6 +11,12 @@ Three specialized kernels:
|
||||
- _fused_kv_compress_norm_rope_insert_indexer_mxfp4_attn:
|
||||
head=128, MXFP4 (block=32), 4 ue8m0 bytes
|
||||
|
||||
Additional cutedsl kernels:
|
||||
- _compress_kv_sparse_attn_cutedsl / _norm_rope_insert_sparse_attn_cutedsl:
|
||||
CuTe DSL split kernels for C128
|
||||
- _fused_kv_compress_norm_rope_insert_sparse_attn_cutedsl:
|
||||
CuTe DSL fused kernels for C4
|
||||
|
||||
RoPE is register-based via tl.reshape -> tl.split -> tl.interleave (or the
|
||||
even/odd halves are consumed directly for MXFP4, no interleave needed).
|
||||
FP8 UE8M0 quant uses tl.reshape to tile [N_QUANT_BLOCKS, QUANT_BLOCK] for
|
||||
@@ -19,11 +25,43 @@ even/odd halves, producing (N_QUANT_BLOCKS, MXFP4_BLOCK/2) packed nibbles
|
||||
and N_QUANT_BLOCKS ue8m0 bytes.
|
||||
"""
|
||||
|
||||
from functools import cache
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .fused_indexer_q import _fp32x2_to_fp4x2
|
||||
|
||||
|
||||
@cache
|
||||
def _get_sparse_attn_cutedsl_impls():
|
||||
from .sparse_attn_compress_cutedsl import (
|
||||
_compress_kv_sparse_attn_cutedsl,
|
||||
_fused_kv_compress_norm_rope_insert_sparse_attn_cutedsl,
|
||||
_norm_rope_insert_sparse_attn_cutedsl,
|
||||
)
|
||||
|
||||
return (
|
||||
_compress_kv_sparse_attn_cutedsl,
|
||||
_norm_rope_insert_sparse_attn_cutedsl,
|
||||
_fused_kv_compress_norm_rope_insert_sparse_attn_cutedsl,
|
||||
)
|
||||
|
||||
|
||||
def _compress_kv_sparse_attn_cutedsl(*args, **kwargs):
|
||||
"""CuTe DSL sparse-attention compress wrapper."""
|
||||
return _get_sparse_attn_cutedsl_impls()[0](*args, **kwargs)
|
||||
|
||||
|
||||
def _norm_rope_insert_sparse_attn_cutedsl(*args, **kwargs):
|
||||
"""CuTe DSL RMSNorm/RoPE/FP8-store wrapper."""
|
||||
return _get_sparse_attn_cutedsl_impls()[1](*args, **kwargs)
|
||||
|
||||
|
||||
def _fused_kv_compress_norm_rope_insert_sparse_attn_cutedsl(*args, **kwargs):
|
||||
"""CuTe DSL fused C4 sparse-attention compressor wrapper."""
|
||||
return _get_sparse_attn_cutedsl_impls()[2](*args, **kwargs)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# DeepseekV4 Attention path (head=512, nope=448 FP8 + rope=64 bf16)
|
||||
# =============================================================================
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -13,9 +13,11 @@ from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import MergedColumnParallelLinear
|
||||
from vllm.models.deepseek_v4.common.ops.fused_compress_quant_cache import (
|
||||
_compress_kv_sparse_attn_cutedsl,
|
||||
_fused_kv_compress_norm_rope_insert_indexer_attn,
|
||||
_fused_kv_compress_norm_rope_insert_indexer_mxfp4_attn,
|
||||
_fused_kv_compress_norm_rope_insert_sparse_attn,
|
||||
_fused_kv_compress_norm_rope_insert_sparse_attn_cutedsl,
|
||||
_norm_rope_insert_sparse_attn_cutedsl,
|
||||
)
|
||||
from vllm.models.deepseek_v4.common.ops.fused_indexer_q import MXFP4_BLOCK_SIZE
|
||||
from vllm.platforms import current_platform
|
||||
@@ -171,6 +173,33 @@ class CompressorStateCache(torch.nn.Module, AttentionLayerBase):
|
||||
|
||||
|
||||
class DeepseekCompressor(nn.Module):
|
||||
_compressed_kv_buffers: ClassVar[dict[tuple[str, int, int], torch.Tensor]] = {}
|
||||
|
||||
@classmethod
|
||||
def _get_compressed_kv_buffer(
|
||||
cls,
|
||||
device: str,
|
||||
max_num_tokens: int,
|
||||
head_dim: int,
|
||||
) -> torch.Tensor:
|
||||
if device == "cuda" and torch.accelerator.is_available():
|
||||
device_key = f"cuda:{torch.accelerator.current_device_index()}"
|
||||
alloc_device = torch.device(device_key)
|
||||
else:
|
||||
device_key = str(device)
|
||||
alloc_device = torch.device(device)
|
||||
|
||||
key = (device_key, max_num_tokens, head_dim)
|
||||
buffer = cls._compressed_kv_buffers.get(key)
|
||||
if buffer is None:
|
||||
buffer = torch.empty(
|
||||
(max_num_tokens, head_dim),
|
||||
dtype=torch.float32,
|
||||
device=alloc_device,
|
||||
)
|
||||
cls._compressed_kv_buffers[key] = buffer
|
||||
return buffer
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
@@ -240,12 +269,24 @@ class DeepseekCompressor(nn.Module):
|
||||
assert not use_fp4_cache, (
|
||||
"MXFP4 cache is only supported for indexer (head=128)"
|
||||
)
|
||||
self._fused_kernel = _fused_kv_compress_norm_rope_insert_sparse_attn
|
||||
self._use_cutedsl_sparse_compressor = True
|
||||
self._use_cutedsl_fused_sparse_compressor = self.compress_ratio == 4
|
||||
self._compress_kernel = _compress_kv_sparse_attn_cutedsl
|
||||
self._norm_rope_store_kernel = _norm_rope_insert_sparse_attn_cutedsl
|
||||
self._fused_sparse_kernel = (
|
||||
_fused_kv_compress_norm_rope_insert_sparse_attn_cutedsl
|
||||
)
|
||||
self._compressed_kv_buffer = self._get_compressed_kv_buffer(
|
||||
self.device,
|
||||
vllm_config.scheduler_config.max_num_batched_tokens,
|
||||
self.head_dim,
|
||||
)
|
||||
self._quant_block = 64
|
||||
self._token_stride = self.nope_head_dim + self.rope_head_dim * 2
|
||||
self._scale_dim = self.nope_head_dim // 64 + 1 # 7 real + 1 pad
|
||||
self._num_warps = 4
|
||||
elif self.head_dim == 128:
|
||||
self._use_cutedsl_sparse_compressor = False
|
||||
if use_fp4_cache:
|
||||
self._fused_kernel = (
|
||||
_fused_kv_compress_norm_rope_insert_indexer_mxfp4_attn
|
||||
@@ -339,43 +380,104 @@ class DeepseekCompressor(nn.Module):
|
||||
k_cache_metadata = cast(Any, attn_metadata[self.k_cache_prefix])
|
||||
kv_cache = self._static_forward_context[self.k_cache_prefix].kv_cache
|
||||
|
||||
self._fused_kernel[(num_actual,)](
|
||||
# state cache
|
||||
state_cache,
|
||||
state_cache.stride(0),
|
||||
state_cache.stride(1),
|
||||
# metadata
|
||||
token_to_req_indices,
|
||||
positions,
|
||||
slot_mapping,
|
||||
block_table,
|
||||
block_table.stride(0),
|
||||
block_size,
|
||||
# RMSNorm
|
||||
self.norm.weight,
|
||||
self.rms_norm_eps,
|
||||
# RoPE
|
||||
cos_sin_cache,
|
||||
cos_sin_cache.stride(0),
|
||||
# KV cache
|
||||
kv_cache,
|
||||
k_cache_metadata.slot_mapping,
|
||||
kv_cache.shape[1], # paged KV cache block size (tokens per block)
|
||||
# constexprs
|
||||
HEAD_SIZE=self.head_dim,
|
||||
TRITON_BLOCK_SIZE=triton.next_power_of_2(self.head_dim),
|
||||
STATE_WIDTH=state_width,
|
||||
COMPRESS_RATIO=self.compress_ratio,
|
||||
OVERLAP=self.overlap,
|
||||
ROPE_HEAD_DIM=self.rope_head_dim,
|
||||
FP8_MAX=448.0,
|
||||
QUANT_BLOCK=self._quant_block,
|
||||
TOKEN_STRIDE=self._token_stride,
|
||||
SCALE_DIM=self._scale_dim,
|
||||
KV_BLOCK_STRIDE=kv_cache.stride(0),
|
||||
num_warps=self._num_warps,
|
||||
**pdl_kwargs,
|
||||
)
|
||||
if self._use_cutedsl_sparse_compressor:
|
||||
if self._use_cutedsl_fused_sparse_compressor:
|
||||
self._fused_sparse_kernel(
|
||||
state_cache,
|
||||
token_to_req_indices,
|
||||
positions,
|
||||
slot_mapping,
|
||||
block_table,
|
||||
block_size,
|
||||
self.norm.weight,
|
||||
self.rms_norm_eps,
|
||||
cos_sin_cache,
|
||||
kv_cache,
|
||||
k_cache_metadata.slot_mapping,
|
||||
kv_cache.shape[1], # paged KV cache block size
|
||||
kv_cache.stride(0),
|
||||
head_size=self.head_dim,
|
||||
state_width=state_width,
|
||||
rope_head_dim=self.rope_head_dim,
|
||||
fp8_max=448.0,
|
||||
quant_block=self._quant_block,
|
||||
token_stride=self._token_stride,
|
||||
scale_dim=self._scale_dim,
|
||||
compress_ratio=self.compress_ratio,
|
||||
overlap=self.overlap,
|
||||
)
|
||||
else:
|
||||
compressed_kv = self._compressed_kv_buffer[:num_actual]
|
||||
self._compress_kernel(
|
||||
state_cache,
|
||||
token_to_req_indices,
|
||||
positions,
|
||||
slot_mapping,
|
||||
block_table,
|
||||
block_size,
|
||||
compressed_kv,
|
||||
head_size=self.head_dim,
|
||||
state_width=state_width,
|
||||
compress_ratio=self.compress_ratio,
|
||||
overlap=self.overlap,
|
||||
)
|
||||
self._norm_rope_store_kernel(
|
||||
compressed_kv,
|
||||
positions,
|
||||
slot_mapping,
|
||||
self.norm.weight,
|
||||
self.rms_norm_eps,
|
||||
cos_sin_cache,
|
||||
kv_cache,
|
||||
k_cache_metadata.slot_mapping,
|
||||
kv_cache.shape[1], # paged KV cache block size
|
||||
kv_cache.stride(0),
|
||||
head_size=self.head_dim,
|
||||
rope_head_dim=self.rope_head_dim,
|
||||
fp8_max=448.0,
|
||||
quant_block=self._quant_block,
|
||||
token_stride=self._token_stride,
|
||||
scale_dim=self._scale_dim,
|
||||
compress_ratio=self.compress_ratio,
|
||||
)
|
||||
else:
|
||||
self._fused_kernel[(num_actual,)](
|
||||
# state cache
|
||||
state_cache,
|
||||
state_cache.stride(0),
|
||||
state_cache.stride(1),
|
||||
# metadata
|
||||
token_to_req_indices,
|
||||
positions,
|
||||
slot_mapping,
|
||||
block_table,
|
||||
block_table.stride(0),
|
||||
block_size,
|
||||
# RMSNorm
|
||||
self.norm.weight,
|
||||
self.rms_norm_eps,
|
||||
# RoPE
|
||||
cos_sin_cache,
|
||||
cos_sin_cache.stride(0),
|
||||
# KV cache
|
||||
kv_cache,
|
||||
k_cache_metadata.slot_mapping,
|
||||
kv_cache.shape[1], # paged KV cache block size (tokens per block)
|
||||
# constexprs
|
||||
HEAD_SIZE=self.head_dim,
|
||||
TRITON_BLOCK_SIZE=triton.next_power_of_2(self.head_dim),
|
||||
STATE_WIDTH=state_width,
|
||||
COMPRESS_RATIO=self.compress_ratio,
|
||||
OVERLAP=self.overlap,
|
||||
ROPE_HEAD_DIM=self.rope_head_dim,
|
||||
FP8_MAX=448.0,
|
||||
QUANT_BLOCK=self._quant_block,
|
||||
TOKEN_STRIDE=self._token_stride,
|
||||
SCALE_DIM=self._scale_dim,
|
||||
KV_BLOCK_STRIDE=kv_cache.stride(0),
|
||||
num_warps=self._num_warps,
|
||||
**pdl_kwargs,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
|
||||
Reference in New Issue
Block a user