mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-05 18:51:38 +08:00
[None][feat] Add KV cache cleanup (#7439)
Signed-off-by: Pengbo Wang <221450789+pengbowang-nv@users.noreply.github.com>
This commit is contained in:
parent
fd2af8d58a
commit
9462d90ec7
@ -571,7 +571,40 @@ class PyTorchModelEngine(ModelEngine):
|
||||
finally:
|
||||
self.cuda_graph_runner.enabled = _run_cuda_graphs
|
||||
|
||||
@staticmethod
|
||||
def warmup_with_kv_cache_cleanup(method):
|
||||
"""
|
||||
Decorator for warmup methods that cleans up NaNs/Infs in KV Cache after warmup execution.
|
||||
|
||||
Why this is needed:
|
||||
- Our attention kernel uses multiplication by zero to mask out invalid tokens within
|
||||
the same page. Since NaN/Inf * 0 = NaN, any NaNs/Infs in these invalid KV areas
|
||||
will persist after masking.
|
||||
- These NaNs/Infs propagate to outputs and subsequent KV Cache entries, corrupting
|
||||
future computations with higher probability.
|
||||
- During warmup, we execute with placeholder data rather than actual valid inputs,
|
||||
which can introduce NaNs/Infs into KV Cache pages and cause random, hard-to-debug
|
||||
accuracy issues.
|
||||
"""
|
||||
|
||||
@functools.wraps(method)
|
||||
def wrapper(self, resource_manager: ResourceManager, *args, **kwargs):
|
||||
result = method(self, resource_manager, *args, **kwargs)
|
||||
kv_cache_manager = resource_manager.get_resource_manager(
|
||||
self.kv_cache_manager_key)
|
||||
if kv_cache_manager is not None:
|
||||
has_invalid_values = kv_cache_manager.check_invalid_values_in_kv_cache(
|
||||
fill_with_zero=True)
|
||||
if has_invalid_values:
|
||||
logger.warning(
|
||||
"NaNs/Infs have been introduced to KVCache during warmup, KVCache was filled with zeros to avoid potential issues"
|
||||
)
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
@with_warmup_flag
|
||||
@warmup_with_kv_cache_cleanup
|
||||
def warmup(self, resource_manager: ResourceManager) -> None:
|
||||
"""
|
||||
Orchestrates the warmup process by calling specialized warmup methods for
|
||||
|
||||
@ -941,6 +941,34 @@ class KVCacheManager(BaseResourceManager):
|
||||
result = self.impl.get_indexer_k_cache_pool_data(layer_idx)
|
||||
return result.view(result.shape[0], -1)
|
||||
|
||||
def check_invalid_values_in_kv_cache(self,
|
||||
fill_with_zero: bool = False) -> bool:
|
||||
some_checks_unavailable = False
|
||||
has_invalid_values = torch.tensor([False],
|
||||
dtype=torch.bool,
|
||||
device=torch.cuda.current_device())
|
||||
for layer_idx, layer_offset in self.layer_offsets.items():
|
||||
buffer = self.impl.get_primary_pool_data(layer_offset)
|
||||
# process in chunks of 256 pages to avoid OoM
|
||||
for i in range(0, buffer.shape[0], 256):
|
||||
buffer_slice = buffer[i:i + 256]
|
||||
try:
|
||||
has_invalid_values.logical_or_(
|
||||
torch.isnan(buffer_slice).any())
|
||||
has_invalid_values.logical_or_(
|
||||
torch.isinf(buffer_slice).any())
|
||||
except NotImplementedError:
|
||||
some_checks_unavailable = True
|
||||
if fill_with_zero:
|
||||
buffer.zero_()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
if some_checks_unavailable:
|
||||
logger.warning(
|
||||
"`torch.isnan` or `torch.isinf` is not implemented for current kv cache dtype, related checks are skipped"
|
||||
)
|
||||
return bool(has_invalid_values)
|
||||
|
||||
def get_unique_primary_pool(self) -> torch.Tensor:
|
||||
return self.impl.get_unique_primary_pool()
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user