[Model Runner V2] Support zeroing freshly allocated KV blocks for hybrid + fp8 KVCache (#43990)

Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
This commit is contained in:
zhrrr
2026-06-02 13:56:53 +08:00
committed by GitHub
parent f91fb2fcf3
commit 8a9eb40808
3 changed files with 45 additions and 20 deletions
+26 -3
View File
@@ -47,6 +47,7 @@ from vllm.sequence import IntermediateTensors
from vllm.tasks import SupportedTask
from vllm.utils.math_utils import cdiv
from vllm.utils.mem_utils import DeviceMemoryProfiler, format_gib
from vllm.utils.platform_utils import is_pin_memory_available
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig, MambaSpec
@@ -103,6 +104,7 @@ from vllm.v1.worker.gpu.spec_decode.utils import DraftTokensHandler
from vllm.v1.worker.gpu.states import RequestState
from vllm.v1.worker.gpu.structured_outputs import StructuredOutputsWorker
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from vllm.v1.worker.utils import KVBlockZeroer
logger = init_logger(__name__)
@@ -129,6 +131,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.cache_config.cache_dtype
]
# Lazily initialized in _init_kv_zero_meta() when the KV cache needs
# zeroing (e.g. hybrid models with fp8 KV cache).
self.kv_block_zeroer: KVBlockZeroer | None = None
self.vocab_size = self.model_config.get_vocab_size()
self.max_model_len = self.model_config.max_model_len
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
@@ -393,7 +399,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
) + spec.num_speculative_blocks
max_num_blocks_per_group.append(max_num_blocks)
self.attn_groups, attn_cg_support, kernel_block_sizes = init_attn_backend(
self.attn_groups, attn_cg_support, self.kernel_block_sizes = init_attn_backend(
self.kv_cache_config, self.vllm_config, self.device
)
self.block_tables = BlockTables(
@@ -402,7 +408,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
max_num_batched_tokens=self.max_num_tokens,
max_num_blocks_per_group=max_num_blocks_per_group,
device=self.device,
kernel_block_sizes=kernel_block_sizes,
kernel_block_sizes=self.kernel_block_sizes,
cp_size=self.dcp_size,
cp_rank=self.dcp_rank,
cp_interleave=self.cp_interleave,
@@ -442,11 +448,22 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.attn_groups,
self.device,
self.cache_config.cache_dtype,
kernel_block_sizes,
self.kernel_block_sizes,
self.vllm_config,
)
self.kv_connector = get_kv_connector(self.vllm_config, kv_caches_dict)
def _init_kv_zero_meta(self) -> None:
"""Build KV-block zeroing metadata; invoked from gpu_worker."""
self.kv_block_zeroer = KVBlockZeroer(
self.device,
is_pin_memory_available(),
attn_groups_iter=(g for groups in self.attn_groups for g in groups),
kernel_block_sizes=self.kernel_block_sizes,
cache_dtype=self.cache_config.cache_dtype,
static_forward_context=self.compilation_config.static_forward_context,
)
@torch.inference_mode()
@step_eplb_after(is_dummy=True)
def _dummy_run(
@@ -753,6 +770,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
out=self.req_states.num_computed_prefill_tokens,
)
# Zero GPU memory for freshly allocated cache blocks to prevent
# stale NaN/data from corrupting attention or SSM computation.
if scheduler_output.new_block_ids_to_zero:
assert self.kv_block_zeroer is not None
self.kv_block_zeroer.zero_block_ids(scheduler_output.new_block_ids_to_zero)
def prepare_inputs(
self, scheduler_output: SchedulerOutput, batch_desc: BatchExecutionDescriptor
) -> InputBatch:
+3 -3
View File
@@ -1084,11 +1084,11 @@ class GPUModelRunner(
def _init_kv_zero_meta(self) -> None:
"""One-time precomputation for _zero_block_ids.
Delegates to KVBlockZeroer.init_meta with the runner's state.
Called from gpu_worker.py outside the CuMem pool context.
"""
self._kv_block_zeroer = KVBlockZeroer(self.device, self.pin_memory)
self._kv_block_zeroer.init_meta(
self._kv_block_zeroer = KVBlockZeroer(
self.device,
self.pin_memory,
attn_groups_iter=self._kv_cache_spec_attn_group_iterator(),
kernel_block_sizes=self._kernel_block_sizes,
cache_dtype=self.cache_config.cache_dtype,
+16 -14
View File
@@ -80,30 +80,23 @@ def _zero_kv_blocks_kernel(
class KVBlockZeroer:
"""Manages efficient zeroing of KV cache blocks via a Triton kernel.
Call :meth:`init_meta` once after KV caches are allocated to precompute
segment addresses, then call :meth:`zero_block_ids` each step to zero
Construct once after KV caches are allocated to precompute segment
addresses, then call :meth:`zero_block_ids` each step to zero
newly-allocated blocks.
"""
def __init__(self, device: torch.device, pin_memory: bool):
self.device = device
self.pin_memory = pin_memory
self._meta: tuple[torch.Tensor, int, int, int] | None = None
self._id_cap: int = 0
self._ids_pinned: torch.Tensor | None = None
self._ids_gpu: torch.Tensor | None = None
def init_meta(
def __init__(
self,
device: torch.device,
pin_memory: bool,
attn_groups_iter: Iterable["AttentionGroup"],
kernel_block_sizes: list[int],
cache_dtype: str,
runner_only_attn_layers: set[str],
static_forward_context: dict[str, Any],
runner_only_attn_layers: set[str] | None = None,
) -> None:
"""One-time precomputation for zero_block_ids.
"""Precompute the absolute-address table for the Triton zeroing kernel.
Builds absolute-address table for the Triton zeroing kernel.
Each entry is the absolute byte address of a segment start on the
GPU, so segments in different CUDA allocations work correctly.
@@ -114,6 +107,15 @@ class KVBlockZeroer:
Only AttentionSpec layers are processed; Mamba layers are skipped.
"""
self.device = device
self.pin_memory = pin_memory
self._meta: tuple[torch.Tensor, int, int, int] | None = None
self._id_cap: int = 0
self._ids_pinned: torch.Tensor | None = None
self._ids_gpu: torch.Tensor | None = None
if runner_only_attn_layers is None:
runner_only_attn_layers = set()
seen_ptrs: set[int] = set()
seg_addrs: list[int] = []
page_size_el: int | None = None