diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 3eae859f80..a1890da391 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -571,7 +571,40 @@ class PyTorchModelEngine(ModelEngine): finally: self.cuda_graph_runner.enabled = _run_cuda_graphs + @staticmethod + def warmup_with_kv_cache_cleanup(method): + """ + Decorator for warmup methods that cleans up NaNs/Infs in KV Cache after warmup execution. + + Why this is needed: + - Our attention kernel uses multiplication by zero to mask out invalid tokens within + the same page. Since NaN/Inf * 0 = NaN, any NaNs/Infs in these invalid KV areas + will persist after masking. + - These NaNs/Infs propagate to outputs and subsequent KV Cache entries, corrupting + future computations with higher probability. + - During warmup, we execute with placeholder data rather than actual valid inputs, + which can introduce NaNs/Infs into KV Cache pages and cause random, hard-to-debug + accuracy issues. + """ + + @functools.wraps(method) + def wrapper(self, resource_manager: ResourceManager, *args, **kwargs): + result = method(self, resource_manager, *args, **kwargs) + kv_cache_manager = resource_manager.get_resource_manager( + self.kv_cache_manager_key) + if kv_cache_manager is not None: + has_invalid_values = kv_cache_manager.check_invalid_values_in_kv_cache( + fill_with_zero=True) + if has_invalid_values: + logger.warning( + "NaNs/Infs have been introduced to KVCache during warmup, KVCache was filled with zeros to avoid potential issues" + ) + return result + + return wrapper + @with_warmup_flag + @warmup_with_kv_cache_cleanup def warmup(self, resource_manager: ResourceManager) -> None: """ Orchestrates the warmup process by calling specialized warmup methods for diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index 82bdce620e..0d739c807c 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -941,6 +941,34 @@ class KVCacheManager(BaseResourceManager): result = self.impl.get_indexer_k_cache_pool_data(layer_idx) return result.view(result.shape[0], -1) + def check_invalid_values_in_kv_cache(self, + fill_with_zero: bool = False) -> bool: + some_checks_unavailable = False + has_invalid_values = torch.tensor([False], + dtype=torch.bool, + device=torch.cuda.current_device()) + for layer_idx, layer_offset in self.layer_offsets.items(): + buffer = self.impl.get_primary_pool_data(layer_offset) + # process in chunks of 256 pages to avoid OoM + for i in range(0, buffer.shape[0], 256): + buffer_slice = buffer[i:i + 256] + try: + has_invalid_values.logical_or_( + torch.isnan(buffer_slice).any()) + has_invalid_values.logical_or_( + torch.isinf(buffer_slice).any()) + except NotImplementedError: + some_checks_unavailable = True + if fill_with_zero: + buffer.zero_() + torch.cuda.synchronize() + + if some_checks_unavailable: + logger.warning( + "`torch.isnan` or `torch.isinf` is not implemented for current kv cache dtype, related checks are skipped" + ) + return bool(has_invalid_values) + def get_unique_primary_pool(self) -> torch.Tensor: return self.impl.get_unique_primary_pool()