[None][feat] Add KV cache cleanup (#7439)

Signed-off-by: Pengbo Wang <221450789+pengbowang-nv@users.noreply.github.com>
This commit is contained in:
Pengbo Wang 2026-01-22 15:14:17 +08:00 committed by GitHub
parent fd2af8d58a
commit 9462d90ec7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 61 additions and 0 deletions

View File

@ -571,7 +571,40 @@ class PyTorchModelEngine(ModelEngine):
finally:
self.cuda_graph_runner.enabled = _run_cuda_graphs
@staticmethod
def warmup_with_kv_cache_cleanup(method):
"""
Decorator for warmup methods that cleans up NaNs/Infs in KV Cache after warmup execution.
Why this is needed:
- Our attention kernel uses multiplication by zero to mask out invalid tokens within
the same page. Since NaN/Inf * 0 = NaN, any NaNs/Infs in these invalid KV areas
will persist after masking.
- These NaNs/Infs propagate to outputs and subsequent KV Cache entries, corrupting
future computations with higher probability.
- During warmup, we execute with placeholder data rather than actual valid inputs,
which can introduce NaNs/Infs into KV Cache pages and cause random, hard-to-debug
accuracy issues.
"""
@functools.wraps(method)
def wrapper(self, resource_manager: ResourceManager, *args, **kwargs):
result = method(self, resource_manager, *args, **kwargs)
kv_cache_manager = resource_manager.get_resource_manager(
self.kv_cache_manager_key)
if kv_cache_manager is not None:
has_invalid_values = kv_cache_manager.check_invalid_values_in_kv_cache(
fill_with_zero=True)
if has_invalid_values:
logger.warning(
"NaNs/Infs have been introduced to KVCache during warmup, KVCache was filled with zeros to avoid potential issues"
)
return result
return wrapper
@with_warmup_flag
@warmup_with_kv_cache_cleanup
def warmup(self, resource_manager: ResourceManager) -> None:
"""
Orchestrates the warmup process by calling specialized warmup methods for

View File

@ -941,6 +941,34 @@ class KVCacheManager(BaseResourceManager):
result = self.impl.get_indexer_k_cache_pool_data(layer_idx)
return result.view(result.shape[0], -1)
def check_invalid_values_in_kv_cache(self,
fill_with_zero: bool = False) -> bool:
some_checks_unavailable = False
has_invalid_values = torch.tensor([False],
dtype=torch.bool,
device=torch.cuda.current_device())
for layer_idx, layer_offset in self.layer_offsets.items():
buffer = self.impl.get_primary_pool_data(layer_offset)
# process in chunks of 256 pages to avoid OoM
for i in range(0, buffer.shape[0], 256):
buffer_slice = buffer[i:i + 256]
try:
has_invalid_values.logical_or_(
torch.isnan(buffer_slice).any())
has_invalid_values.logical_or_(
torch.isinf(buffer_slice).any())
except NotImplementedError:
some_checks_unavailable = True
if fill_with_zero:
buffer.zero_()
torch.cuda.synchronize()
if some_checks_unavailable:
logger.warning(
"`torch.isnan` or `torch.isinf` is not implemented for current kv cache dtype, related checks are skipped"
)
return bool(has_invalid_values)
def get_unique_primary_pool(self) -> torch.Tensor:
return self.impl.get_unique_primary_pool()