mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[DSv4] Drop _get_compressed_kv_buffer in DeepseekCompressor (#43690)
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
(cherry picked from commit 193ce8812e)
This commit is contained in:
@@ -173,33 +173,6 @@ class CompressorStateCache(torch.nn.Module, AttentionLayerBase):
|
||||
|
||||
|
||||
class DeepseekCompressor(nn.Module):
|
||||
_compressed_kv_buffers: ClassVar[dict[tuple[str, int, int], torch.Tensor]] = {}
|
||||
|
||||
@classmethod
|
||||
def _get_compressed_kv_buffer(
|
||||
cls,
|
||||
device: str,
|
||||
max_num_tokens: int,
|
||||
head_dim: int,
|
||||
) -> torch.Tensor:
|
||||
if device == "cuda" and torch.accelerator.is_available():
|
||||
device_key = f"cuda:{torch.accelerator.current_device_index()}"
|
||||
alloc_device = torch.device(device_key)
|
||||
else:
|
||||
device_key = str(device)
|
||||
alloc_device = torch.device(device)
|
||||
|
||||
key = (device_key, max_num_tokens, head_dim)
|
||||
buffer = cls._compressed_kv_buffers.get(key)
|
||||
if buffer is None:
|
||||
buffer = torch.empty(
|
||||
(max_num_tokens, head_dim),
|
||||
dtype=torch.float32,
|
||||
device=alloc_device,
|
||||
)
|
||||
cls._compressed_kv_buffers[key] = buffer
|
||||
return buffer
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
@@ -276,11 +249,6 @@ class DeepseekCompressor(nn.Module):
|
||||
self._fused_sparse_kernel = (
|
||||
_fused_kv_compress_norm_rope_insert_sparse_attn_cutedsl
|
||||
)
|
||||
self._compressed_kv_buffer = self._get_compressed_kv_buffer(
|
||||
self.device,
|
||||
vllm_config.scheduler_config.max_num_batched_tokens,
|
||||
self.head_dim,
|
||||
)
|
||||
self._quant_block = 64
|
||||
self._token_stride = self.nope_head_dim + self.rope_head_dim * 2
|
||||
self._scale_dim = self.nope_head_dim // 64 + 1 # 7 real + 1 pad
|
||||
@@ -407,7 +375,11 @@ class DeepseekCompressor(nn.Module):
|
||||
overlap=self.overlap,
|
||||
)
|
||||
else:
|
||||
compressed_kv = self._compressed_kv_buffer[:num_actual]
|
||||
compressed_kv = torch.empty(
|
||||
(num_actual, self.head_dim),
|
||||
dtype=torch.float32,
|
||||
device=state_cache.device,
|
||||
)
|
||||
self._compress_kernel(
|
||||
state_cache,
|
||||
token_to_req_indices,
|
||||
|
||||
Reference in New Issue
Block a user