mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[Model Runner V2] Support zeroing freshly allocated KV blocks for hybrid + fp8 KVCache (#43990)
Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
This commit is contained in:
@@ -47,6 +47,7 @@ from vllm.sequence import IntermediateTensors
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.mem_utils import DeviceMemoryProfiler, format_gib
|
||||
from vllm.utils.platform_utils import is_pin_memory_available
|
||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, MambaSpec
|
||||
@@ -103,6 +104,7 @@ from vllm.v1.worker.gpu.spec_decode.utils import DraftTokensHandler
|
||||
from vllm.v1.worker.gpu.states import RequestState
|
||||
from vllm.v1.worker.gpu.structured_outputs import StructuredOutputsWorker
|
||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||
from vllm.v1.worker.utils import KVBlockZeroer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -129,6 +131,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.cache_config.cache_dtype
|
||||
]
|
||||
|
||||
# Lazily initialized in _init_kv_zero_meta() when the KV cache needs
|
||||
# zeroing (e.g. hybrid models with fp8 KV cache).
|
||||
self.kv_block_zeroer: KVBlockZeroer | None = None
|
||||
|
||||
self.vocab_size = self.model_config.get_vocab_size()
|
||||
self.max_model_len = self.model_config.max_model_len
|
||||
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
|
||||
@@ -393,7 +399,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
) + spec.num_speculative_blocks
|
||||
max_num_blocks_per_group.append(max_num_blocks)
|
||||
|
||||
self.attn_groups, attn_cg_support, kernel_block_sizes = init_attn_backend(
|
||||
self.attn_groups, attn_cg_support, self.kernel_block_sizes = init_attn_backend(
|
||||
self.kv_cache_config, self.vllm_config, self.device
|
||||
)
|
||||
self.block_tables = BlockTables(
|
||||
@@ -402,7 +408,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
max_num_batched_tokens=self.max_num_tokens,
|
||||
max_num_blocks_per_group=max_num_blocks_per_group,
|
||||
device=self.device,
|
||||
kernel_block_sizes=kernel_block_sizes,
|
||||
kernel_block_sizes=self.kernel_block_sizes,
|
||||
cp_size=self.dcp_size,
|
||||
cp_rank=self.dcp_rank,
|
||||
cp_interleave=self.cp_interleave,
|
||||
@@ -442,11 +448,22 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.attn_groups,
|
||||
self.device,
|
||||
self.cache_config.cache_dtype,
|
||||
kernel_block_sizes,
|
||||
self.kernel_block_sizes,
|
||||
self.vllm_config,
|
||||
)
|
||||
self.kv_connector = get_kv_connector(self.vllm_config, kv_caches_dict)
|
||||
|
||||
def _init_kv_zero_meta(self) -> None:
|
||||
"""Build KV-block zeroing metadata; invoked from gpu_worker."""
|
||||
self.kv_block_zeroer = KVBlockZeroer(
|
||||
self.device,
|
||||
is_pin_memory_available(),
|
||||
attn_groups_iter=(g for groups in self.attn_groups for g in groups),
|
||||
kernel_block_sizes=self.kernel_block_sizes,
|
||||
cache_dtype=self.cache_config.cache_dtype,
|
||||
static_forward_context=self.compilation_config.static_forward_context,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
@step_eplb_after(is_dummy=True)
|
||||
def _dummy_run(
|
||||
@@ -753,6 +770,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
out=self.req_states.num_computed_prefill_tokens,
|
||||
)
|
||||
|
||||
# Zero GPU memory for freshly allocated cache blocks to prevent
|
||||
# stale NaN/data from corrupting attention or SSM computation.
|
||||
if scheduler_output.new_block_ids_to_zero:
|
||||
assert self.kv_block_zeroer is not None
|
||||
self.kv_block_zeroer.zero_block_ids(scheduler_output.new_block_ids_to_zero)
|
||||
|
||||
def prepare_inputs(
|
||||
self, scheduler_output: SchedulerOutput, batch_desc: BatchExecutionDescriptor
|
||||
) -> InputBatch:
|
||||
|
||||
@@ -1084,11 +1084,11 @@ class GPUModelRunner(
|
||||
def _init_kv_zero_meta(self) -> None:
|
||||
"""One-time precomputation for _zero_block_ids.
|
||||
|
||||
Delegates to KVBlockZeroer.init_meta with the runner's state.
|
||||
Called from gpu_worker.py outside the CuMem pool context.
|
||||
"""
|
||||
self._kv_block_zeroer = KVBlockZeroer(self.device, self.pin_memory)
|
||||
self._kv_block_zeroer.init_meta(
|
||||
self._kv_block_zeroer = KVBlockZeroer(
|
||||
self.device,
|
||||
self.pin_memory,
|
||||
attn_groups_iter=self._kv_cache_spec_attn_group_iterator(),
|
||||
kernel_block_sizes=self._kernel_block_sizes,
|
||||
cache_dtype=self.cache_config.cache_dtype,
|
||||
|
||||
+16
-14
@@ -80,30 +80,23 @@ def _zero_kv_blocks_kernel(
|
||||
class KVBlockZeroer:
|
||||
"""Manages efficient zeroing of KV cache blocks via a Triton kernel.
|
||||
|
||||
Call :meth:`init_meta` once after KV caches are allocated to precompute
|
||||
segment addresses, then call :meth:`zero_block_ids` each step to zero
|
||||
Construct once after KV caches are allocated to precompute segment
|
||||
addresses, then call :meth:`zero_block_ids` each step to zero
|
||||
newly-allocated blocks.
|
||||
"""
|
||||
|
||||
def __init__(self, device: torch.device, pin_memory: bool):
|
||||
self.device = device
|
||||
self.pin_memory = pin_memory
|
||||
self._meta: tuple[torch.Tensor, int, int, int] | None = None
|
||||
self._id_cap: int = 0
|
||||
self._ids_pinned: torch.Tensor | None = None
|
||||
self._ids_gpu: torch.Tensor | None = None
|
||||
|
||||
def init_meta(
|
||||
def __init__(
|
||||
self,
|
||||
device: torch.device,
|
||||
pin_memory: bool,
|
||||
attn_groups_iter: Iterable["AttentionGroup"],
|
||||
kernel_block_sizes: list[int],
|
||||
cache_dtype: str,
|
||||
runner_only_attn_layers: set[str],
|
||||
static_forward_context: dict[str, Any],
|
||||
runner_only_attn_layers: set[str] | None = None,
|
||||
) -> None:
|
||||
"""One-time precomputation for zero_block_ids.
|
||||
"""Precompute the absolute-address table for the Triton zeroing kernel.
|
||||
|
||||
Builds absolute-address table for the Triton zeroing kernel.
|
||||
Each entry is the absolute byte address of a segment start on the
|
||||
GPU, so segments in different CUDA allocations work correctly.
|
||||
|
||||
@@ -114,6 +107,15 @@ class KVBlockZeroer:
|
||||
|
||||
Only AttentionSpec layers are processed; Mamba layers are skipped.
|
||||
"""
|
||||
self.device = device
|
||||
self.pin_memory = pin_memory
|
||||
self._meta: tuple[torch.Tensor, int, int, int] | None = None
|
||||
self._id_cap: int = 0
|
||||
self._ids_pinned: torch.Tensor | None = None
|
||||
self._ids_gpu: torch.Tensor | None = None
|
||||
|
||||
if runner_only_attn_layers is None:
|
||||
runner_only_attn_layers = set()
|
||||
seen_ptrs: set[int] = set()
|
||||
seg_addrs: list[int] = []
|
||||
page_size_el: int | None = None
|
||||
|
||||
Reference in New Issue
Block a user