diff --git a/docs/source/features/auto_deploy/advanced/benchmarking_with_trtllm_bench.md b/docs/source/features/auto_deploy/advanced/benchmarking_with_trtllm_bench.md index 84f8015889..6c738d445d 100644 --- a/docs/source/features/auto_deploy/advanced/benchmarking_with_trtllm_bench.md +++ b/docs/source/features/auto_deploy/advanced/benchmarking_with_trtllm_bench.md @@ -55,14 +55,17 @@ skip_loading_weights: false # Sequence configuration max_batch_size: 256 +# transform options +# KV cache configuration +kv_cache_config: + # fraction of free memory to use for kv-caches + free_gpu_memory_fraction: 0.8 + # transform options transforms: insert_cached_attention: # attention backend backend: flashinfer - resize_kv_cache: - # fraction of free memory to use for kv-caches - free_mem_ratio: 0.8 compile_model: # compilation backend backend: torch-opt @@ -80,7 +83,7 @@ Enable multi-GPU execution by specifying `--tp n`, where `n` is the number of GP |-----------|---------|-------------| | `compile_backend` | `torch-compile` | Compilation backend: `torch-simple`, `torch-compile`, `torch-cudagraph`, `torch-opt` | | `runtime` | `trtllm` | Runtime engine: `trtllm`, `demollm` | -| `free_mem_ratio` | `0.0` | Fraction of available GPU memory for KV cache (0.0-1.0) | +| `kv_cache_config.free_gpu_memory_fraction` | `0.9` | Fraction of available GPU memory for KV cache (0.0-1.0) | | `skip_loading_weights` | `false` | Skip weight loading for architecture-only benchmarks | ### CUDA Graph Optimization @@ -95,7 +98,7 @@ For optimal CUDA graph performance, specify batch sizes that match your expected ## Performance Optimization Tips -1. **Memory Management**: Set `free_mem_ratio` to 0.8-0.9 for optimal KV cache utilization +1. **Memory Management**: Set `kv_cache_config.free_gpu_memory_fraction` to 0.8-0.9 for optimal KV cache utilization 1. **Compilation Backend**: Use `torch-opt` for production workloads 1. **Attention Backend**: `flashinfer` generally provides the best performance for most models 1. **CUDA Graphs**: Enable CUDA graphs for batch sizes that match your production traffic patterns. diff --git a/docs/source/features/auto_deploy/advanced/workflow.md b/docs/source/features/auto_deploy/advanced/workflow.md index f1bd715029..c828323b92 100644 --- a/docs/source/features/auto_deploy/advanced/workflow.md +++ b/docs/source/features/auto_deploy/advanced/workflow.md @@ -15,10 +15,8 @@ llm = LLM( compile_backend="torch-compile", model_kwargs={"num_hidden_layers": 2}, # test with smaller model configuration attn_backend="flashinfer", # choose between "triton" and "flashinfer" - attn_page_size=64, # page size for attention (tokens_per_block, should be == max_seq_len for triton) skip_loading_weights=False, model_factory="AutoModelForCausalLM", # choose appropriate model factory - free_mem_ratio=0.8, # fraction of available memory for cache max_seq_len=, max_batch_size=, ) diff --git a/docs/source/torch/auto_deploy/advanced/benchmarking_with_trtllm_bench.md b/docs/source/torch/auto_deploy/advanced/benchmarking_with_trtllm_bench.md index 2f37c716cf..e55b5c3264 100644 --- a/docs/source/torch/auto_deploy/advanced/benchmarking_with_trtllm_bench.md +++ b/docs/source/torch/auto_deploy/advanced/benchmarking_with_trtllm_bench.md @@ -49,14 +49,17 @@ skip_loading_weights: false # Sequence configuration max_batch_size: 256 +# transform options +# KV cache configuration +kv_cache_config: + # fraction of free memory to use for kv-caches + free_gpu_memory_fraction: 0.9 + # transform options transforms: insert_cached_attention: # attention backend backend: flashinfer - resize_kv_cache: - # fraction of free memory to use for kv-caches - free_mem_ratio: 0.8 compile_model: # compilation backend backend: torch-opt @@ -74,7 +77,7 @@ Enable multi-GPU execution by specifying `--tp n`, where `n` is the number of GP |-----------|---------|-------------| | `compile_backend` | `torch-compile` | Compilation backend: `torch-simple`, `torch-compile`, `torch-cudagraph`, `torch-opt` | | `runtime` | `trtllm` | Runtime engine: `trtllm`, `demollm` | -| `free_mem_ratio` | `0.0` | Fraction of available GPU memory for KV cache (0.0-1.0) | +| `kv_cache_config.free_gpu_memory_fraction` | `0.9` | Fraction of available GPU memory for KV cache (0.0-1.0) | | `skip_loading_weights` | `false` | Skip weight loading for architecture-only benchmarks | ### CUDA Graph Optimization @@ -89,7 +92,7 @@ For optimal CUDA graph performance, specify batch sizes that match your expected ## Performance Optimization Tips -1. **Memory Management**: Set `free_mem_ratio` to 0.8-0.9 for optimal KV cache utilization +1. **Memory Management**: Set `kv_cache_config.free_gpu_memory_fraction` to 0.8-0.9 for optimal KV cache utilization 1. **Compilation Backend**: Use `torch-opt` for production workloads 1. **Attention Backend**: `flashinfer` generally provides the best performance for most models 1. **CUDA Graphs**: Enable CUDA graphs for batch sizes that match your production traffic patterns. diff --git a/docs/source/torch/auto_deploy/advanced/serving_with_trtllm_serve.md b/docs/source/torch/auto_deploy/advanced/serving_with_trtllm_serve.md index 20693f6170..18c2de281a 100644 --- a/docs/source/torch/auto_deploy/advanced/serving_with_trtllm_serve.md +++ b/docs/source/torch/auto_deploy/advanced/serving_with_trtllm_serve.md @@ -54,14 +54,17 @@ max_batch_size: 256 # multi-gpu execution world_size: 1 +# transform options +# KV cache configuration +kv_cache_config: + # fraction of free memory to use for kv-caches + free_gpu_memory_fraction: 0.9 + # transform options transforms: insert_cached_attention: # attention backend backend: flashinfer - resize_kv_cache: - # fraction of free memory to use for kv-caches - free_mem_ratio: 0.8 compile_model: # compilation backend backend: torch-opt @@ -77,7 +80,7 @@ transforms: - Prefer `compile_backend: torch-opt` - Use `attn_backend: flashinfer` - Set realistic `cuda_graph_batch_sizes` that match expected traffic - - Tune `free_mem_ratio` to 0.8–0.9 + - Tune `kv_cache_config.free_gpu_memory_fraction` to 0.8–0.9 ## See also diff --git a/docs/source/torch/auto_deploy/advanced/workflow.md b/docs/source/torch/auto_deploy/advanced/workflow.md index 5debad44d3..99cc5c0f90 100644 --- a/docs/source/torch/auto_deploy/advanced/workflow.md +++ b/docs/source/torch/auto_deploy/advanced/workflow.md @@ -17,12 +17,9 @@ llm = LLM( transforms={ "insert_cached_attention": {"backend": "flashinfer"}, # or "triton" "insert_cached_mla_attention": {"backend": "MultiHeadLatentAttention"}, - "resize_kv_cache": {"free_mem_ratio": 0.8}, "compile_model": {"backend": "torch-compile"}, "detect_sharding": {"simple_shard_only": False}, - }, - attn_page_size=64, # page size for attention skip_loading_weights=False, max_seq_len=, max_batch_size=, diff --git a/examples/auto_deploy/README.md b/examples/auto_deploy/README.md index 1a90c64948..ccad5f6231 100644 --- a/examples/auto_deploy/README.md +++ b/examples/auto_deploy/README.md @@ -49,7 +49,6 @@ Below is a non-exhaustive list of common configuration options: | `--args.mla-backend` | Specifies implementation for multi-head latent attention | | `--args.max-seq-len` | Maximum sequence length for inference/cache | | `--args.max-batch-size` | Maximum dimension for statically allocated KV cache | -| `--args.attn-page-size` | Page size for attention | | `--prompt.batch-size` | Number of queries to generate | | `--benchmark.enabled` | Whether to run the built-in benchmark (true/false) | @@ -125,10 +124,8 @@ llm = LLM( compile_backend="torch-compile", model_kwargs={"num_hidden_layers": 2}, # test with smaller model configuration attn_backend="flashinfer", # choose between "triton" and "flashinfer" - attn_page_size=64, # page size for attention (tokens_per_block, should be == max_seq_len for triton) skip_loading_weights=False, model_factory="AutoModelForCausalLM", # choose appropriate model factory - free_mem_ratio=0.8, # fraction of available memory for cache max_seq_len=, max_batch_size=, ) diff --git a/examples/auto_deploy/model_registry/configs/llama3_3_70b.yaml b/examples/auto_deploy/model_registry/configs/llama3_3_70b.yaml index 828800c93b..427faaf70f 100644 --- a/examples/auto_deploy/model_registry/configs/llama3_3_70b.yaml +++ b/examples/auto_deploy/model_registry/configs/llama3_3_70b.yaml @@ -3,7 +3,6 @@ max_batch_size: 1024 max_num_tokens: 2048 -free_mem_ratio: 0.9 trust_remote_code: true cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 768, 1024] kv_cache_config: diff --git a/examples/auto_deploy/model_registry/configs/llama4_scout.yaml b/examples/auto_deploy/model_registry/configs/llama4_scout.yaml index 25b5c98971..2a5bb8fc15 100644 --- a/examples/auto_deploy/model_registry/configs/llama4_scout.yaml +++ b/examples/auto_deploy/model_registry/configs/llama4_scout.yaml @@ -3,7 +3,6 @@ max_batch_size: 1024 max_num_tokens: 2048 -free_mem_ratio: 0.9 trust_remote_code: true cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 768, 1024] kv_cache_config: diff --git a/examples/auto_deploy/nano_v3.yaml b/examples/auto_deploy/nano_v3.yaml index 7e30caf1c5..8844bc228e 100644 --- a/examples/auto_deploy/nano_v3.yaml +++ b/examples/auto_deploy/nano_v3.yaml @@ -6,12 +6,12 @@ enable_chunked_prefill: true attn_backend: flashinfer model_factory: AutoModelForCausalLM skip_loading_weights: false -# TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/9884 -free_mem_ratio: 0.88 cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 24, 32, 64, 128, 256, 320, 384] kv_cache_config: - # disable kv_cache reuse since not supported for hybrid/ssm models - enable_block_reuse: false + free_gpu_memory_fraction: 0.88 + # tunable mamba cache dtype + # --> use float32 for accuracy and default (auto) for speed + mamba_ssm_cache_dtype: auto transforms: detect_sharding: allreduce_strategy: SYMM_MEM @@ -39,12 +39,6 @@ transforms: multi_stream_moe: stage: compile enabled: true - # tunable mamba cache dtype - # --> use float32 for accuracy and default (null) for speed - insert_cached_ssm_attention: - cache_config: - # mamba_dtype: float32 - mamba_dtype: null gather_logits_before_lm_head: # TODO: fix https://github.com/NVIDIA/TensorRT-LLM/issues/9878 to enable by default enabled: true diff --git a/examples/auto_deploy/nemotron_flash.yaml b/examples/auto_deploy/nemotron_flash.yaml index 9872c23ed4..6e1d782426 100644 --- a/examples/auto_deploy/nemotron_flash.yaml +++ b/examples/auto_deploy/nemotron_flash.yaml @@ -4,11 +4,7 @@ max_seq_len: 2097152 max_num_tokens: 8192 enable_chunked_prefill: true model_factory: NemotronFlashForCausalLM -free_mem_ratio: 0.9 cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 24, 32, 64, 96, 128, 256, 320, 384] -kv_cache_config: - # disable kv_cache reuse since not supported for hybrid/ssm models - enable_block_reuse: false transforms: gather_logits_before_lm_head: # TODO: fix https://github.com/NVIDIA/TensorRT-LLM/issues/9878 to enable by default diff --git a/examples/auto_deploy/super_v3.yaml b/examples/auto_deploy/super_v3.yaml index 707f3ef2de..56e7f292da 100644 --- a/examples/auto_deploy/super_v3.yaml +++ b/examples/auto_deploy/super_v3.yaml @@ -6,11 +6,11 @@ enable_chunked_prefill: true attn_backend: flashinfer model_factory: AutoModelForCausalLM skip_loading_weights: false -free_mem_ratio: 0.9 cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 24, 32, 64, 128, 256, 320, 384] kv_cache_config: - # disable kv_cache reuse since not supported for hybrid/ssm models - enable_block_reuse: false + # tunable mamba cache dtype + # --> use float32 for accuracy and default (auto) for speed + mamba_ssm_cache_dtype: auto transforms: detect_sharding: allreduce_strategy: SYMM_MEM @@ -38,12 +38,6 @@ transforms: multi_stream_moe: stage: compile enabled: false - # tunable mamba cache dtype - # --> use float32 for accuracy and default (null) for speed - insert_cached_ssm_attention: - cache_config: - # mamba_dtype: float32 - mamba_dtype: null gather_logits_before_lm_head: # TODO: fix https://github.com/NVIDIA/TensorRT-LLM/issues/9878 to enable by default enabled: true diff --git a/examples/models/core/nemotron/README_nemotron_nano_v3.md b/examples/models/core/nemotron/README_nemotron_nano_v3.md index d60a76ec5e..8bef20fe92 100644 --- a/examples/models/core/nemotron/README_nemotron_nano_v3.md +++ b/examples/models/core/nemotron/README_nemotron_nano_v3.md @@ -37,15 +37,16 @@ enable_chunked_prefill: true attn_backend: flashinfer model_factory: AutoModelForCausalLM skip_loading_weights: false -free_mem_ratio: 0.9 cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 24, 32, 64, 128, 256, 320, 384] kv_cache_config: - # disable kv_cache reuse since not supported for hybrid/ssm models - enable_block_reuse: false + free_gpu_memory_fraction: 0.88 + # tunable mamba cache dtype + # --> use float32 for accuracy and default (auto) for speed + # mamba_ssm_cache_dtype: float32 transforms: detect_sharding: - sharding_dims: ['ep', 'bmm'] allreduce_strategy: 'SYMM_MEM' + sharding_dims: ['ep', 'bmm'] manual_config: head_dim: 128 tp_plan: @@ -69,9 +70,8 @@ transforms: multi_stream_moe: stage: compile enabled: true - insert_cached_ssm_attention: - cache_config: - mamba_dtype: float32 + gather_logits_before_lm_head: + enabled: true fuse_mamba_a_log: stage: post_load_fusion enabled: true diff --git a/tensorrt_llm/_torch/auto_deploy/config/default.yaml b/tensorrt_llm/_torch/auto_deploy/config/default.yaml index 12ea7c5edd..5fc629fd8f 100644 --- a/tensorrt_llm/_torch/auto_deploy/config/default.yaml +++ b/tensorrt_llm/_torch/auto_deploy/config/default.yaml @@ -99,9 +99,11 @@ transforms: stage: weight_load run_per_gm: false checkpoint_device: null + expect_mem_change: true move_inputs_to_device: stage: weight_load run_per_gm: false + expect_mem_change: true ############################################################################################ # RUN POST-LOAD FUSION AND OPTIMIZATIONS ############################################################################################ @@ -122,18 +124,17 @@ transforms: backend: trtllm fuse_moe: stage: post_load_fusion - enabled: true + expect_mem_change: true backend: trtllm fuse_fp8_moe: stage: post_load_fusion - enabled: true + expect_mem_change: true backend: trtllm fuse_nvfp4_moe: stage: post_load_fusion - enabled: true + expect_mem_change: true fuse_allreduce_residual_rmsnorm: stage: post_load_fusion - # TODO (lucaslie): add backend selection as part of configurable inference optimizers fuse_rmsnorm: stage: post_load_fusion rmsnorm_backend: flashinfer @@ -175,11 +176,12 @@ transforms: backend: cached_residual_add initialize_cache: stage: cache_init + expect_mem_change: true run_per_gm: false resize_kv_cache: stage: cache_init + expect_mem_change: true run_per_gm: false - free_mem_ratio: 0.0 ############################################################################################ # COMPILE MODEL ############################################################################################ @@ -190,6 +192,7 @@ transforms: enabled: false compile_model: stage: compile + expect_mem_change: true run_per_gm: false cuda_graph_batch_sizes: null backend: torch-compile diff --git a/tensorrt_llm/_torch/auto_deploy/config/transformers.yaml b/tensorrt_llm/_torch/auto_deploy/config/transformers.yaml index 42188ab94b..07d40c072f 100644 --- a/tensorrt_llm/_torch/auto_deploy/config/transformers.yaml +++ b/tensorrt_llm/_torch/auto_deploy/config/transformers.yaml @@ -7,12 +7,14 @@ transforms: build_and_load_factory_model: stage: factory run_per_gm: false + expect_mem_change: true ############################################################################################ # MOVE ARGUMENTS TO DEVICE ############################################################################################ move_inputs_to_device: stage: weight_load run_per_gm: false + expect_mem_change: true ############################################################################################ # SWITCH TO CACHED+FLATTENED ATTENTION + INITIALIZE CACHES ############################################################################################ @@ -26,10 +28,11 @@ transforms: initialize_cache: stage: cache_init run_per_gm: false + expect_mem_change: true resize_kv_cache: stage: cache_init run_per_gm: false - free_mem_ratio: 0.0 + expect_mem_change: true ############################################################################################ # COMPILE MODEL ############################################################################################ diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py index 3131d87cf8..170463dae9 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py @@ -9,16 +9,18 @@ and operates on a purely functional paradigm that is compatible with the torch c """ +import math from abc import ABC, abstractmethod -from typing import Dict, List, Literal, Optional, Protocol, Sequence, Set, Tuple, Type, Union +from typing import Dict, List, Literal, Optional, Protocol, Sequence, Set, Tuple, Type, Union, final import torch -from pydantic import BaseModel, ConfigDict, Field, field_validator from torch._ops import OpOverloadPacket from torch.fx import Node from torch.types import Number -from ...._utils import nvtx_range +from tensorrt_llm.llmapi.llm_args import KvCacheConfig + +from ...._utils import nvtx_range, str_dtype_to_torch from ..utils.logger import ad_logger Constant = Union[int, float, str, None] @@ -281,44 +283,6 @@ class InputBuffer: self._device_views = self._create_views(self._device_buffer) -class CacheConfig(BaseModel): - """Cache configuration for attention-related dtypes.""" - - model_config = ConfigDict( - arbitrary_types_allowed=True, - extra="forbid", - ) - - dtype: Optional[torch.dtype] = Field(default=None, description="KV cache dtype.") - mamba_dtype: Optional[torch.dtype] = Field(default=None, description="Mamba cache dtype.") - delta_dtype: Optional[torch.dtype] = Field( - default=torch.float32, description="Delta cache dtype. Defaults to float32." - ) - - @field_validator("dtype", "mamba_dtype", "delta_dtype", mode="before") - @classmethod - def _coerce_dtype(cls, value): - if value is None or isinstance(value, torch.dtype): - return value - if isinstance(value, str): - dtype = getattr(torch, value, None) - assert isinstance(dtype, torch.dtype), f"Invalid {dtype=}" - return dtype - return value - - def __or__(self, other: "CacheConfig") -> "CacheConfig": - """Combine two CacheConfig objects field-wise using Python's `or` semantics. - - For each field, selects the first non-None value between `self` and `other`. - """ - if not isinstance(other, CacheConfig): - raise NotImplementedError(f"Cannot combine CacheConfig with {type(other)}") - merged_kwargs = {} - for field_name in type(self).model_fields.keys(): - merged_kwargs[field_name] = getattr(self, field_name) or getattr(other, field_name) - return CacheConfig(**merged_kwargs) - - class SequenceInfo: """An interface to hold information about how the sequence is laid out and stored in cache. @@ -398,9 +362,9 @@ class SequenceInfo: def __init__( self, - max_seq_len: int = 1, - max_batch_size: int = 1, - page_size: int = 0, + max_seq_len: int, + max_batch_size: int, + tokens_per_block: Optional[int] = None, max_num_tokens: Optional[int] = None, vocab_size_padded: Optional[int] = None, ): @@ -411,9 +375,7 @@ class SequenceInfo: includes the tokens in the input sequence and the tokens generated by the model. max_batch_size: corresponds to the maximum number of sequences (or requests) that the model can process. - page_size: corresponds to the page size of the cache. For an unpaged cache, the page - size should be set to max_seq_len. Also note that two sequences in a batch can not - share a page. + tokens_per_block: corresponds to the tokens per block of the cache. max_num_tokens: corresponds to the maximum number of tokens that the model can process across all sequences in the batch. If a batch is composed of context-only requests of input sequence length ISL, then the maximum number of sequences possible in the @@ -427,46 +389,27 @@ class SequenceInfo: # set up basic attributes self.max_seq_len = max_seq_len self.max_batch_size = max_batch_size - self.page_size = page_size if page_size > 0 else max_seq_len - self.vocab_size_padded = vocab_size_padded - # NOTE (lucaslie): WAR to address issue when using flashinfer attention with + self.tokens_per_block = tokens_per_block or max_seq_len + # NOTE (lucaslie): +1 is a WAR to address issue when using flashinfer attention with # (max_batch_size, max_seq_len) input in trtllm runtime. # see https://github.com/NVIDIA/TensorRT-LLM/issues/4504 - max_seq_len_adjusted = self.max_seq_len + 1 + self.max_num_tokens = max_num_tokens or (max_seq_len + 1) * max_batch_size - # TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/9883 clean up this hack - self.max_state_slots = max_batch_size + 1 + # TODO (lucaslie): can we remove this eventually from this i/f? + self.vocab_size_padded = vocab_size_padded - # if the provided max_num_tokens is less than the max_batch_size * max_seq_len_adjusted, - # we use the provided max_num_tokens. If max_num_tokens provided is more, we still use - # max_batch_size * max_seq_len_adjusted since the extra tokens cannot be used. - self.max_num_tokens = self.max_batch_size * max_seq_len_adjusted - if max_num_tokens is not None and max_num_tokens > 0: - self.max_num_tokens = min(self.max_num_tokens, max_num_tokens) - - # Num pages can not be less than max_batch_size. - self._num_pages = max( - self.max_batch_size, - (self.max_num_tokens) // self.page_size # floored number of pages - + (self.max_num_tokens / self.max_batch_size % self.page_size > 0) # check for overflow - * self.max_batch_size, # +1 page per sequence if overflow is required - ) - # sanity check - assert self.num_pages >= self.max_batch_size, "num_pages can't be less than max_batch_size" - - # cache_loc requires some special treatment due to block reuse. Note that the constraint for - # cache_loc with block_reuse is as follows: - # 0 <= cache_loc < num_pages - # len(cache_loc) <= max_num_cache_loc_assignments - max_num_cache_loc_assignments = ( - max_seq_len_adjusted // self.page_size + 1 - ) * self.max_batch_size + # NOTE: we keep an extra state slot around to simplify cuda graph padding + # WHY? + # Requests that just finished won't free their used resources immediately. Specifically, the + # running order is self.scheduler.schedule_request, self._forward_step() and + # self._process_previous_batch() in the PyExecutor. Hence, the current forward step will + # remove finished requests but will not remove mamba_cache immediately and therefore it + # won't be available in time for padding in the next forward step. + self.max_num_state_slots = max_batch_size + 1 # log parameters ad_logger.info( - f"[SequenceInfo:] {self.max_seq_len=}, {self.max_batch_size=}, {self.page_size=}, " - f"{self.max_num_tokens=} (inferred), {max_num_tokens=} (provided), {self.num_pages=}, " - f"{max_num_cache_loc_assignments=}" + f"[SequenceInfo:] {self.max_seq_len=}, {self.max_batch_size=}, {self.max_num_tokens=}" ) # indicator if extra args are activated that are needed for cached attention backends @@ -496,8 +439,10 @@ class SequenceInfo: # OTHER FIELDS WHERE WE NEED EFFICIENT HOST<>DEVICE TRANSFER ("_gather_idx", self.max_num_tokens, torch.int), ("_mask_scatter_indices", self.max_num_tokens, torch.int), - # cache_loc is LAST for truncation optimization (it's the largest tensor) - ("cache_loc", max_num_cache_loc_assignments, torch.int), + # cache_loc is LAST for truncation optimization (it can be the largest tensor) + # NOTE: sufficient for max_num_tokens forward pass. will be resized when KVCacheManager + # is created. + ("cache_loc", self.max_num_tokens, torch.int), ] # Create the InputBuffer that manages contiguous host and device memory @@ -619,35 +564,44 @@ class SequenceInfo: def is_generate(self) -> bool: return all(sl == 1 for sl in self.seq_len) - @property - def num_pages(self) -> int: - return self._num_pages - - @num_pages.setter - def num_pages(self, value): - self._num_pages = value - # Check if we need to resize cache_loc (it's the last tensor in the buffer) - cache_loc_capacity = self._input_buffer.get_capacity("cache_loc") - if value > cache_loc_capacity: - ad_logger.info( - f"Resizing cache_loc capacity from {cache_loc_capacity} to {value} " - f"to accommodate num_pages={value}" - ) - # Resize the input buffer (cache_loc is the last tensor, so this is supported) - self._input_buffer.resize("cache_loc", value) - # Also resize the args_list to match - old_size = len(self._args_list["cache_loc"]) - self._args_list["cache_loc"].extend([0] * (value - old_size)) - - @property - def is_paged(self) -> bool: - return self.page_size < self.max_seq_len - @property def page_assignments(self) -> List[List[int]]: """Return the page assignments for each sequence.""" return self._get_page_assignments(self.cache_loc, self.pages_per_seq) + def estimate_cache_tokens_per_forward(self) -> int: + """Estimate the max number of tokens that will be cached for a forward pass. + + It is estimated assuming a worst-case allocation of tokens across sequences in a batch. + """ + seq_len = math.ceil(self.max_num_tokens / self.max_batch_size) + num_blocks_estimate_per_seq = math.ceil(seq_len / self.tokens_per_block) + num_blocks_estimate = num_blocks_estimate_per_seq * self.max_batch_size + return num_blocks_estimate * self.tokens_per_block + + def estimate_cache_loc_capacity(self, num_blocks: int) -> None: + """Estimate needed capacity of cache_loc based on available blocks and resize.""" + cache_loc_capacity = self._input_buffer.get_capacity("cache_loc") + + # cache_loc requires some special treatment due to block reuse. Note that the constraint for + # cache_loc with block_reuse is as follows: + # 0 <= cache_loc < num_blocks + # len(cache_loc) <= num_blocks * max_batch_size + # NOTE: # However, num_blocks * max_batch_size may be a potentially very large number and an + # overestimation, see we assume at most 10x reuse of blocks. + estimated_capacity = num_blocks * min(10, self.max_batch_size) + + # NOTE (lucaslie): WAR to address issue when using flashinfer attention with + # (max_batch_size, max_seq_len) input in trtllm runtime. + # see https://github.com/NVIDIA/TensorRT-LLM/issues/4504 + estimated_capacity = estimated_capacity + 1 + + if estimated_capacity > cache_loc_capacity: + self._input_buffer.resize("cache_loc", estimated_capacity) + # Also resize the args_list to match + old_size = len(self._args_list["cache_loc"]) + self._args_list["cache_loc"].extend([0] * (estimated_capacity - old_size)) + @staticmethod def _get_page_assignments( cache_locations: List[int], pages_per_sequence: List[int] @@ -732,7 +686,8 @@ class SequenceInfo: # figure out page assignments pages_per_seq = [ - len(ids_one_seq) // self.page_size + (len(ids_one_seq) % self.page_size > 0) + len(ids_one_seq) // self.tokens_per_block + + (len(ids_one_seq) % self.tokens_per_block > 0) for ids_one_seq in input_ids ] cache_loc = list(range(sum(pages_per_seq))) @@ -889,7 +844,7 @@ class SequenceInfo: This i/f will ensure that all sequence info args are updated accordingly. Reset values are chosen as "neutral" values so that for cases like rounding up batch sizes for cudagraph we - only write to unused buffers/caches. + only write to unused caches. """ ### UPDATE SEQUENCE LENGTH AND INPUT POSITION FIRST SINCE IT'S USED FOR OTHER UPDATES ###### if seq_len is None: @@ -953,7 +908,7 @@ class SequenceInfo: # update last page length if last_page_len is None: - last_page_len = [(slwc - 1) % self.page_size + 1 for slwc in seq_len_with_cache] + last_page_len = [(slwc - 1) % self.tokens_per_block + 1 for slwc in seq_len_with_cache] self._store_arg("last_page_len", last_page_len) # check for updated slot_idx @@ -1049,6 +1004,90 @@ class SequenceInfo: host_function(**{arg: self._get_arg(arg) for arg in args}) +class ResourceHandler(ABC): + """An abstract interface to handle a generic resource needed by attention operators. + + The ResourceHandler interface standardizes operations that the cached sequence interface + performs on the resources providing an abstract handle. + """ + + @abstractmethod + def allocate(self, sequence_info: SequenceInfo) -> torch.Tensor: + """Initialize the resource for the given sequence info.""" + + +class ManagedResourceHandler(ResourceHandler): + """An abstract interface to handle a resource that is managed by the cache manager.""" + + @final + def allocate(self, sequence_info: SequenceInfo) -> torch.Tensor: + """Allocate the resource for the given sequence info.""" + raise NotImplementedError("Managed resources should not be allocated directly!") + + +class PagedResourceHandler(ManagedResourceHandler): + """An abstract interface to handle a paged resource. + + The PagedResourceHandler can be used to handle resources that support paging such as kv-caches. + """ + + def __init__(self, *token_shape: int, dtype: torch.dtype) -> None: + """Initialize the PagedResourceHandler. + + Args: + page_shape: The shape of a single page of the resource. + dtype: The dtype of the resource. + """ + self.token_shape = token_shape + self.dtype = dtype + + +class StateResourceHandler(ManagedResourceHandler): + """Handler for per-sequence state resources (e.g., Mamba SSM/conv states). + + These resources have shape [max_batch_size, *state_shape] and are + managed by MambaHybridCacheManager via byte-level pooling. + """ + + def __init__(self, *state_shape: int, dtype: torch.dtype) -> None: + """Initialize the StateResourceHandler. + + Args: + state_shape: The shape of a single state resource. + dtype: The dtype of the state resource. + """ + self.state_shape = state_shape + self.dtype = dtype + + +class UnpagedResourceHandler(ResourceHandler): + """Handler for per-token unpaged resources (e.g., unpaged KV caches). + + These resources have shape [max_batch_size, max_seq_len, *token_shape]. + They are allocated locally and not managed by MambaHybridCacheManager. + """ + + def __init__(self, *token_shape: int, dtype: torch.dtype) -> None: + """Initialize the UnpagedResourceHandler. + + Args: + token_shape: The shape of the resource per token. + dtype: The dtype of the resource. + """ + self.token_shape = token_shape + self.dtype = dtype + + def allocate(self, sequence_info: SequenceInfo) -> torch.Tensor: + """Initialize the unpaged resource for the given sequence info.""" + return torch.empty( + sequence_info.max_num_state_slots, + sequence_info.max_seq_len, + *self.token_shape, + device=sequence_info.device, + dtype=self.dtype, + ) + + class MHACallable(Protocol): def __call__( self, @@ -1062,18 +1101,10 @@ class PrepareMetadataCallable(Protocol): ) -> List[torch.Tensor]: ... -class GetCacheCallable(Protocol): - def __call__(self, sequence_info: SequenceInfo) -> torch.Tensor: ... - - -class GetBufferCallable(GetCacheCallable): - pass - - -CacheInitializerDict = Dict[str, GetCacheCallable] -BufferInitializerDict = Dict[str, GetBufferCallable] AttentionLayout = Literal["bsnd", "bnsd"] +ResourceHandlerDict = Dict[str, ResourceHandler] + class AttentionDescriptor(ABC): """An interface to define a functional attention operator. @@ -1083,11 +1114,6 @@ class AttentionDescriptor(ABC): specific to the attention op. """ - @classmethod - @abstractmethod - def is_paged(cls) -> bool: - """Return if the attention op is paged or not.""" - @classmethod @abstractmethod def get_attention_layout(cls) -> AttentionLayout: @@ -1116,7 +1142,6 @@ class AttentionDescriptor(ABC): *meta_std, # standard metadata fields identified by matching arg names! *meta_extra,# metadata about the sequences as returned by the prepare_metadata op *caches, # contains layer-specific caches per provided cache initializers - *buffers, # global buffers used by the attention op as provided by buffer initializers *constants, # basic arguments (int, float, str, None) added as CONSTANTS in the graph ) -> torch.Tensor: ... ``` @@ -1172,52 +1197,46 @@ class AttentionDescriptor(ABC): @classmethod @abstractmethod def get_cache_initializers( - cls, source_attn_node: Node, cache_config: CacheConfig - ) -> CacheInitializerDict: - """Provide a dictionary of function pointers that can be used to initialize the caches. + cls, source_attn_node: Node, cache_config: KvCacheConfig + ) -> ResourceHandlerDict: + """Provide a dictionary of resource handlers that can be used to initialize the resources. The key corresponds to the argument name used in the attention op signature. The function key doesn't need to be unique across multiple attention nodes in the graph. The key used to describe the cache in the graph will be patched with the attention node index to ensure uniqueness. - ``get_cache_initializers`` will be called *once* during cache initialization and before - the initial forward pass for each attention op detected in the graph. The caches will be - managed by the global CacheManager and passed back to the attention op during the forward - pass. + The resource will be initialized before the initial forward pass and will be managed by the + global CacheManager and passed back to the model during the forward pass. If the cache initializer requires information about the attention op, it can retrieve the necessary information from the source attention node and cache config. """ - @classmethod - def get_global_buffer_initializers(cls, source_attn_node: Node) -> BufferInitializerDict: - """Provide a dictionary of function pointers that can be used to initialize buffers. - - The key corresponds to the buffer name used in the graph module and will **not** - be patched unlike a cache key. Hence, it is a **global** key that is shared across all - attention ops in the model much like a regular buffer in an nn.Module. That means if this - i/f is called for multiple attention ops, the same buffer will be shared across all of them - if this function provides the same key multiple times. - - Buffers are initialize *once* after the model initialization and before the initial forward - pass for each attention op detected in the graph. The buffer will be managed by the global - CacheManager and passed back to the attention op during the forward pass. - - If the buffer initializer requires information about the attention op, it can retrieve - the necessary information from the source attention node. - """ - return {} - @classmethod def get_constants(cls, source_attn_node: Node) -> List[Constant]: """Provide a list of constant arguments to be passed to the attention op. The constant arguments are passed to the attention op as additional arguments after the - caches and buffers. The constants are expected to be of type int, float, str, or None. + caches. The constants are expected to be of type int, float, str, or None. """ return [] + @staticmethod + def resolve_cache_dtype(dtype_config: str, fallback_dtype: torch.dtype) -> torch.dtype: + """Resolve cache dtype from KvCacheConfig dtype string to torch.dtype. + + Args: + dtype_config: The dtype string from KvCacheConfig (e.g., "auto", "float16", "bfloat16"). + fallback_dtype: The fallback dtype to use when dtype_config is "auto". + + Returns: + The resolved torch.dtype. + """ + if dtype_config == "auto": + return fallback_dtype + return str_dtype_to_torch(dtype_config) + @classmethod def get_host_prepare_metadata_function(cls) -> Optional[PrepareMetadataHostCallable]: """Get function that performs host-side prep for the forward pass for the attention op. diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/fla_backend_delta.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/fla_backend_delta.py index 757aff042f..5c52410f24 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/fla_backend_delta.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/fla_backend_delta.py @@ -11,17 +11,16 @@ import torch from torch._ops import OpOverloadPacket from torch.fx import Node +from .....llmapi.llm_args import KvCacheConfig from ...utils.node_utils import extract_op_args from ..attention_interface import ( AttentionDescriptor, AttentionLayout, AttentionRegistry, - BufferInitializerDict, - CacheConfig, - CacheInitializerDict, Constant, MHACallable, - SequenceInfo, + ResourceHandlerDict, + StateResourceHandler, ) from .delta_rule.chunk import chunk_delta_rule_fwd from .delta_rule.fused_recurrent import fused_recurrent_delta_rule_fwd @@ -136,11 +135,6 @@ def fla_cached_delta_rule_fake( @AttentionRegistry.register("fla_delta") class FlaDeltaBackend(AttentionDescriptor): - @classmethod - def is_paged(cls) -> bool: - # TODO: we should refine our notion of "is_paged" --> seems counterintuitive for ssm nows - return True - @classmethod def get_attention_layout(cls) -> AttentionLayout: return "bsnd" @@ -164,8 +158,8 @@ class FlaDeltaBackend(AttentionDescriptor): @classmethod def get_cache_initializers( - cls, source_attn_node: Node, cache_config: CacheConfig - ) -> CacheInitializerDict: + cls, source_attn_node: Node, cache_config: KvCacheConfig + ) -> ResourceHandlerDict: key_node = source_attn_node.args[1] value_node = source_attn_node.args[2] num_heads = key_node.meta["val"].shape[-2] @@ -173,21 +167,15 @@ class FlaDeltaBackend(AttentionDescriptor): value_dim = value_node.meta["val"].shape[-1] key_dtype = key_node.meta["val"].dtype - def _get_delta_cache(si: SequenceInfo): - return torch.empty( - si.max_state_slots, + return { + "delta_cache": StateResourceHandler( num_heads, key_dim, value_dim, - device=si.device, - dtype=cache_config.delta_dtype or key_dtype, + # NOTE: not configurable at the moment, using auto to match the key dtype + dtype=cls.resolve_cache_dtype("auto", key_dtype), ) - - return {"delta_cache": _get_delta_cache} - - @classmethod - def get_global_buffer_initializers(cls, source_attn_node: Node) -> BufferInitializerDict: - return {} + } @classmethod def get_constants(cls, source_attn_node: Node) -> List[Constant]: diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py index ac530fd7ea..7d43dc6cea 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, fields -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Literal, Optional, Tuple, Union import flashinfer import torch @@ -7,6 +7,7 @@ from torch._ops import OpOverloadPacket from torch._subclasses import FakeTensor from torch.fx import Node +from ....llmapi.llm_args import KvCacheConfig from ...flashinfer_utils import get_env_enable_pdl from ..utils.cuda_graph import cuda_graph_state from ..utils.logger import ad_logger @@ -15,14 +16,12 @@ from .attention_interface import ( AttentionDescriptor, AttentionLayout, AttentionRegistry, - BufferInitializerDict, - CacheConfig, - CacheInitializerDict, Constant, MHACallable, + PagedResourceHandler, PrepareMetadataCallable, PrepareMetadataHostCallable, - SequenceInfo, + ResourceHandlerDict, ) @@ -57,6 +56,9 @@ class _FlashInferPlanner: ] plan_params_prefill: Optional[PlanParams] plan_params_decode: Optional[PlanParams] + kv_layout: Literal["NHD", "HND"] = ( + "NHD" # TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/10966 + ) def __init__(self): self.workspace_buffer = None @@ -77,7 +79,7 @@ class _FlashInferPlanner: if use_cuda_graph: return flashinfer.BatchDecodeWithPagedKVCacheWrapper( self.workspace_buffer, - "NHD", + self.kv_layout, use_cuda_graph=True, paged_kv_indptr_buffer=indptr, paged_kv_indices_buffer=indices, @@ -88,28 +90,33 @@ class _FlashInferPlanner: else: return flashinfer.BatchDecodeWithPagedKVCacheWrapper( self.workspace_buffer, - "NHD", + self.kv_layout, use_tensor_cores=True, backend="fa2" if torch.cuda.get_device_capability(0) == (9, 0) else "auto", ) - def init_workspace(self, workspace_buffer: torch.Tensor): + def reset(self, device: torch.device) -> None: + self.plan_params_prefill = None + self.plan_params_decode = None + + if isinstance(self.workspace_buffer, torch.Tensor): + return + self.__init__() # reset all state - self.workspace_buffer = workspace_buffer + # NOTE (lucaslie): avoid OOM for many cudagraphs, + # see https://github.com/NVIDIA/TensorRT-LLM/pull/3686 + self.workspace_buffer = torch.empty(320 * 1024 * 1024, device=device, dtype=torch.uint8) + # NOTE (lucaslie): flashinfer fa3 backend has accuracy issue + illegal memory access issues # on H100 PCIe, see https://github.com/NVIDIA/TensorRT-LLM/issues/4504 self.prefill_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( self.workspace_buffer, - "NHD", + self.kv_layout, backend="fa2", ) self.decode_wrapper = self._init_decode_wrapper() - def reset(self) -> None: - self.plan_params_prefill = None - self.plan_params_decode = None - def plan_generate_only( self, num_seq: int, @@ -248,7 +255,7 @@ def prepare_flashinfer_metadata( num_seq = num_prefill + num_decode num_tokens = num_prefill_tokens + num_decode - _GlobalFlashInferPlanner.reset() + _GlobalFlashInferPlanner.reset(position_ids.device) qo_indptr = cu_seqlen[: num_seq + 1] @@ -316,8 +323,6 @@ def flashinfer_mha_with_cache( # CACHES k_cache: torch.Tensor, v_cache: torch.Tensor, - # BUFFERS - workspace_buffer: torch.Tensor, # CONSTANTS scale: Optional[float], k_scale: float, @@ -458,8 +463,6 @@ def flashinfer_mha_with_cache_fake( # CACHES k_cache: torch.Tensor, v_cache: torch.Tensor, - # BUFFERS - workspace_buffer: torch.Tensor, # CONSTANTS scale: Optional[float], k_scale: float, @@ -474,11 +477,6 @@ class FlashInferAttention(AttentionDescriptor): def _get_planner(cls) -> _FlashInferPlanner: return _GlobalFlashInferPlanner - @classmethod - def is_paged(cls): - """Return if the attention op is paged or not.""" - return True - @classmethod def get_attention_layout(cls) -> AttentionLayout: """Get the attention layout expected by the backend.""" @@ -519,35 +517,25 @@ class FlashInferAttention(AttentionDescriptor): @classmethod def get_cache_initializers( - cls, source_attn_node: Node, cache_config: CacheConfig - ) -> CacheInitializerDict: + cls, source_attn_node: Node, cache_config: KvCacheConfig + ) -> ResourceHandlerDict: # source op is [bsnd] layout already k_fake: FakeTensor = source_attn_node.args[1].meta["val"] num_kv_heads = k_fake.shape[2] head_dim = k_fake.shape[3] - def _get_cache(si: SequenceInfo): - return torch.empty( - si.num_pages, - si.page_size, + return { + "k_cache": PagedResourceHandler( num_kv_heads, head_dim, - device=si.device, - dtype=cache_config.dtype or k_fake.dtype, - ) - - return {"k_cache": _get_cache, "v_cache": _get_cache} - - @classmethod - def get_global_buffer_initializers(cls, source_attn_node: Node) -> BufferInitializerDict: - def _init_workspace(si: SequenceInfo) -> torch.Tensor: - # NOTE (lucaslie): avoid OOM for many cudagraphs, - # see https://github.com/NVIDIA/TensorRT-LLM/pull/3686 - buffer = torch.empty(320 * 1024 * 1024, dtype=torch.uint8, device=si.device) - cls._get_planner().init_workspace(buffer) - return buffer - - return {"workspace_buffer": _init_workspace} + dtype=cls.resolve_cache_dtype(cache_config.dtype, k_fake.dtype), + ), + "v_cache": PagedResourceHandler( + num_kv_heads, + head_dim, + dtype=cls.resolve_cache_dtype(cache_config.dtype, k_fake.dtype), + ), + } @classmethod def get_host_prepare_metadata_function(cls) -> Optional[PrepareMetadataHostCallable]: diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py index bc7752df52..800419efdc 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py @@ -33,16 +33,16 @@ from torch.fx import Node from tensorrt_llm._torch.modules.mamba import PAD_SLOT_ID from tensorrt_llm._torch.modules.mamba.causal_conv1d import causal_conv1d_fn, causal_conv1d_update +from .....llmapi.llm_args import KvCacheConfig from ...utils.node_utils import extract_op_args from ..attention_interface import ( AttentionDescriptor, AttentionLayout, AttentionRegistry, - CacheConfig, - CacheInitializerDict, Constant, MHACallable, - SequenceInfo, + ResourceHandlerDict, + StateResourceHandler, ) @@ -164,10 +164,6 @@ def cuda_cached_causal_conv1d_wrapper(input: torch.Tensor, *args, **kwargs) -> t @AttentionRegistry.register("cuda_causal_conv") class CudaBackendCausalConv(AttentionDescriptor): - @classmethod - def is_paged(cls) -> bool: - return True - @classmethod def get_attention_layout(cls) -> AttentionLayout: # Hidden states follow [b, s, c] @@ -193,24 +189,21 @@ class CudaBackendCausalConv(AttentionDescriptor): @classmethod def get_cache_initializers( - cls, source_attn_node: Node, cache_config: CacheConfig - ) -> CacheInitializerDict: + cls, source_attn_node: Node, cache_config: KvCacheConfig + ) -> ResourceHandlerDict: inp_fake: torch.Tensor = source_attn_node.args[0].meta["val"] w_fake: torch.Tensor = source_attn_node.args[1].meta["val"] in_channels = inp_fake.shape[-1] kernel_size = w_fake.shape[-1] - def _get_conv_cache(si: SequenceInfo): - return torch.empty( - si.max_state_slots, - in_channels, - max(1, kernel_size - 1), - device=si.device, - dtype=inp_fake.dtype, - ) - - return {"conv_state_cache": _get_conv_cache} + conv_state_handler = StateResourceHandler( + in_channels, + max(1, kernel_size - 1), + # NOTE: not configurable at the moment, using auto to match the input dtype + dtype=cls.resolve_cache_dtype("auto", inp_fake.dtype), + ) + return {"conv_state_cache": conv_state_handler} @classmethod def get_constants(cls, source_attn_node: Node) -> List[Constant]: diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py index fd5c717023..16d91bef7f 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py @@ -16,17 +16,16 @@ import torch.nn.functional as F from torch._ops import OpOverloadPacket from torch.fx import Node +from .....llmapi.llm_args import KvCacheConfig from ...utils.node_utils import extract_op_args from ..attention_interface import ( AttentionDescriptor, AttentionLayout, AttentionRegistry, - BufferInitializerDict, - CacheConfig, - CacheInitializerDict, Constant, MHACallable, - SequenceInfo, + ResourceHandlerDict, + StateResourceHandler, ) @@ -270,11 +269,6 @@ def _torch_cached_causal_conv1d_fake( @AttentionRegistry.register("torch_causal_conv") class TorchBackendCausalConv(AttentionDescriptor): - @classmethod - def is_paged(cls) -> bool: - # TODO: we should refine our notion of "is_paged" --> seems counterintuitive for ssm nows - return True - @classmethod def get_attention_layout(cls) -> AttentionLayout: # Hidden states follow [b, s, c] @@ -300,28 +294,22 @@ class TorchBackendCausalConv(AttentionDescriptor): @classmethod def get_cache_initializers( - cls, source_attn_node: Node, cache_config: CacheConfig - ) -> CacheInitializerDict: + cls, source_attn_node: Node, cache_config: KvCacheConfig + ) -> ResourceHandlerDict: inp_fake: torch.Tensor = source_attn_node.args[0].meta["val"] w_fake: torch.Tensor = source_attn_node.args[1].meta["val"] in_channels = inp_fake.shape[-1] kernel_size = w_fake.shape[-1] - def _get_conv_cache(si: SequenceInfo): - return torch.empty( - si.max_state_slots, + return { + "conv_state_cache": StateResourceHandler( in_channels, kernel_size, - device=si.device, - dtype=inp_fake.dtype, + # NOTE: not configurable at the moment, using auto to match the input dtype + dtype=cls.resolve_cache_dtype("auto", inp_fake.dtype), ) - - return {"conv_state_cache": _get_conv_cache} - - @classmethod - def get_global_buffer_initializers(cls, source_attn_node: Node) -> BufferInitializerDict: - return {} + } @classmethod def get_constants(cls, source_attn_node: Node) -> List[Constant]: diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.py index 03f403e0f6..96995d5a74 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.py @@ -12,17 +12,16 @@ import torch from torch._ops import OpOverloadPacket from torch.fx import Node +from .....llmapi.llm_args import KvCacheConfig from ...utils.node_utils import extract_op_args from ..attention_interface import ( AttentionDescriptor, AttentionLayout, AttentionRegistry, - BufferInitializerDict, - CacheConfig, - CacheInitializerDict, Constant, MHACallable, - SequenceInfo, + ResourceHandlerDict, + StateResourceHandler, ) from .torch_mamba import _torch_ssm_prefill @@ -268,11 +267,6 @@ def _torch_cached_ssm_fake( @AttentionRegistry.register("torch_ssm") class TorchBackendSSM(AttentionDescriptor): - @classmethod - def is_paged(cls) -> bool: - # TODO: we should refine our notion of "is_paged" --> seems counterintuitive for ssm now - return True - @classmethod def get_attention_layout(cls) -> AttentionLayout: # Hidden states follow [b, s, n, d] @@ -297,8 +291,8 @@ class TorchBackendSSM(AttentionDescriptor): @classmethod def get_cache_initializers( - cls, source_attn_node: Node, cache_config: CacheConfig - ) -> CacheInitializerDict: + cls, source_attn_node: Node, cache_config: KvCacheConfig + ) -> ResourceHandlerDict: # Shapes from fake tensors hs_fake: torch.Tensor = source_attn_node.args[0].meta["val"] B_fake: torch.Tensor = source_attn_node.args[2].meta["val"] @@ -315,23 +309,13 @@ class TorchBackendSSM(AttentionDescriptor): ssm_state_size = max(1, B_fake.shape[-1]) # extract ssm_state_dtype from cache_config or hs_fake - ssm_state_dtype = cache_config.mamba_dtype or hs_fake.dtype + ssm_state_dtype = cls.resolve_cache_dtype(cache_config.mamba_ssm_cache_dtype, hs_fake.dtype) - def _get_ssm_cache(si: SequenceInfo): - return torch.empty( - si.max_state_slots, - num_heads, - head_dim, - ssm_state_size, - device=si.device, - dtype=ssm_state_dtype, + return { + "ssm_state_cache": StateResourceHandler( + num_heads, head_dim, ssm_state_size, dtype=ssm_state_dtype ) - - return {"ssm_state_cache": _get_ssm_cache} - - @classmethod - def get_global_buffer_initializers(cls, source_attn_node: Node) -> BufferInitializerDict: - return {} + } @classmethod def get_constants(cls, source_attn_node: Node) -> List[Constant]: diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py index 6df74b3c2a..5f6ecdee9a 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py @@ -24,17 +24,17 @@ from tensorrt_llm._torch.modules.mamba.mamba2_metadata import cu_seqlens_to_chun from tensorrt_llm._torch.modules.mamba.selective_state_update import selective_state_update from tensorrt_llm._torch.modules.mamba.ssd_combined import mamba_chunk_scan_combined +from .....llmapi.llm_args import KvCacheConfig from ...utils.node_utils import extract_op_args from ..attention_interface import ( AttentionDescriptor, AttentionLayout, AttentionRegistry, - CacheConfig, - CacheInitializerDict, Constant, MHACallable, PrepareMetadataCallable, - SequenceInfo, + ResourceHandlerDict, + StateResourceHandler, ) @@ -274,10 +274,6 @@ def _triton_cached_ssm_fake( @AttentionRegistry.register("triton_ssm") class TritonBackendSSM(AttentionDescriptor): - @classmethod - def is_paged(cls) -> bool: - return True - @classmethod def get_attention_layout(cls) -> AttentionLayout: # Hidden states follow [b, s, n, d] @@ -313,8 +309,8 @@ class TritonBackendSSM(AttentionDescriptor): @classmethod def get_cache_initializers( - cls, source_attn_node: Node, cache_config: CacheConfig - ) -> CacheInitializerDict: + cls, source_attn_node: Node, cache_config: KvCacheConfig + ) -> ResourceHandlerDict: # Shapes from fake tensors hs_fake: torch.Tensor = source_attn_node.args[0].meta["val"] B_fake: torch.Tensor = source_attn_node.args[2].meta["val"] @@ -328,19 +324,13 @@ class TritonBackendSSM(AttentionDescriptor): ssm_state_size = max(1, B_fake.shape[-1]) # extract ssm_state_dtype from cache_config or hs_fake - ssm_state_dtype = cache_config.mamba_dtype or hs_fake.dtype + ssm_state_dtype = cls.resolve_cache_dtype(cache_config.mamba_ssm_cache_dtype, hs_fake.dtype) - def _get_ssm_cache(si: SequenceInfo): - return torch.empty( - si.max_state_slots, - num_heads, - head_dim, - ssm_state_size, - device=si.device, - dtype=ssm_state_dtype, + return { + "ssm_state_cache": StateResourceHandler( + num_heads, head_dim, ssm_state_size, dtype=ssm_state_dtype ) - - return {"ssm_state_cache": _get_ssm_cache} + } @classmethod def get_constants(cls, source_attn_node: Node) -> List[Constant]: diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mla.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mla.py index 716cda7d1b..350be81d58 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mla.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mla.py @@ -6,21 +6,35 @@ import torch from torch._ops import OpOverloadPacket from torch.fx import Node +from ....llmapi.llm_args import KvCacheConfig from .attention_interface import ( AttentionDescriptor, AttentionLayout, AttentionRegistry, - BufferInitializerDict, - CacheConfig, - CacheInitializerDict, MHACallable, - SequenceInfo, + ResourceHandlerDict, + UnpagedResourceHandler, ) from .triton_attention import _flattened_context_mha, _generate_mha Constant = Union[int, float, str, None] +def _precompute_inv_freq( + max_seq_len: int, head_dim: int, rope_theta: float, device: torch.device +) -> torch.Tensor: + inv_freq = 1.0 / ( + rope_theta ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim) + ) + t = torch.arange(max_seq_len, device=inv_freq.device, dtype=inv_freq.dtype) + + freqs = torch.outer(t, inv_freq.to(t.device)) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + cos_sin_stacked = torch.stack([emb.cos().to(torch.bfloat16), emb.sin().to(torch.bfloat16)]) + return cos_sin_stacked + + @torch.library.custom_op( "auto_deploy::triton_attention_fused_flattened_mla_with_cache", mutates_args=() ) @@ -41,10 +55,9 @@ def fused_flattened_mla_with_cache( # CACHES k_cache: torch.Tensor, v_cache: torch.Tensor, - # BUFFERS - cos_sin_stacked: torch.Tensor, # CONSTANTS softmax_scale: Optional[float] = None, + rope_theta: Optional[float] = None, ) -> torch.Tensor: """Flattened & fused MLA with cache with triton kernels.""" # b, s info @@ -84,7 +97,12 @@ def fused_flattened_mla_with_cache( k_pe = k_pe.clone().transpose(1, 2).view(*bs_view, -1, qk_rope_head_dim).contiguous() value_states = value_states.transpose(1, 2).view(*bs_view, -1, v_head_dim).contiguous() # Apply RoPE - if cos_sin_stacked.numel() > 0: + if rope_theta is not None: + max_seq_len = (input_pos + seq_len).max().item() + cos_sin_stacked = _precompute_inv_freq( + max_seq_len, qk_rope_head_dim, rope_theta, q_pe.device + ) + # Extract cos and sin from freqs_cis cos_base = cos_sin_stacked[0, ...] sin_base = cos_sin_stacked[1, ...] @@ -176,10 +194,9 @@ def fused_flattened_mla_with_cache_fake( # CACHES k_cache: torch.Tensor, v_cache: torch.Tensor, - # BUFFERS - cos_sin_stacked: torch.Tensor, # CONSTANTS softmax_scale: Optional[float] = None, + rope_theta: Optional[float] = None, ): v_head_dim = kv.shape[-1] - q_nope.shape[-1] return torch.empty_like(kv[..., -v_head_dim:]) @@ -187,11 +204,6 @@ def fused_flattened_mla_with_cache_fake( @AttentionRegistry.register("MultiHeadLatentAttention") class MultiHeadLatentAttention(AttentionDescriptor): - @classmethod - def is_paged(cls) -> bool: - """Return if the attention op is paged or not.""" - return False - @classmethod def get_attention_layout(cls) -> AttentionLayout: """Get the attention layout expected by the backend.""" @@ -216,8 +228,8 @@ class MultiHeadLatentAttention(AttentionDescriptor): @classmethod def get_cache_initializers( - cls, source_attn_node: Node, cache_config: CacheConfig - ) -> CacheInitializerDict: + cls, source_attn_node: Node, cache_config: KvCacheConfig + ) -> ResourceHandlerDict: q_nope_fake = source_attn_node.args[0].meta["val"] q_pe_fake = source_attn_node.args[1].meta["val"] kv_fake = source_attn_node.args[2].meta["val"] @@ -226,56 +238,21 @@ class MultiHeadLatentAttention(AttentionDescriptor): head_dim = q_nope_fake.shape[-1] rope_dim = q_pe_fake.shape[-1] - def _get_k_cache(si: SequenceInfo): - assert not si.is_paged, "Paged cache not supported for MultiHeadLatentAttention" - return torch.empty( - si.num_pages, - si.page_size, + return { + "k_cache": UnpagedResourceHandler( num_kv_heads, head_dim + rope_dim, - device=si.device, - dtype=cache_config.dtype or kv_fake.dtype, - ) - - def _get_v_cache(si: SequenceInfo): - assert not si.is_paged, "Paged cache not supported for MultiHeadLatentAttention" - return torch.empty( - si.num_pages, - si.page_size, + dtype=cls.resolve_cache_dtype(cache_config.dtype, kv_fake.dtype), + ), + "v_cache": UnpagedResourceHandler( num_kv_heads, head_dim, - device=si.device, - dtype=cache_config.dtype or kv_fake.dtype, - ) - - return {"k_cache": _get_k_cache, "v_cache": _get_v_cache} - - @classmethod - def get_global_buffer_initializers(cls, source_attn_node: Node) -> BufferInitializerDict: - q_pe_fake = source_attn_node.args[1].meta["val"] - rope_head_dim = q_pe_fake.shape[-1] - rope_theta: float = 10000.0 # TODO: remove once MLA is unfused - - def _get_cos_sin_stacked(si: SequenceInfo): - if rope_theta is None: - return torch.empty(0, device=si.device) - return cls._precompute_inv_freq(si.max_seq_len, rope_head_dim, rope_theta).to(si.device) - - return { - f"cos_sin_stacked_{rope_head_dim}_{rope_theta}".replace(".", "_"): _get_cos_sin_stacked + dtype=cls.resolve_cache_dtype(cache_config.dtype, kv_fake.dtype), + ), } @classmethod def get_constants(cls, source_attn_node: Node) -> List[Constant]: - return [None] - - @staticmethod - def _precompute_inv_freq(seq_len: int, head_dim: int, rope_theta: float = 1e4) -> torch.Tensor: - inv_freq = 1.0 / (rope_theta ** (torch.arange(0, head_dim, 2).float() / head_dim)) - t = torch.arange(seq_len, device=inv_freq.device, dtype=inv_freq.dtype) - - freqs = torch.outer(t, inv_freq.to(t.device)) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - cos_sin_stacked = torch.stack([emb.cos().to(torch.bfloat16), emb.sin().to(torch.bfloat16)]) - return cos_sin_stacked + softmax_scale = None + rope_theta = 10000.0 # TODO: remove once MLA is unfused + return [softmax_scale, rope_theta] diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_attention.py index 09bc253708..378ca2639e 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_attention.py @@ -8,18 +8,17 @@ from torch._ops import OpOverloadPacket from torch._subclasses import FakeTensor from torch.fx import Node +from ....llmapi.llm_args import KvCacheConfig from ..utils.logger import ad_logger from ..utils.node_utils import extract_op_args from .attention_interface import ( AttentionDescriptor, AttentionLayout, AttentionRegistry, - BufferInitializerDict, - CacheConfig, - CacheInitializerDict, Constant, MHACallable, - SequenceInfo, + ResourceHandlerDict, + UnpagedResourceHandler, ) from .torch_attention import repeat_kv, update_kv_cache @@ -375,11 +374,6 @@ def torch_backend_mha_with_cache_fake( @AttentionRegistry.register("torch") class TorchBackendAttention(AttentionDescriptor): - @classmethod - def is_paged(cls) -> bool: - """Return if the attention op is paged or not.""" - return False - @classmethod def get_attention_layout(cls) -> AttentionLayout: """Get the attention layout expected by the source op and the cached attention op.""" @@ -404,8 +398,8 @@ class TorchBackendAttention(AttentionDescriptor): @classmethod def get_cache_initializers( - cls, source_attn_node: Node, cache_config: CacheConfig - ) -> CacheInitializerDict: + cls, source_attn_node: Node, cache_config: KvCacheConfig + ) -> ResourceHandlerDict: # source op is [bsnd] layout already k_fake: FakeTensor = source_attn_node.args[1].meta["val"] v_fake: FakeTensor = source_attn_node.args[2].meta["val"] @@ -413,33 +407,18 @@ class TorchBackendAttention(AttentionDescriptor): k_head_dim = k_fake.shape[3] v_head_dim = v_fake.shape[3] - def _get_k_cache(si: SequenceInfo): - assert not si.is_paged, "Paged cache not supported for torch backend" - return torch.empty( - si.num_pages, - si.page_size, + return { + "k_cache": UnpagedResourceHandler( num_kv_heads, k_head_dim, - device=si.device, - dtype=cache_config.dtype or k_fake.dtype, - ) - - def _get_v_cache(si: SequenceInfo): - assert not si.is_paged, "Paged cache not supported for torch backend" - return torch.empty( - si.num_pages, - si.page_size, + dtype=cls.resolve_cache_dtype(cache_config.dtype, k_fake.dtype), + ), + "v_cache": UnpagedResourceHandler( num_kv_heads, v_head_dim, - device=si.device, - dtype=cache_config.dtype or v_fake.dtype, - ) - - return {"k_cache": _get_k_cache, "v_cache": _get_v_cache} - - @classmethod - def get_global_buffer_initializers(cls, source_attn_node: Node) -> BufferInitializerDict: - return {} + dtype=cls.resolve_cache_dtype(cache_config.dtype, v_fake.dtype), + ), + } @classmethod def get_constants(cls, source_attn_node: Node) -> List[Constant]: diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py index c1c0255686..3fec7800d4 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py @@ -172,7 +172,7 @@ def torch_fake_quant_fp8_linear( - input_zp / weight_zp ignored """ if weight_quantized.dtype != torch.float8_e4m3fn: - raise TypeError("FP8 path requires weight_quantized.dtype == float8_e4m3fn") + raise TypeError("FP8 path requires weight_quantized.dtype == torch.float8_e4m3fn") s_in = _expect_single_scale(input_scale, "input_scale") s_w = _expect_single_scale(weight_scale, "weight_scale") diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py index 8a9daf7523..7aa8239c08 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py @@ -9,18 +9,17 @@ from torch._ops import OpOverloadPacket from torch._subclasses import FakeTensor from torch.fx import Node +from ....llmapi.llm_args import KvCacheConfig from ..utils.logger import ad_logger from ..utils.node_utils import extract_op_args from .attention_interface import ( AttentionDescriptor, AttentionLayout, AttentionRegistry, - BufferInitializerDict, - CacheConfig, - CacheInitializerDict, Constant, MHACallable, - SequenceInfo, + ResourceHandlerDict, + UnpagedResourceHandler, ) from .triton_kernels.attention_with_kv_cache import ( attention_kv_stage2, @@ -312,11 +311,6 @@ def flattened_mha_fake( @AttentionRegistry.register("triton") class TritonAttention(AttentionDescriptor): - @classmethod - def is_paged(cls) -> bool: - """Return if the attention op is paged or not.""" - return False - @classmethod def get_attention_layout(cls) -> AttentionLayout: """Get the attention layout expected by the source op and the cached attention op.""" @@ -341,8 +335,8 @@ class TritonAttention(AttentionDescriptor): @classmethod def get_cache_initializers( - cls, source_attn_node: Node, cache_config: CacheConfig - ) -> CacheInitializerDict: + cls, source_attn_node: Node, cache_config: KvCacheConfig + ) -> ResourceHandlerDict: # source op is [bsnd] layout already k_fake: FakeTensor = source_attn_node.args[1].meta["val"] v_fake: FakeTensor = source_attn_node.args[2].meta["val"] @@ -350,33 +344,18 @@ class TritonAttention(AttentionDescriptor): k_head_dim = k_fake.shape[3] v_head_dim = v_fake.shape[3] - def _get_k_cache(si: SequenceInfo): - assert not si.is_paged, "Paged cache not supported for triton" - return torch.empty( - si.num_pages, - si.page_size, + return { + "k_cache": UnpagedResourceHandler( num_kv_heads, k_head_dim, - device=si.device, - dtype=cache_config.dtype or k_fake.dtype, - ) - - def _get_v_cache(si: SequenceInfo): - assert not si.is_paged, "Paged cache not supported for triton" - return torch.empty( - si.num_pages, - si.page_size, + dtype=cls.resolve_cache_dtype(cache_config.dtype, k_fake.dtype), + ), + "v_cache": UnpagedResourceHandler( num_kv_heads, v_head_dim, - device=si.device, - dtype=cache_config.dtype or v_fake.dtype, - ) - - return {"k_cache": _get_k_cache, "v_cache": _get_v_cache} - - @classmethod - def get_global_buffer_initializers(cls, source_attn_node: Node) -> BufferInitializerDict: - return {} + dtype=cls.resolve_cache_dtype(cache_config.dtype, v_fake.dtype), + ), + } @classmethod def get_constants(cls, source_attn_node: Node) -> List[Constant]: diff --git a/tensorrt_llm/_torch/auto_deploy/llm_args.py b/tensorrt_llm/_torch/auto_deploy/llm_args.py index f892fa7bf7..eb22e116b7 100644 --- a/tensorrt_llm/_torch/auto_deploy/llm_args.py +++ b/tensorrt_llm/_torch/auto_deploy/llm_args.py @@ -9,7 +9,6 @@ from pydantic_settings import BaseSettings, SettingsConfigDict from ...llmapi.llm_args import ( BuildConfig, EagleDecodingConfig, - KvCacheConfig, SamplerType, TorchLlmArgs, _ParallelConfig, @@ -45,7 +44,6 @@ def _check_for_default_value_only( _TRANSFORMS_SHORTCUT_LOOKUP = { "attn_backend": ("insert_cached_attention.backend", "transformers_replace_cached_attn.backend"), - "free_mem_ratio": ("resize_kv_cache.free_mem_ratio",), "compile_backend": ("compile_model.backend",), "cuda_graph_batch_sizes": ("compile_model.cuda_graph_batch_sizes",), } @@ -191,25 +189,11 @@ class LlmArgs(DynamicYamlMixInForSettings, TorchLlmArgs, BaseSettings): device: str = Field(default="cuda", description="The device to use for the model.", frozen=True) - # TODO: see if we can just remove this field and use kv_cache_config.dtype instead? - kv_cache_dtype: str = Field( - default="auto", - description="Data type for KV cache. This is a temporary field until kv_cache_dtype is " - "supported in AutoDeploy.", - ) - sampler_type: Union[str, SamplerType] = Field( default=SamplerType.TorchSampler, description="The type of sampler to use. Options are TRTLLMSampler or TorchSampler. Defaults to TorchSampler.", ) - # NOTE: we do not support copy_on_partial_reuse in AutoDeploy yet - # see https://github.com/NVIDIA/TensorRT-LLM/issues/7142 - kv_cache_config: KvCacheConfig = Field( - default_factory=lambda **kwargs: KvCacheConfig(copy_on_partial_reuse=False, **kwargs), - description="KV cache config.", - ) - max_beam_width: int = Field( default=1, description="The maximum beam width. >1 is not supported by AutoDeploy.", @@ -240,12 +224,6 @@ class LlmArgs(DynamicYamlMixInForSettings, TorchLlmArgs, BaseSettings): default="flashinfer", description=_shortcut_description("Attention backend to use.", "attn_backend"), ) - free_mem_ratio: float = Field( - default=0.0, - description=_shortcut_description( - "The fraction of available memory to allocate for cache.", "free_mem_ratio" - ), - ) compile_backend: str = Field( default="torch-compile", description=_shortcut_description( @@ -263,16 +241,8 @@ class LlmArgs(DynamicYamlMixInForSettings, TorchLlmArgs, BaseSettings): ) ### SEQUENCE INTERFACE CONFIG ################################################################## - max_num_tokens: Optional[int] = Field(default=None, description="The maximum number of tokens.") max_seq_len: int = Field(default=512, ge=1, description="The maximum sequence length.") max_batch_size: int = Field(default=8, ge=1, description="The maximum batch size.") - attn_page_size: int = Field( - default=64, - ge=1, - description="Page size for attention (tokens_per_block). For triton and torch " - "backends, this should equal max_seq_len. Temporary field until tokens_per_block gets " - "properly passed through.", - ) def model_dump(self, *args, **kwargs): """Convert the arguments to a dictionary that can be used as kwargs for the LLM API.""" @@ -292,23 +262,6 @@ class LlmArgs(DynamicYamlMixInForSettings, TorchLlmArgs, BaseSettings): return kwargs ### VALIDATION ################################################################################# - @model_validator(mode="after") - # TODO: discuss what to do with this once we fully transition to the new inference optimizer - def update_attn_page_size(self): - # NOTE force attn_page_size to equal max_seq_len for triton backend - if self.transforms.get("insert_cached_attention", {}).get("backend") in [ - "triton", - "torch", - ]: - self.attn_page_size = self.max_seq_len - # NOTE: (hg) For transformers mode. This is ugly. - if self.transforms.get("transformers_replace_cached_attn", {}).get("backend") in [ - "triton", - "torch", - ]: - self.attn_page_size = self.max_seq_len - return self - @field_validator("model_factory", mode="after") @classmethod def model_factory_exists(cls, value: str) -> str: @@ -358,16 +311,6 @@ class LlmArgs(DynamicYamlMixInForSettings, TorchLlmArgs, BaseSettings): self.update_transforms_with_shortcuts() return self - @field_validator("kv_cache_config", mode="after") - @classmethod - def validate_kv_cache_config(cls, kv_cache_config: KvCacheConfig) -> KvCacheConfig: - if kv_cache_config.copy_on_partial_reuse: - kv_cache_config.copy_on_partial_reuse = False - ad_logger.warning( - "copy_on_partial_reuse is not supported by AutoDeploy. Setting it to False." - ) - return kv_cache_config - ### UTILITY METHODS ############################################################################ def create_factory(self) -> ModelFactory: """Create a model factory from the arguments.""" diff --git a/tensorrt_llm/_torch/auto_deploy/models/factory.py b/tensorrt_llm/_torch/auto_deploy/models/factory.py index 97a34e481b..ecfc11889b 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/factory.py +++ b/tensorrt_llm/_torch/auto_deploy/models/factory.py @@ -27,10 +27,6 @@ from torch._prims_common import DeviceLikeType from torch.export import Dim from torch.fx import GraphModule -from ..custom_ops.attention_interface import CacheConfig -from ..utils.cuda_mem_tracker import get_mem_info_in_mb -from ..utils.logger import ad_logger - DynamicShape = Dict[int, Dim] # indicating the dynamic shape in tensor dimension @@ -211,13 +207,15 @@ class ModelFactory(ABC): """Returns the sharding config for this model.""" return self._sharding_config - def get_cache_config(self) -> CacheConfig: - """Return the cache configuration for the model. + def get_cache_config_updates(self) -> Dict[str, Any]: + """Return updates for the KVCacheConfig for the model. Returns: - The cache configuration for the model. + A dictionary of updates for the KVCacheConfig for the model. + + Check tensorrt_llm/llmapi/llm_args.py for the KVCacheConfig fields. """ - return CacheConfig() + return {} def init_tokenizer(self) -> Optional[Any]: """Initialize the tokenizer for the model. @@ -301,22 +299,12 @@ class ModelFactory(ABC): """ - ad_logger.info("Loading and initializing weights.") - free_mem_pre, _ = get_mem_info_in_mb() - ad_logger.info(f"Free memory before loading weights (MB): {free_mem_pre}") self._to_maybe_random(model, device) - params_size = sum(p.numel() * p.element_size() for p in model.parameters()) - total_size_GB = params_size / (1024**3) - ad_logger.info(f"Estimated parameters memory: {total_size_GB:.2f} GB") if not self.skip_loading_weights: self.prefetch_checkpoint(force=True) self._load_checkpoint(model, device) - ad_logger.info("Loading and initializing weights. Done.") - free_mem_post, _ = get_mem_info_in_mb() - ad_logger.info(f"Free memory after loading weights (MB): {free_mem_post}") - @staticmethod def _to_maybe_random(model: nn.Module, device: DeviceLikeType): """A mix of ``model.to(device)`` and random initialization of parameters. diff --git a/tensorrt_llm/_torch/auto_deploy/models/hf.py b/tensorrt_llm/_torch/auto_deploy/models/hf.py index bf5384af15..fad26fa6e2 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/hf.py +++ b/tensorrt_llm/_torch/auto_deploy/models/hf.py @@ -33,7 +33,6 @@ from transformers.utils import ( WEIGHTS_NAME, ) -from ..custom_ops.attention_interface import CacheConfig from ..utils._config import deep_merge_dicts from ..utils.logger import ad_logger from .factory import ( @@ -261,18 +260,16 @@ class AutoModelForCausalLMFactory(AutoModelFactory): return self._quant_config_reader.get_config() return {} - def get_cache_config(self): - """Return kv cache dtype configuration.""" + def get_cache_config_updates(self): + """Return kv cache dtype updates.""" if not self._quant_config_reader: - return CacheConfig(dtype=None) + return {} - kv_cache_dtype = self._quant_config_reader.get_config().get("kv_cache_dtype") - torch_dtype = torch.float8_e4m3fn if kv_cache_dtype == "float8_e4m3fn" else None - assert torch_dtype in (torch.float8_e4m3fn, None), ( - f"Unsupported dtype: {torch_dtype}. Only torch.float8_e4m3fn is supported." + kv_cache_dtype = self._quant_config_reader.get_config().get("kv_cache_dtype", "auto") + assert kv_cache_dtype in ("fp8", "auto"), ( + f"Unsupported dtype: {kv_cache_dtype}. Only fp8 and auto are supported." ) - - return CacheConfig(dtype=torch_dtype) + return {"dtype": kv_cache_dtype} def init_tokenizer(self) -> Optional[Any]: """Initialize the tokenizer—either a custom name or the model's default.""" diff --git a/tensorrt_llm/_torch/auto_deploy/models/quant_config_reader.py b/tensorrt_llm/_torch/auto_deploy/models/quant_config_reader.py index c2b15198b6..9a16e0b972 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/quant_config_reader.py +++ b/tensorrt_llm/_torch/auto_deploy/models/quant_config_reader.py @@ -106,7 +106,7 @@ class ModelOPTQuantConfigReader(QuantConfigReader): if kv_algo: if kv_algo != "FP8": raise ValueError(f"KV cache quantization format {kv_algo} not supported.") - quant_config["kv_cache_dtype"] = "float8_e4m3fn" + quant_config["kv_cache_dtype"] = "fp8" self._quant_config = quant_config diff --git a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py index 5ee79f9572..a7e1d8525f 100644 --- a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py +++ b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py @@ -41,6 +41,7 @@ from tensorrt_llm._utils import nvtx_range from tensorrt_llm.llmapi.llm_args import ( ContextChunkingPolicy, EagleDecodingConfig, + KvCacheConfig, LoadFormat, SamplerType, TorchLlmArgs, @@ -48,9 +49,9 @@ from tensorrt_llm.llmapi.llm_args import ( from tensorrt_llm.llmapi.tokenizer import TokenizerBase from ...._utils import get_free_port, mpi_rank, mpi_world_size -from ....bindings.internal.batch_manager import CacheType from ....mapping import Mapping from ...distributed import Distributed +from ...pyexecutor.mamba_cache_manager import MambaHybridCacheManager from ...pyexecutor.model_engine import ModelEngine, PyTorchModelEngine from ...pyexecutor.py_executor import PyExecutor from ...pyexecutor.resource_manager import ( @@ -67,7 +68,6 @@ from ...pyexecutor.scheduler import ( ScheduledRequests, SimpleScheduler, ) -from ..custom_ops.attention_interface import SequenceInfo from ..distributed.common import initialize_or_skip from ..llm_args import LlmArgs from ..transform.optimizer import InferenceOptimizer @@ -82,44 +82,6 @@ class ReportingInfo: enable_iter_req_stats: bool = False -class _CacheManagerWithFakePool(KVCacheManager): - """We use the default KVCacheManager but with a fake pool by setting head_dim=0. - - The actual cache pools are managed by auto_deploy layerwise cache pools. - """ - - def __init__( - self, - kv_cache_config, - num_blocks: int, - tokens_per_block: int, - max_seq_len: int, - max_batch_size: int, - ): - self.num_blocks = num_blocks - super().__init__( - kv_cache_config=kv_cache_config, - kv_cache_type=CacheType.SELF, - num_layers=1, - num_kv_heads=1, - head_dim=0, - tokens_per_block=tokens_per_block, - max_seq_len=max_seq_len, - max_batch_size=max_batch_size, - mapping=Mapping(), - ) - - def calculate_max_num_blocks( - self, kv_cache_config, head_dim, tokens_per_block, mapping, dtype, kv_factor - ) -> Tuple[int, int]: - """Calculate the maximum number of blocks needed for the cache.""" - # TODO: this is VERY hacky... Ideally, we want to compute the number of blocks - # just like in the original implementation. However, let's wait for the layer-wise attention - # implementation before over-optimizing the function here - ad_logger.info("Using fake cache manager with head_dim=0 and num pages:", self.num_blocks) - return self.num_blocks, 0 - - class ADHiddenStateManager(Eagle3ResourceManager): def __init__( self, @@ -275,6 +237,7 @@ def construct_draft_llm_args( def create_draft_kv_cache_manager_maybe( draft_model_engine: Optional[PyTorchModelEngine], ad_config: LlmArgs, + kv_cache_config_tuned: KvCacheConfig, dist_mapping: Mapping, ) -> Optional[KVCacheManager]: if draft_model_engine is None or not draft_model_engine.model.model_config.is_generation: @@ -287,8 +250,8 @@ def create_draft_kv_cache_manager_maybe( model_engine=draft_model_engine, kv_cache_manager_cls=kv_cache_manager_cls, mapping=dist_mapping, - kv_cache_config=ad_config.kv_cache_config, - tokens_per_block=ad_config.attn_page_size, + kv_cache_config=kv_cache_config_tuned, + tokens_per_block=kv_cache_config_tuned.tokens_per_block, max_seq_len=ad_config.max_seq_len, max_batch_size=ad_config.max_batch_size, spec_config=ad_config.speculative_config, @@ -321,19 +284,33 @@ def _generate_dummy_request( ResourceManagerType.SPEC_RESOURCE_MANAGER ) - # check if we have a free slot available and free page available - if not slot_manager.slot_manager.free_slots or kv_cache_manager.get_num_free_blocks() == 0: + # check if it's a hybrid kv-cache manager + is_hybrid_cache = isinstance(kv_cache_manager, MambaHybridCacheManager) + + # check if we have a free page and free state available + if not kv_cache_manager.get_num_free_blocks(): + return None + if is_hybrid_cache and not kv_cache_manager.mamba_cache_free_blocks: return None # generate a dummy request dummy_request = kv_cache_manager.add_dummy_requests([request_id], **request_kwargs)[0] dummy_request.is_cuda_graph_dummy = True + # generate a dummy scheduled requests object + dummy_scheduled_requests = ScheduledRequests() + dummy_scheduled_requests.generation_requests.append(dummy_request) + + # if it's a hybrid kv-cache manager, we need to manually call prepare_resources again (not done + # in add_dummy_requests) + if is_hybrid_cache: + kv_cache_manager.prepare_resources(dummy_scheduled_requests) + # add to spec resource manager if spec_res_mgr: spec_res_mgr.add_dummy_requests([request_id]) - # TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/9883 clean up this hack + # NOTE: hack to avoid blocking a slot for the dummy request dummy_request.seq_slot = slot_manager.get_max_resource_count() dummy_request.py_seq_slot = dummy_request.seq_slot @@ -448,11 +425,6 @@ class ADEngine(ModelEngine): ): """Build the ADEngine using the LlmArgs that gets passed through from the LLM.""" - max_batch_size = ad_config.max_batch_size - max_seq_len = ad_config.max_seq_len - attn_page_size = ad_config.attn_page_size - max_num_tokens = ad_config.max_num_tokens - # update device to contain the current default device if it's in cuda device = torch.device(ad_config.device) if device.type == "cuda" and device.index is None: @@ -461,14 +433,17 @@ class ADEngine(ModelEngine): factory = ad_config.create_factory() - # initialize seq info object - seq_info = SequenceInfo( - max_seq_len=max_seq_len, - max_batch_size=max_batch_size, - page_size=attn_page_size, - max_num_tokens=max_num_tokens, + # Initialize CachedSequenceInterface - it will create SequenceInfo internally + # using tokens_per_block from kv_cache_config + cache_seq_interface = CachedSequenceInterface( + max_seq_len=ad_config.max_seq_len, + max_batch_size=ad_config.max_batch_size, + device=device, + kv_cache_config=ad_config.kv_cache_config, + max_num_tokens=ad_config.max_num_tokens, vocab_size_padded=factory.vocab_size_padded, ) + reporting_info = ReportingInfo( print_log=False, enable_iter_perf_stats=ad_config.enable_iter_perf_stats, @@ -483,8 +458,7 @@ class ADEngine(ModelEngine): # construct engine return cls( build_and_optimize, - seq_info, - device, + cache_seq_interface, ad_config=ad_config, mapping=mapping, dist=dist, @@ -495,14 +469,21 @@ class ADEngine(ModelEngine): def __init__( self, get_inference_model: GetInferenceModel, - seq_info: SequenceInfo, - device: DeviceLikeType, + cache_seq_interface: CachedSequenceInterface, ad_config: Optional[LlmArgs] = None, mapping: Optional[Mapping] = None, dist: Optional[Distributed] = None, reporting_info: ReportingInfo = ReportingInfo(), ) -> None: - """Initialize the engine with model and sequence information.""" + """Initialize the engine with model and CachedSequenceInterface. + + Args: + get_inference_model: Callable that builds the inference model. + cache_seq_interface: The CachedSequenceInterface containing sequence and cache config. + ad_config: Optional LLM configuration. + mapping: Optional distributed mapping configuration. + reporting_info: Reporting configuration for logging. + """ # NOTE (lucaslie): create a fake Namespace to satisfy PyExecutor requirements... # This is not correctly declared in the base ModelEngine class though... self.llm_args = SimpleNamespace() @@ -514,8 +495,8 @@ class ADEngine(ModelEngine): self.llm_args.batch_wait_timeout_ms = 0 self.llm_args.batch_wait_timeout_iters = 0 self.llm_args.batch_wait_max_tokens_ratio = 0.0 - self.llm_args.max_num_tokens = seq_info.max_num_tokens - self.llm_args.max_seq_len = seq_info.max_seq_len + self.llm_args.max_num_tokens = cache_seq_interface.info.max_num_tokens + self.llm_args.max_seq_len = cache_seq_interface.info.max_seq_len self.iter_counter = 0 self.iter_states = {} @@ -546,13 +527,10 @@ class ADEngine(ModelEngine): ) # For compatibility with PyTorchModelEngine utilities - self.batch_size = seq_info.max_batch_size + self.batch_size = cache_seq_interface.info.max_batch_size - # construct cache sequence interface - self.cache_seq_interface = CachedSequenceInterface( - sequence_info=seq_info, - device=device, - ) + # Store the cache sequence interface + self.cache_seq_interface = cache_seq_interface # build model self.model = get_inference_model(self.cache_seq_interface) @@ -628,12 +606,18 @@ class ADEngine(ModelEngine): # gather indices for logits logits_gather_indices: List[int] = [] - page_size = self.cache_seq_interface.info.page_size + page_size = kv_cache_manager.tokens_per_block dummy_token = -1 num_ctx_requests = len(context_requests) num_ctx_tokens = 0 num_generation_tokens = 0 + # Helper to get slot index - use mamba_cache_index if available (MambaHybridCacheManager) + def _get_slot_idx(request) -> int: + if hasattr(kv_cache_manager, "mamba_cache_index"): + return kv_cache_manager.mamba_cache_index[request.py_request_id] + return request.seq_slot + # look at context requests first for request in context_requests: # store input ids and pos of first token in sequence @@ -669,8 +653,8 @@ class ADEngine(ModelEngine): position_ids.append(list(range(input_pos[-1], seq_len_with_cache[-1]))) - # store seq slot idx - slot_idx.append(request.seq_slot) + # store seq slot idx (use mamba_cache_index if available) + slot_idx.append(_get_slot_idx(request)) use_initial_states.append(input_pos[-1] > 0) # store extra arguments @@ -749,7 +733,8 @@ class ADEngine(ModelEngine): num_generation_tokens += 1 + get_draft_token_length(request) request.py_batch_idx = request.seq_slot - slot_idx.append(request.seq_slot) + # store seq slot idx (use mamba_cache_index if available) + slot_idx.append(_get_slot_idx(request)) use_initial_states.append(input_pos[-1] > 0) seq_len.append(len(input_ids[-1])) @@ -941,11 +926,11 @@ def create_draft_model_engine_maybe( draft_spec_config = copy.copy(spec_config) - kv_cache_config = ad_config.kv_cache_config + kv_cache_config_tuned = target_engine.cache_seq_interface.kv_cache_config_tuned attn_runtime_features = AttentionRuntimeFeatures( chunked_prefill=ad_config.enable_chunked_prefill, - cache_reuse=kv_cache_config.enable_block_reuse, + cache_reuse=kv_cache_config_tuned.enable_block_reuse, has_speculative_draft_tokens=has_spec_drafter, chunk_size=target_engine.llm_args.max_num_tokens, ) @@ -1108,40 +1093,15 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer else None ) - # check kvcache config for partial block reuse - # TODO: copy_on_partial_reuse is not supported yet, see - # https://github.com/NVIDIA/TensorRT-LLM/issues/7142 for more details. - enable_block_reuse = ad_config.kv_cache_config.enable_block_reuse - enable_partial_reuse = ad_config.kv_cache_config.enable_partial_reuse - copy_on_partial_reuse = ad_config.kv_cache_config.copy_on_partial_reuse - if enable_block_reuse and enable_partial_reuse and copy_on_partial_reuse: - raise RuntimeError( - f"partial block reuse with {copy_on_partial_reuse=} set to True is NOT supported" - " in AutoDeploy. Please set it to False via the kv_cache_config.copy_on_partial_reuse " - "field in tensorrt_llm._torch.auto_deploy.llm_args.LlmArgs." - ) - - # TODO: detect whether SSM layer is present in the model and raise an error or disable block - # reuse with a warning --> see https://github.com/NVIDIA/TensorRT-LLM/issues/7142. For now, we - # just emit a general warning. - if enable_block_reuse: - ad_logger.warning( - f"{enable_block_reuse=} is enabled. Note that this is not supported for SSM layers and" - " may lead to incorrect results if the model contains SSM layers." - ) - # resource managers - kv_cache_manager = _CacheManagerWithFakePool( - ad_config.kv_cache_config, - num_blocks=engine.cache_seq_interface.info.num_pages, - tokens_per_block=ad_config.attn_page_size, - max_seq_len=ad_config.max_seq_len, - max_batch_size=ad_config.max_batch_size, - ) + # KVCacheManager is now created and managed by CachedSequenceInterface during the + # initialize_cache/resize_kv_cache transform pipeline. Get it from the interface. + kv_cache_manager = engine.cache_seq_interface.kv_cache_manager + kv_cache_config_tuned = engine.cache_seq_interface.kv_cache_config_tuned seq_slot_manager = SeqSlotManager(max_num_sequences=max_num_sequences) draft_kv_cache_manager = create_draft_kv_cache_manager_maybe( - draft_model_engine, ad_config, dist_mapping + draft_model_engine, ad_config, kv_cache_config_tuned, dist_mapping ) resource_manager = ResourceManager( @@ -1160,7 +1120,7 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer # Chunked prefill if ad_config.enable_chunked_prefill: - chunk_unit_size = ad_config.attn_page_size + chunk_unit_size = kv_cache_config_tuned.tokens_per_block chunking_policy = ContextChunkingPolicy.FIRST_COME_FIRST_SERVED ctx_chunk_config: Tuple[StrEnum, int] = (chunking_policy, chunk_unit_size) else: diff --git a/tensorrt_llm/_torch/auto_deploy/shim/demollm.py b/tensorrt_llm/_torch/auto_deploy/shim/demollm.py index eb14888b91..3cedb2300d 100644 --- a/tensorrt_llm/_torch/auto_deploy/shim/demollm.py +++ b/tensorrt_llm/_torch/auto_deploy/shim/demollm.py @@ -71,14 +71,15 @@ class DemoEngine(ADEngine): currently available token slots in the assigned pages and assign a new, previously unassigned page if needed. """ - si = self.cache_seq_interface.info - page_assignments = si.page_assignments + page_assignments = self.cache_seq_interface.info.page_assignments + num_pages = self.cache_seq_interface.kv_cache_manager.blocks_in_primary_pool + tokens_per_block = self.cache_seq_interface.kv_cache_manager.tokens_per_block - free_pages = set(range(si.num_pages)) - {i for pages in page_assignments for i in pages} + free_pages = set(range(num_pages)) - {i for pages in page_assignments for i in pages} updated_assignments = [] for t_l, pages in zip(total_lens, page_assignments): - extra_tokens = t_l - len(pages) * si.page_size - num_extra_pages = (extra_tokens // si.page_size) + (extra_tokens > 0) + extra_tokens = t_l - len(pages) * tokens_per_block + num_extra_pages = (extra_tokens // tokens_per_block) + (extra_tokens > 0) updated_assignments.append(pages + [free_pages.pop() for _ in range(num_extra_pages)]) return updated_assignments diff --git a/tensorrt_llm/_torch/auto_deploy/shim/interface.py b/tensorrt_llm/_torch/auto_deploy/shim/interface.py index 9abfa911ba..0ff7c80fac 100644 --- a/tensorrt_llm/_torch/auto_deploy/shim/interface.py +++ b/tensorrt_llm/_torch/auto_deploy/shim/interface.py @@ -1,95 +1,473 @@ -from typing import Callable, Dict, List, Optional, Tuple, final +import copy +import functools +import math +from typing import Callable, Dict, Optional, Tuple, Union, final import torch import torch.nn as nn from torch._prims_common import DeviceLikeType -from ..custom_ops.attention_interface import GetCacheCallable, SequenceInfo +import tensorrt_llm.bindings +from tensorrt_llm.llmapi.llm_args import KvCacheConfig +from tensorrt_llm.mapping import Mapping + +from ...pyexecutor.mamba_cache_manager import MambaHybridCacheManager +from ...pyexecutor.resource_manager import KVCacheManager +from ..custom_ops.attention_interface import ( + PagedResourceHandler, + ResourceHandler, + ResourceHandlerDict, + SequenceInfo, + StateResourceHandler, +) +from ..distributed.common import all_gather_object, get_world_size +from ..distributed.common import is_initialized as is_distributed_initialized +from ..utils.cuda_mem_tracker import bytes_to, get_mem_info +from ..utils.logger import ad_logger + +CacheTypeCpp = tensorrt_llm.bindings.internal.batch_manager.CacheType +DataType = tensorrt_llm.bindings.DataType + + +def with_pre_callback(method, callback): + """Wrap method to call callback before the original method.""" + + @functools.wraps(method) + def wrapper(*args, **kwargs): + callback() + return method(*args, **kwargs) + + return wrapper @final class CachedSequenceInterface: - """An interface responsible for maintaining information about sequences and their caches.""" + """An interface responsible for maintaining information about sequences and their caches. + + This class is the single source of truth for sequence and cache configuration. It creates + SequenceInfo internally, ensuring that tokens_per_block and other fields from KvCacheConfig + are always consistent. + """ def __init__( - self, sequence_info: SequenceInfo, device: Optional[DeviceLikeType] = None + self, + max_seq_len: int, + max_batch_size: int, + device: Optional[DeviceLikeType] = None, + kv_cache_config: Optional[KvCacheConfig] = None, + max_num_tokens: Optional[int] = None, + vocab_size_padded: Optional[int] = None, ) -> None: + """Initialize the CachedSequenceInterface. + + Args: + max_seq_len: Maximum sequence length including input and generated tokens. + max_batch_size: Maximum number of sequences (requests) that can be processed. + device: Target device for tensors. Defaults to "cuda". + kv_cache_config: KV cache configuration. If None, uses default KvCacheConfig. + max_num_tokens: Maximum total tokens across all sequences. If None, computed from + max_seq_len and max_batch_size. + vocab_size_padded: Padded vocabulary size of the model. + """ # TODO (lucaslie): this is somewhat circular/confusing. Here `device` denotes the desired # device and not the actual device unlike, e.g., in SequenceInfo. We rely on the attribute # here to read the desired device across the inference optimizer pipeline. We should ideally # think about a better way to handle this, # see https://github.com/NVIDIA/TensorRT-LLM/issues/8371 self.device = device or "cuda" - self.info = sequence_info - self._cache_initializers: Dict[str, GetCacheCallable] = {} + + # Initialize kv_cache_config first since SequenceInfo needs tokens_per_block from it + self._kv_cache_config_original: KvCacheConfig = kv_cache_config or KvCacheConfig() + self._kv_cache_config_tuned: Optional[KvCacheConfig] = None + + # Create SequenceInfo internally, using tokens_per_block from kv_cache_config + self.info = SequenceInfo( + max_seq_len=max_seq_len, + max_batch_size=max_batch_size, + tokens_per_block=self._kv_cache_config_original.tokens_per_block, + max_num_tokens=max_num_tokens, + vocab_size_padded=vocab_size_padded, + ) + + self._resource_lookup: ResourceHandlerDict = {} self._caches: Dict[str, torch.Tensor] = {} + # KVCacheManager (or MambaHybridCacheManager) for managed resources + self._kv_cache_manager: Optional[Union[KVCacheManager, MambaHybridCacheManager]] = None + # Ordered dicts tracking resource handlers by type + self._paged_cache_order: ResourceHandlerDict = {} # Paged resources (kv caches) + self._state_resource_order: ResourceHandlerDict = {} # State resources (ssm states) @property def args(self) -> Tuple[torch.Tensor, ...]: """Return all the graph arguments owned by this interface.""" - return (*self.info.args, *self._caches.values()) + return tuple(self.named_args.values()) @property def named_args(self) -> Dict[str, torch.Tensor]: """Return all the named arguments owned by this interface.""" return {**self.info.named_args, **self._caches} - @property - def all_future_arg_names(self) -> List[str]: - """Return all the argument names owned by this interface including uninitialized caches.""" - return list(self.info.named_args.keys()) + list(self._cache_initializers.keys()) - def to(self, *args, **kwargs) -> None: self.info.to(*args, **kwargs) - if self._caches: - for cache in self._caches.values(): + # Only move locally-allocated caches (paged/state caches are managed by cache managers) + for name, cache in self._caches.items(): + if name not in self._paged_cache_order and name not in self._state_resource_order: cache.to(*args, **kwargs) - def add_cache(self, name: str, get_cache: GetCacheCallable) -> None: - """Add a cache initializer to the cache interface.""" - self._cache_initializers[name] = get_cache + def update_kv_cache_config(self, **kwargs) -> None: + """Update the KVCacheConfig with the given kwargs.""" + for k, v in kwargs.items(): + if k in type(self._kv_cache_config_original).model_fields: + setattr(self._kv_cache_config_original, k, v) + else: + raise ValueError(f"Invalid KVCacheConfig field: {k}") - def initialize_caches(self) -> int: - """Initialize caches using the cache initializers.""" - assert not self._caches, "Caches already initialized." - self.info.to(self.device) - self._caches = { - name: get_cache(self.info) for name, get_cache in self._cache_initializers.items() + def add_resource(self, name: str, resource_handler: ResourceHandler) -> None: + """Add a resource handler to the cache interface.""" + self._resource_lookup[name] = resource_handler + + def _create_kv_cache_manager(self, max_tokens: Optional[int] = None) -> int: + """Create KVCacheManager or MambaHybridCacheManager with multi-layer byte-level params. + + This uses a multi-layer approach with byte-level abstraction: + - Paged resources: Each resource gets its own layer in KVCacheManager with + num_kv_heads=bytes_per_token for that resource, head_dim=1. + - State resources: Each resource gets its own layer in MambaCacheManager with + head_dim=bytes_per_slot for that resource. + + Each layer's cache is contiguous, avoiding byte-offset slicing within layers. + + When state resources exist, MambaHybridCacheManager is used to manage both. + + Important NOTE on contiguity of managed resources: + - We only guarantee contiguity for an individual page or an individual state slot. + - Outside of these individual pages/slots, resources are NOT guaranteed to be contiguous. + + Args: + max_tokens: Maximum number of tokens to allocate. If provided, it will use the min value + between this value and max_tokens in kv_cache_config. + + Returns: + The final number of tokens that can be cached in the KVCacheManager. + NOTE: this number may differ from the provided ``max_tokens`` arg for two reasons: + 1. the final number of tokens is synced (min) across ranks + 2. rounding for getting a multiple of tokens_per_block + """ + # Build per-layer num_kv_heads list for paged resources + # Each paged resource becomes one "layer" with num_kv_heads = bytes_per_token + num_kv_heads_per_layer = [ + math.prod(h.token_shape) * h.dtype.itemsize for h in self._paged_cache_order.values() + ] + + # Calculate total bytes per slot for state resources (modeled as single layer) + cumulative_bytes_per_state = [0] + for name, handler in self._state_resource_order.items(): + byte_size = math.prod(handler.state_shape) * handler.dtype.itemsize + cumulative_bytes_per_state.append(cumulative_bytes_per_state[-1] + byte_size) + + # Make a deep copy of the kv_cache_config to avoid modifying the original object + kv_cache_config = copy.deepcopy(self._kv_cache_config_original) + + # Disable copy_on_partial_reuse + # TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/10966 + if kv_cache_config.copy_on_partial_reuse: + kv_cache_config.copy_on_partial_reuse = False + ad_logger.info("Disabling copy_on_partial_reuse for AutoDeploy backend.") + + # Update kv_cache_config based on max_tokens if provided + if max_tokens is not None: + # sync max_tokens across ranks + if is_distributed_initialized(): + max_tokens_gathered = [None] * get_world_size() + all_gather_object(max_tokens_gathered, max_tokens) + max_tokens = min(max_tokens_gathered) + kv_cache_config.free_gpu_memory_fraction = None + kv_cache_config.max_tokens = min(kv_cache_config.max_tokens or max_tokens, max_tokens) + + # Check if we should disable block reuse + if kv_cache_config.enable_block_reuse and not self.is_paged(): + kv_cache_config.enable_block_reuse = False + ad_logger.info(f"Setting {kv_cache_config.enable_block_reuse=} for non-paged models.") + + # Make sure to set free_gpu_memory_fraction to None if set to 0.0 + # NOTE: KVCacheConfig validator enforces that free_gpu_memory_fraction must be between 0.0 + # and 1.0 but we allow 0.0 to be set to disable resizing (corresponding to None in the + # manager). + if kv_cache_config.free_gpu_memory_fraction == 0.0: + kv_cache_config.free_gpu_memory_fraction = None + + # Common KV cache parameters + kv_cache_kwargs = { + "kv_cache_config": kv_cache_config, + "kv_cache_type": CacheTypeCpp.SELFKONLY, # kv_factor=1, treat K, V separately + "num_layers": len(self._paged_cache_order), # correct num layers + "num_kv_heads": num_kv_heads_per_layer, # per-layer bytes_per_token + "head_dim": 1, # all bytes in num_kv_heads + "tokens_per_block": kv_cache_config.tokens_per_block, + "max_seq_len": self.info.max_seq_len, + "max_batch_size": self.info.max_batch_size, + "mapping": Mapping(), + # NOTE (lucaslie): this is the only 1-byte dtype currently supported by the + # KVCacheManager. Ideally, we would use the typical uint8 dtype for byte-level + # abstraction, but this is not supported. + "dtype": DataType.FP8, # 1-byte dtype for byte-level abstraction + "layer_mask": None, + # NOTE (lucaslie): we can always run with False here since when we are estimating, we + # are explicitly setting the max_tokens in which case it's okay to use False here since + # we don't rely on free_gpu_memory_fraction inside the KVCacheManager. This is similar + # to _torch.pyexecutor._util.KVCacheCreator, which explicitly estimates the max_tokens + # outside of the KVCacheManager. + "is_estimating_kv_cache": False, } + + # update args if we are just doing a dummy cache manager + if not len(self._paged_cache_order): + kv_cache_kwargs.update( + { + "num_layers": 1, + "num_kv_heads": 1, + "head_dim": 1, + } + ) + + if self._state_resource_order: + # NOTE: +1 for cuda graph padding + kv_cache_kwargs["max_batch_size"] = self.info.max_num_state_slots + + self._kv_cache_manager = MambaHybridCacheManager( + # Mamba params for single-layer byte buffer + mamba_d_state=1, + mamba_d_conv=1, # conv_states will have shape [..., 0] (empty) + mamba_num_heads=1, + mamba_n_groups=1, + mamba_head_dim=cumulative_bytes_per_state[-1], # Total bytes per slot + mamba_num_layers=1, # Single layer + mamba_layer_mask=None, # Single enabled layer + mamba_cache_dtype=torch.uint8, # Byte-level + mamba_ssm_cache_dtype=torch.uint8, # Byte-level + # KV cache params + **kv_cache_kwargs, + ) + else: + # No state resources - use pure KVCacheManager + self._kv_cache_manager = KVCacheManager(**kv_cache_kwargs) + + # store the tuned kv_cache_config + self._kv_cache_config_tuned = kv_cache_config + + # Ensure cache_loc capacity is sufficient for the new KVCacheManager + blocks_in_primary_pool = self._kv_cache_manager.blocks_in_primary_pool + tokens_per_block = self._kv_cache_manager.tokens_per_block + self.info.estimate_cache_loc_capacity(blocks_in_primary_pool) + + # Create paged resource views from per-layer buffers + for layer_idx, (name, handler) in enumerate(self._paged_cache_order.items()): + view = self._kv_cache_manager.get_buffers(layer_idx, kv_layout="NHD") + view = view.view(blocks_in_primary_pool, tokens_per_block, -1).view(handler.dtype) + view = view.view(blocks_in_primary_pool, tokens_per_block, *handler.token_shape) + + # Sanity check on contiguity of individual pages + view_one_page = view[0] + assert view_one_page.is_contiguous(), f"Per-page cache for {name} is not contiguous" + + self._caches[name] = view + + for layer_idx, (name, handler) in enumerate(self._state_resource_order.items()): + num_states = len(self._kv_cache_manager.state_indices) + # Get the single-layer ssm_states buffer + # ssm_states shape: [1, num_states, 1, total_bytes_per_slot, 1] + ssm_buffer = self._kv_cache_manager.get_ssm_states(0) + # Flatten to [max_batch, total_bytes_per_slot_for_all_layers] + ssm_buffer = ssm_buffer.view(num_states, -1) + + offset_start = cumulative_bytes_per_state[layer_idx] + offset_end = cumulative_bytes_per_state[layer_idx + 1] + + # Slice at byte offset, reinterpret dtype, reshape + view = ssm_buffer[:, offset_start:offset_end] + view = view.view(handler.dtype) + view = view.view(num_states, *handler.state_shape) + + # Sanity check on contiguity of individual state slots + assert view[0].is_contiguous(), f"Per-slot state for {name} cache is not contiguous" + + self._caches[name] = view + + # Patch shutdown to clear cache views before pool release + self._kv_cache_manager.shutdown = with_pre_callback( + self._kv_cache_manager.shutdown, + self._clear_cache_views, + ) + + max_resource_count = self._kv_cache_manager.get_max_resource_count() + max_tokens_final = max_resource_count * self._kv_cache_manager.tokens_per_block + + return max_tokens_final + + def initialize_resources(self) -> int: + """Initialize resources - paged/state caches via cache managers, others separately. + + Paged resources are managed by KVCacheManager (or the KV portion of MambaHybridCacheManager). + State resources are managed by the Mamba portion of MambaHybridCacheManager. + Other resources are allocated locally as a fallback. + + Returns: + The number of caches initialized. + """ + assert not self._caches and not self._paged_cache_order, "Caches already initialized." + self.info.to(self.device) + + # Separate resources by type + for name, handler in self._resource_lookup.items(): + if isinstance(handler, PagedResourceHandler): + self._paged_cache_order[name] = handler + self._caches[name] = None # Will be set by _create_kv_cache_manager + elif isinstance(handler, StateResourceHandler): + self._state_resource_order[name] = handler + self._caches[name] = None # Will be set by _create_kv_cache_manager + else: + # Unknown handler type - allocate locally (fallback) + self._caches[name] = handler.allocate(self.info) + + # Create unified cache manager (handles both paged and state resources) + if self.needs_resize() or self._requires_token_estimate(): + max_tokens_estimate = self.info.estimate_cache_tokens_per_forward() + else: + # if we don't need a resize, we will just use the original settings in kv_cache_config + # instead of passing in an overwrite here. + max_tokens_estimate = None + self._create_kv_cache_manager(max_tokens=max_tokens_estimate) + return len(self._caches) - def current_cache_size_bytes(self) -> int: - """Calculate and return the total size of all caches in bytes.""" - total_size = 0 - for name, cache in self._caches.items(): - # this hack is needed since _caches also contains global buffers such as freqs_cis. - if "cache" in name: - total_size += cache.element_size() * cache.numel() - return total_size + def is_paged(self) -> bool: + """Return True if all resources are paged and part of the KVCacheManager.""" + return set(self._paged_cache_order.keys()) == set(self._resource_lookup.keys()) - def current_kv_cache_size_bytes(self) -> int: - """Return size in bytes of KV caches only (k_cache_*, v_cache_*). + def _requires_token_estimate(self) -> bool: + """Check if our kv_cache_config requires.""" + return ( + self._kv_cache_config_original.free_gpu_memory_fraction in [None, 0.0] + and self._kv_cache_config_original.max_tokens is None + ) - Excludes SSM/conv/etc. which do not scale with num_pages. + def needs_resize(self) -> bool: + """Check if we need a resize or not.""" + has_paged = bool(self._paged_cache_order) + return has_paged and self._kv_cache_config_original.free_gpu_memory_fraction not in [ + None, + 0.0, + ] + + def resize_kv_cache_manager(self, mem_exclude: int = 0) -> None: + """Shutdown existing KVCacheManager and create new one with optimal capacity. + + Args: + mem_exclude: Extra memory to exclude from the calculation of optimal capacity. + This is in bytes and typically the memory reserved for the forward pass. + + This implements the two-phase approach: after running a forward pass during estimation + to allocate intermediate memory, call this method to recreate the KVCacheManager. + The new manager will compute optimal capacity based on current free GPU memory + via calculate_max_num_blocks. """ - total_size = 0 - for name, cache in self._caches.items(): - if name.startswith("k_cache_") or name.startswith("v_cache_"): - total_size += cache.element_size() * cache.numel() - return total_size + if not self.needs_resize(): + return - def resize_cache(self, new_num_pages: int): - """Resize the cache to the new number of pages.""" - # TODO: We should do some sanity check on the new number of pages. - self.info.num_pages = new_num_pages - for name, cache in self._caches.items(): - # We assume cache is a tensor of shape (max_batch_size, page_size, n_heads, head_dim) - # TODO: cache resize should ideally be handled via a callback to the AttentionDescriptor - # to avoid hard-coding any assumptions about the cache shape or its "pagedness" - if "k_cache" in name or "v_cache" in name: - current_shape = cache.shape - new_shape = (new_num_pages, *current_shape[1:]) - cache.resize_(new_shape) + # get per-token cache size for resizable resources + paged_cache_bytes_per_token = self._kv_cache_manager.get_cache_bytes_per_token() + + # get total cache size of state resources that cannot be resized + # NOTE: this does NOT include resources handled OUTSIDE of the KVCacheManager or + # MambaHybridCacheManager. Those will persistent and will be accounted for via free_mem even + # after the initialize kv_cache_manager is shutdown. + state_cache_bytes_total = sum( + cache.numel() * cache.element_size() + for name, cache in self._caches.items() + if name in self._state_resource_order + ) + + # get unmanaged cache size + unmanaged_cache_bytes_total = sum( + cache.numel() * cache.element_size() + for name, cache in self._caches.items() + if name not in self._paged_cache_order and name not in self._state_resource_order + ) + + # Shutdown existing KVCacheManager to free memory + self._kv_cache_manager.shutdown() + + # Get current free GPU memory (roughly includes model weights + non-managed resources) + _, free_mem, *_ = get_mem_info(empty_cache=True) + + # Compute available memory for the KVCacheManager + # NOTE: free_mem was obtained AFTER shutdown of initial KVCacheManager - hence it accounts + # for unmanaged resources but it does NOT account for state resources since those were + # freed as part of the shutdown. + free_gpu_memory_fraction = self._kv_cache_config_original.free_gpu_memory_fraction + mem_for_paged_optimal = ( + free_mem - state_cache_bytes_total - mem_exclude + ) * free_gpu_memory_fraction + # Check how many tokens we can fit into the paged cache + max_tokens_optimal = int(mem_for_paged_optimal // paged_cache_bytes_per_token) + + # Create new KVCacheManager with final capacity + max_tokens_final = self._create_kv_cache_manager(max_tokens=max_tokens_optimal) + + # Log resulting memory information + mem_info = [ + f"free_mem={bytes_to(free_mem, unit='GB'):.2f}GB", + f"free_gpu_memory_fraction={free_gpu_memory_fraction}", + f"mem_exclude={bytes_to(mem_exclude, unit='GB'):.2f}GB", + f"mem_exclude_for_state={bytes_to(state_cache_bytes_total, unit='GB'):.2f}GB", + f"mem_for_paged_optimal={bytes_to(mem_for_paged_optimal, unit='GB'):.2f}GB", + ] + total_cache_bytes = ( + mem_for_paged_optimal + state_cache_bytes_total + unmanaged_cache_bytes_total + ) + mem_cache_info = [ + f"Max Tokens={max_tokens_final}", + f"Paged={bytes_to(mem_for_paged_optimal, unit='GB'):.2f}GB", + f"State={bytes_to(state_cache_bytes_total, unit='GB'):.2f}GB", + f"Unmanaged={bytes_to(unmanaged_cache_bytes_total, unit='GB'):.2f}GB", + f"Total={bytes_to(total_cache_bytes, unit='GB'):.2f}GB", + ] + ad_logger.info(f"Mem info for resize: {' | '.join(mem_info)}") + ad_logger.info(f"Final Cache Mem: {' | '.join(mem_cache_info)}") + + @property + def kv_cache_manager(self) -> Optional[KVCacheManager]: + """Return the KVCacheManager managing paged resources, or None if not initialized.""" + assert self._kv_cache_manager is not None, "KVCacheManager not initialized." + return self._kv_cache_manager + + @property + def kv_cache_config_tuned(self) -> KvCacheConfig: + """Return the KVCacheConfig tuned for the KVCacheManager.""" + assert None not in [self._kv_cache_manager, self._kv_cache_config_tuned], ( + "KVCacheManager not initialized." + ) + return self._kv_cache_config_tuned + + @property + def kv_cache_config(self) -> KvCacheConfig: + """Return the original KVCacheConfig as passed in.""" + return self._kv_cache_config_original + + def _clear_cache_views(self) -> None: + """Set paged and state cache views to None before pool release.""" + self._kv_cache_config_tuned = None + for name in self._paged_cache_order: + self._caches[name] = None + for name in self._state_resource_order: + self._caches[name] = None + + def shutdown(self) -> None: + """Shutdown and release all resources.""" + if self._kv_cache_manager is not None: + self._kv_cache_manager.shutdown() + self._kv_cache_config_tuned = None + self._caches.clear() GetInferenceModel = Callable[[CachedSequenceInterface], nn.Module] diff --git a/tensorrt_llm/_torch/auto_deploy/transform/interface.py b/tensorrt_llm/_torch/auto_deploy/transform/interface.py index 6d0ba54c29..571a632f7d 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/interface.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/interface.py @@ -5,7 +5,8 @@ This module defines the base classes and interfaces for all transforms. import time from abc import ABC -from contextlib import nullcontext +from contextlib import contextmanager, nullcontext +from dataclasses import dataclass from enum import Enum from functools import total_ordering, wraps from typing import Any, Callable, Dict, Mapping, Tuple, Type, Union, final @@ -24,8 +25,60 @@ from ..utils._graph import ( placeholders_on_meta, run_shape_prop, ) +from ..utils.cuda_mem_tracker import get_mem_info from ..utils.logger import ad_logger +# ANSI color codes for log formatting (set to False to disable colors) +# NOTE: colors disabled by default to make logging in CI/CD pipelines easier to read +_ENABLE_LOG_COLORS = False + + +class _Colors: + RESET = "\033[0m" if _ENABLE_LOG_COLORS else "" + BOLD = "\033[1m" if _ENABLE_LOG_COLORS else "" + DIM = "\033[2m" if _ENABLE_LOG_COLORS else "" + CYAN = "\033[36m" if _ENABLE_LOG_COLORS else "" + MAGENTA = "\033[35m" if _ENABLE_LOG_COLORS else "" + GREEN = "\033[32m" if _ENABLE_LOG_COLORS else "" + YELLOW = "\033[33m" if _ENABLE_LOG_COLORS else "" + ORANGE = "\033[38;5;208m" if _ENABLE_LOG_COLORS else "" + + +@dataclass +class MemStats: + """Memory statistics snapshot for tracking CUDA memory usage.""" + + tot: float + free: float + resv: float + alloc: float + frag: float + + def diff(self, other: "MemStats") -> "MemStats": + """Calculate the difference (self - other).""" + return MemStats( + tot=self.tot - other.tot, + free=self.free - other.free, + resv=self.resv - other.resv, + alloc=self.alloc - other.alloc, + frag=self.frag - other.frag, + ) + + def to_dict(self) -> Dict[str, float]: + """Convert to dictionary for serialization.""" + return { + "tot": self.tot, + "free": self.free, + "resv": self.resv, + "alloc": self.alloc, + "frag": self.frag, + } + + @classmethod + def from_dict(cls, d: Dict[str, float]) -> "MemStats": + """Create from dictionary.""" + return cls(tot=d["tot"], free=d["free"], resv=d["resv"], alloc=d["alloc"], frag=d["frag"]) + class TransformError(Exception): """An exception raised when a transform fails.""" @@ -114,6 +167,11 @@ class TransformConfig(BaseModel): description="Whether this transform requires shape propagation before it is applied.", ) + expect_mem_change: bool = Field( + default=False, + description="Whether this transform is expected to cause changes in CUDA memory stats.", + ) + AutodeployMeta = Dict[str, Any] _UntypedInferenceOptimizerConfig = Dict[str, Any] @@ -223,6 +281,7 @@ class BaseTransform(ABC): config: TransformConfig # overwrite type hint if other config cls is used in subclass! _autodeploy_meta_key: str = "_autodeploy" _history_key: str = "transform_history" + _mem_history_key: str = "mem_history" _transform_key: str # Set by TransformRegistry.register() decorator @classmethod @@ -324,6 +383,9 @@ class BaseTransform(ABC): # show debug info for debug config ad_logger.debug(f"{t_name} config: {self.config}") + # capture memory stats at the start + mem_pre = self._get_mem_stats(empty_cache=True) + # store some timing information elapsed_time_total = -time.time() elapsed_time_pre_cleanup = 0.0 @@ -340,23 +402,28 @@ class BaseTransform(ABC): self.config.requires_shape_prop, info.is_clean, info.has_valid_shapes, + phase="pre", ) elapsed_time_pre_cleanup += time.time() # run the transform in a error-handling wrapper if desired elapsed_time_apply = -time.time() - if self.config.skip_on_error: - try: + with self._apply_logging_context(): + self._log_info("applying transform...") + if self.config.skip_on_error: + try: + mod, info_apply = self._apply_per_gm_or_whole_model( + mod, cm, factory, shared_config + ) + except Exception as e: + error_msg = f"Transform {t_name} failed" + ad_logger.warning(f"{error_msg}: {e}") + info_apply = TransformInfo(skipped=True, num_matches=0) + else: + # handle this here normally to improve debugging and error message mod, info_apply = self._apply_per_gm_or_whole_model( mod, cm, factory, shared_config ) - except Exception as e: - error_msg = f"Transform {t_name} failed" - ad_logger.warning(f"{error_msg}: {e}") - info_apply = TransformInfo(skipped=True, num_matches=0) - else: - # handle this here normally to improve debugging and error message - mod, info_apply = self._apply_per_gm_or_whole_model(mod, cm, factory, shared_config) elapsed_time_apply += time.time() # we cannot say it's clean if the previous wasn't clean even if this one is @@ -371,31 +438,40 @@ class BaseTransform(ABC): self.config.run_shape_prop, info.is_clean, info.has_valid_shapes, + phase="post", ) elapsed_time_post_cleanup += time.time() elapsed_time_total += time.time() + # capture memory stats at the end and log summary (only log if enabled) + mem_post = self._get_mem_stats(empty_cache=True) + if self.config.enabled: + self._log_mem_summary(mem_pre, mem_post, self.config.expect_mem_change) + # log the result of the transform - log_msgs = [ - f"enabled={self.config.enabled}", - "skipped=True" if info.skipped else f"num_matches={info.num_matches}", - f"is_clean={info.is_clean}", - f"has_valid_shapes={info.has_valid_shapes}", - ] - self._log_info(", ".join(log_msgs)) - log_msgs_timing = [ - f"elapsed time: total={elapsed_time_total:.3f}s", - f"pre_cleanup={elapsed_time_pre_cleanup:.3f}s", - f"apply={elapsed_time_apply:.3f}s", - f"post_cleanup={elapsed_time_post_cleanup:.3f}s", - ] - self._log_info(", ".join(log_msgs_timing)) + self._log_transform_summary( + enabled=self.config.enabled, + skipped=info.skipped, + num_matches=info.num_matches, + elapsed_total=elapsed_time_total, + elapsed_pre=elapsed_time_pre_cleanup, + elapsed_apply=elapsed_time_apply, + elapsed_post=elapsed_time_post_cleanup, + ) ad_logger.debug(f"Model after {t_name}: {mod}") - # update + store new meta data + # update + store new meta data (transform history and memory history) history[t_name] = info autodeploy_meta[self._history_key] = history + + # store memory history + mem_history: Dict[str, Dict[str, Dict[str, float]]] = autodeploy_meta.get( + self._mem_history_key, {} + ) + mem_history[t_name] = {"pre": mem_pre.to_dict(), "post": mem_post.to_dict()} + autodeploy_meta[self._mem_history_key] = mem_history + self._set_autodeploy_meta(mod, autodeploy_meta) # return the graph module @@ -423,11 +499,161 @@ class BaseTransform(ABC): info = info & info_apply if info is not None else info_apply return mod, info + @final + def _log_warning(self, *args: any): + """Log a warning message with the transform key.""" + ad_logger.warning(*args) + @final def _log_info(self, *args: any): """Log a message with the transform key.""" ad_logger.info(*args) + @final + def _log_debug(self, *args: any): + """Log a message with the transform key.""" + ad_logger.debug(*args) + + @contextmanager + def _apply_logging_context(self): + """Context manager to add [APPLY] prefix to logs during transform execution.""" + original_log = ad_logger.log + apply_label = "[APPLY]" + + def _patched_log(severity, *msg): + # Prepend [APPLY] after any existing prefix + if msg: + first = msg[0] if isinstance(msg[0], str) else str(msg[0]) + return original_log(severity, f"{apply_label} {first}", *msg[1:]) + return original_log(severity, apply_label) + + ad_logger.log = _patched_log # type: ignore[assignment] + try: + yield + finally: + ad_logger.log = original_log # type: ignore[assignment] + + @final + def _log_cleanup_status(self, phase: str, action: str, reason: str = "") -> None: + """Log cleanup status with colored formatting. + + Args: + phase: "pre" or "post" + action: "ran" or "skipped" + reason: Description of what ran or why skipped + """ + label = f"{_Colors.CYAN}[{phase.upper()}-CLEANUP]{_Colors.RESET}" + if action == "skipped": + self._log_info(f"{label} {_Colors.DIM}skipped ({reason}){_Colors.RESET}") + else: + self._log_info(f"{label} {reason}") + + @final + def _log_transform_summary( + self, + enabled: bool, + skipped: bool, + num_matches: int, + elapsed_total: float, + elapsed_pre: float, + elapsed_apply: float, + elapsed_post: float, + ) -> None: + """Log transform summary with colored formatting. + + Args: + enabled: Whether the transform was enabled. + skipped: Whether the transform was skipped. + num_matches: Number of matches found. + elapsed_total: Total elapsed time in seconds. + elapsed_pre: Pre-cleanup elapsed time in seconds. + elapsed_apply: Apply elapsed time in seconds. + elapsed_post: Post-cleanup elapsed time in seconds. + """ + label = f"{_Colors.GREEN}[SUMMARY]{_Colors.RESET}" + timing_str = ( + f"{elapsed_total:.3f}s " + f"(pre={elapsed_pre:.3f}s, apply={elapsed_apply:.3f}s, post={elapsed_post:.3f}s)" + ) + + if not enabled: + self._log_info(f"{label} {_Colors.DIM}disabled{_Colors.RESET}") + elif skipped: + self._log_info(f"{label} {_Colors.DIM}skipped{_Colors.RESET} | time: {timing_str}") + else: + self._log_info(f"{label} matches={num_matches} | time: {timing_str}") + + @final + def _get_mem_stats(self, empty_cache: bool = True) -> MemStats: + """Get current CUDA memory statistics. + + Args: + empty_cache: Whether to empty the memory cache before getting the memory stats. + + Returns: + MemStats object with current memory values in GB. + """ + tot, free, resv, alloc, frag = get_mem_info(empty_cache=empty_cache, unit="GB") + return MemStats(tot=tot, free=free, resv=resv, alloc=alloc, frag=frag) + + @final + def _log_mem_summary(self, pre: MemStats, post: MemStats, expect_mem_change: bool) -> None: + """Log memory summary with diff between pre and post stats. + + Logs one of three cases: + 1. Expected mem change: info log, magenta color + 2. Unexpected mem change: warning log, yellow color + 3. No mem change: debug log, no colors + + Args: + pre: Memory stats captured before the transform. + post: Memory stats captured after the transform. + expect_mem_change: Whether this transform is expected to cause memory changes. + """ + diff = post.diff(pre) + + # Threshold for detecting significant memory changes (in GB) + mem_change_threshold = 0.005 + + # Check if there was a significant memory change + has_mem_change = ( + abs(diff.resv) >= mem_change_threshold + or abs(diff.alloc) >= mem_change_threshold + or abs(diff.frag) >= mem_change_threshold + ) + + def _fmt_val_with_delta(val: float, delta: float, color: str) -> str: + """Format value with optional delta in the specified color.""" + val_str = f"{val:6.2f}GB" + if abs(delta) < mem_change_threshold: + return val_str + sign = "+" if delta > 0 else "" + if color: + return f"{val_str} {_Colors.BOLD}{color}({sign}{delta:.2f}GB){_Colors.RESET}" + return f"{val_str} ({sign}{delta:.2f}GB)" + + def _fmt_parts(color: str) -> str: + """Format all memory parts with the specified color for deltas.""" + parts = [ + f"free: {_fmt_val_with_delta(post.free, diff.free, color)}", + f"resv: {_fmt_val_with_delta(post.resv, diff.resv, color)}", + f"alloc: {_fmt_val_with_delta(post.alloc, diff.alloc, color)}", + f"frag: {_fmt_val_with_delta(post.frag, diff.frag, color)}", + ] + return " | ".join(parts) + + if has_mem_change and expect_mem_change: + # Case 1: Expected mem change - info log, magenta + label = f"{_Colors.MAGENTA}[CUDA MEM DIFF (EXPECTED)]{_Colors.RESET}" + self._log_info(f"{label} {_fmt_parts(_Colors.MAGENTA)}") + elif has_mem_change and not expect_mem_change: + # Case 2: Unexpected mem change - warning log, yellow + label = f"{_Colors.YELLOW}[CUDA MEM DIFF (UNEXPECTED)]{_Colors.RESET}" + self._log_warning(f"{label} {_fmt_parts(_Colors.YELLOW)}") + else: + # Case 3: No mem change - debug log, no colors + self._log_debug(f"[CUDA MEM] {_fmt_parts('')}") + @final def _get_autodeploy_meta(self, mod: nn.Module) -> AutodeployMeta: """Get the autodeploy metadata from the graphmodule.""" @@ -450,8 +676,9 @@ class BaseTransform(ABC): clean_shape: bool, is_clean: bool, has_valid_shapes: bool, + phase: str, ) -> TransformInfo: - """Run graph cleanup before the transform. + """Run graph cleanup before or after the transform. Args: mod: The model to run cleanup on. @@ -459,23 +686,28 @@ class BaseTransform(ABC): clean_shape: Whether we want clean shapes after the transform. is_clean: The current cleanup status. has_valid_shapes: The current shape propagation status. + phase: The phase of cleanup ("pre" or "post"). Returns: An info object indicating the cleanup status after this function is called. """ # check if run cleanup depending on the config and info if clean_shape and not (is_clean and has_valid_shapes): - self._log_info("running graph cleanup (with shape_prop)") + self._log_cleanup_status(phase, "ran", "graph canonicalization + shape_prop") canonicalize_graph(mod) with lift_to_meta(mod) if placeholders_on_meta(mod) else nullcontext(): run_shape_prop(mod) is_clean = True has_valid_shapes = True elif clean_graph and not is_clean: - self._log_info("running graph cleanup (no shape_prop)") + self._log_cleanup_status(phase, "ran", "graph canonicalization") canonicalize_graph(mod) is_clean = True has_valid_shapes = False + elif not clean_graph and not clean_shape: + self._log_cleanup_status(phase, "skipped", "disabled") + else: + self._log_cleanup_status(phase, "skipped", "graph already clean") return TransformInfo(is_clean=is_clean, has_valid_shapes=has_valid_shapes) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/build_model.py b/tensorrt_llm/_torch/auto_deploy/transform/library/build_model.py index 083128a317..f33872e74a 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/build_model.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/build_model.py @@ -46,6 +46,9 @@ class BuildModel(BaseTransform): # build the model model = factory.build_model(self.config.device) + # update the kv cache config + cm.update_kv_cache_config(**factory.get_cache_config_updates()) + # by convention, we say the model is always clean info = TransformInfo(skipped=False, num_matches=1, is_clean=True, has_valid_shapes=True) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/hidden_states.py b/tensorrt_llm/_torch/auto_deploy/transform/library/hidden_states.py index e1c917160c..905c45c4b6 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/hidden_states.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/hidden_states.py @@ -21,13 +21,14 @@ import torch from torch._ops import OpOverloadPacket from torch.fx import GraphModule, Node +from .....llmapi.llm_args import KvCacheConfig from ...custom_ops.attention_interface import ( AttentionDescriptor, AttentionLayout, AttentionRegistry, - CacheConfig, - CacheInitializerDict, MHACallable, + ResourceHandler, + ResourceHandlerDict, SequenceInfo, ) from ...models.factory import ModelFactory @@ -195,12 +196,30 @@ class DetectHiddenStatesForCapture(BaseTransform): return gm, info +class HiddenStatesResourceHandler(ResourceHandler): + """A resource handler for hidden states.""" + + def __init__(self, hidden_size: int, dtype: torch.dtype) -> None: + """Initialize the HiddenStatesResourceHandler. + + Args: + hidden_size: The size of the hidden states resource. + dtype: The dtype of the hidden states resource. + """ + self.hidden_size = hidden_size + self.dtype = dtype + + def allocate(self, sequence_info: SequenceInfo) -> torch.Tensor: + return torch.empty( + sequence_info.max_num_tokens, + self.hidden_size, + device=sequence_info.device, + dtype=self.dtype, + ) + + @AttentionRegistry.register("cached_residual_add") class CachedResidualAdd(AttentionDescriptor): - @classmethod - def is_paged(cls) -> bool: - return True - @classmethod def get_attention_layout(cls) -> AttentionLayout: return "bsnd" @@ -219,15 +238,12 @@ class CachedResidualAdd(AttentionDescriptor): @classmethod def get_cache_initializers( - cls, source_attn_node: Node, cache_config: CacheConfig - ) -> CacheInitializerDict: + cls, source_attn_node: Node, cache_config: KvCacheConfig + ) -> ResourceHandlerDict: hidden_size = source_attn_node.meta["val"].shape[-1] hidden_type = source_attn_node.meta["val"].dtype - def _get_hidden_states_cache(si: SequenceInfo): - return torch.empty(si.max_num_tokens, hidden_size, device=si.device, dtype=hidden_type) - - return {"hidden_states_cache": _get_hidden_states_cache} + return {"hidden_states_cache": HiddenStatesResourceHandler(hidden_size, dtype=hidden_type)} @classmethod def get_standard_metadata_args(cls) -> List[str]: diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py b/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py index 43ea2c9e01..5d7c1b339d 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py @@ -18,7 +18,7 @@ import inspect import operator -from typing import Dict, List, Optional, Tuple, Type +from typing import List, Optional, Tuple, Type import torch import torch.nn as nn @@ -28,17 +28,13 @@ from torch.fx import GraphModule, Node from ...custom_ops.attention_interface import ( AttentionDescriptor, AttentionRegistry, - CacheConfig, Constant, PrepareMetadataCallable, ) -from ...distributed.common import all_gather_object, get_world_size -from ...distributed.common import is_initialized as is_distributed_initialized from ...models.factory import ModelFactory from ...shim.interface import CachedSequenceInterface from ...utils._graph import add_graph_input -from ...utils.cuda_mem_tracker import get_mem_info_in_mb -from ...utils.logger import ad_logger +from ...utils.cuda_mem_tracker import get_mem_info from ...utils.node_utils import is_op from ..interface import ( BaseTransform, @@ -53,9 +49,6 @@ class InsertCachedAttentionConfig(TransformConfig): """Configuration for the insert cached attention transform.""" backend: Optional[str] = Field(default=None, description="The attention backend to use.") - cache_config: CacheConfig = Field( - default_factory=CacheConfig, description="The custom cache configuration to use." - ) @TransformRegistry.register("insert_cached_attention") @@ -154,7 +147,6 @@ class InsertCachedAttention(BaseTransform): meta_nodes_std: List[Node], meta_nodes_extra: List[Node], cache_nodes: List[Node], - buffer_nodes: List[Node], constants: List[Constant], ): """Insert a cached attention node into the graph.""" @@ -166,7 +158,6 @@ class InsertCachedAttention(BaseTransform): *meta_nodes_std, *meta_nodes_extra, *cache_nodes, - *buffer_nodes, *constants, ), ) @@ -183,10 +174,6 @@ class InsertCachedAttention(BaseTransform): """Replace uncached source attention node with corresponding cached attn node.""" attn_descriptor = self.attn_descriptor - # run field-wise or to combine the cache config from the transform and the factory - # the transform config takes precedence over the factory config - cache_config = self.config.cache_config | factory.get_cache_config() - # Get all attention nodes and their info objects source_op = attn_descriptor.get_source_attention_op() @@ -199,10 +186,6 @@ class InsertCachedAttention(BaseTransform): skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True ) - # Sanity check - if cm.info.is_paged: - assert attn_descriptor.is_paged(), "Paged sequence info requires paged attention op." - # get standard metadata nodes for all source attention nodes meta_nodes_std = self._process_metadata_std(gm, cm) @@ -212,8 +195,6 @@ class InsertCachedAttention(BaseTransform): # Register host-side prepare_metadata function for attention descriptor. self._process_metadata_host(cm) - buffer_in_lookup: Dict[str, Node] = {} - # replace fused attention node with attention node that has kv cache num_cached_attn_replacements = 0 for idx, attn_node in enumerate(source_attn_nodes): @@ -222,22 +203,13 @@ class InsertCachedAttention(BaseTransform): # setup + store cache initializers and caches as input nodes cache_in_nodes = [] - for k, get_cache in attn_descriptor.get_cache_initializers( - attn_node, cache_config + for k, resource_handler in attn_descriptor.get_cache_initializers( + attn_node, cm.kv_cache_config ).items(): k_indexed = f"{k}_{idx}" - cm.add_cache(k_indexed, get_cache) + cm.add_resource(k_indexed, resource_handler) cache_in_nodes.append(self._process_cache_node(gm, k_indexed)) - # setup + store global buffer initializers and buffers as input nodes - # NOTE: we have to check against existing keys to make sure nothing is registered twice... - buffer_in_nodes = [] - for k, get_buffer in attn_descriptor.get_global_buffer_initializers(attn_node).items(): - if k not in buffer_in_lookup: - cm.add_cache(k, get_buffer) - buffer_in_lookup[k] = self._process_cache_node(gm, k) - buffer_in_nodes.append(buffer_in_lookup[k]) # store buffer nodes for this op - # retrieve constants for attention_op constants = attn_descriptor.get_constants(attn_node) @@ -249,7 +221,6 @@ class InsertCachedAttention(BaseTransform): meta_nodes_std, meta_nodes_extra, cache_in_nodes, - buffer_in_nodes, constants, ) @@ -276,27 +247,15 @@ class InsertCachedMLAAttention(InsertCachedAttention): pass -class ResizeKVCacheConfig(TransformConfig): - """Configuration for the resize kv cache transform.""" - - free_mem_ratio: float = Field( - default=0.0, ge=0.0, le=1.0, description="The fraction of available memory to occupy." - ) - - @TransformRegistry.register("resize_kv_cache") class ResizeKVCache(BaseTransform): - """Inflate the kv cache to occupy the available GPU memory. + """Resize the KV cache to occupy available GPU memory. - free_mem_ratio specifies the fraction of available memory to occupy. + This implements the two-phase approach: + 1. Run a forward pass to allocate intermediate memory (activations, workspaces, etc.) + 2. Call resize_kv_cache_manager() to recreate KVCacheManager with optimal capacity """ - config: ResizeKVCacheConfig - - @classmethod - def get_config_class(cls) -> Type[TransformConfig]: - return ResizeKVCacheConfig - def _apply_to_full_model( self, mod: nn.Module, @@ -304,100 +263,29 @@ class ResizeKVCache(BaseTransform): factory: ModelFactory, shared_config: SharedConfig, ) -> Tuple[nn.Module, TransformInfo]: - free_mem_ratio = self.config.free_mem_ratio - - free_mem, total_mem = get_mem_info_in_mb(empty_cache=True) - self._log_info(f"Free memory (MB): {free_mem}, Total memory (MB): {total_mem}") - current_cache_size = cm.current_cache_size_bytes() - current_kv_cache_size = getattr(cm, "current_kv_cache_size_bytes", None) - current_kv_cache_size = ( - current_kv_cache_size() if callable(current_kv_cache_size) else current_cache_size - ) - current_num_pages = cm.info.num_pages - self._log_info( - f"Current cache size (MB): {current_cache_size // 1024**2}, " - f"Current num pages: {current_num_pages}" - ) - if current_kv_cache_size != current_cache_size: - self._log_info( - f"Current KV-only cache size (MB): {current_kv_cache_size // 1024 // 1024}" - ) - - if free_mem_ratio == 0.0: - self._log_info(f"Skipping cache resize for {free_mem_ratio=}") + # check if we need a resize or not + if not cm.needs_resize(): return mod, TransformInfo( skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True ) - # TODO: the manual PyTorch workflow respects max_num_tokens if set and does _NOT_ resize - # the cache in this case. Should we do the same here? - - # Let's run a forward pass to get the memory usage + # Run a forward pass to get the extra memory usage cm.info.set_max_num_tokens_sample() - free_mem_pre, _ = get_mem_info_in_mb(empty_cache=True) - self._log_info(f"Free memory before forward pass (MB): {free_mem_pre}") - - # Reset peak memory stats to get the extra memory used during the forward pass - torch.cuda.reset_peak_memory_stats() - memory_allocated_before_forward_pass_mb = torch.cuda.memory_allocated() // 1024**2 try: mod(**cm.named_args) except torch.OutOfMemoryError as e: - ad_logger.error( + self._log_info( f"OutOfMemoryError in forward pass while trying to resize the kv-cache:\n{e}" ) raise e - peak_memory_during_forward_pass_mb = torch.cuda.max_memory_allocated() // 1024**2 - mem_used_during_forward_pass_mb = ( - peak_memory_during_forward_pass_mb - memory_allocated_before_forward_pass_mb - ) - self._log_info( - f"Peak memory uasge during forward pass (MB): {peak_memory_during_forward_pass_mb}" - ) - self._log_info( - f"Extra memory used during forward pass (MB): {mem_used_during_forward_pass_mb}" - ) + # NOTE: use fragmented memory without empty cache (peak forward memory + fragmented memory) + # as a proxy for the memory reserved for the forward pass. This is a rough estimate and + # may not be accurate. + *_, mem_reserved_for_forward = get_mem_info(empty_cache=False, unit="B") - # TODO (lucaslie): logic needs overhaul, too much going on. For now, this is just reverting - # to the original logic. Full overhaulwill be done as part of #10013 - free_mem_post, _ = get_mem_info_in_mb(empty_cache=False) - self._log_info(f"Free memory after forward pass (MB): {free_mem_post}") - - memory_for_forward_pass = free_mem_pre - free_mem_post - self._log_info(f"Memory for forward pass (MB): {memory_for_forward_pass}") - - # Compute new pages using KV-only bytes to avoid SSM/conv inflating per-page cost - # Reserve headroom to avoid OOM from other allocations (workspaces, cudagraph pools, etc.) - reserve_mb = max(1024, (total_mem * 5) // 100) # at least 1 GiB or 5% of total - available_mb = max(0, free_mem_post - reserve_mb) - - new_kv_total_bytes = int( - available_mb * 1024 * 1024 * free_mem_ratio + current_kv_cache_size - ) - per_page_bytes = max(1, current_kv_cache_size // max(1, current_num_pages)) - new_num_pages = int(new_kv_total_bytes // per_page_bytes) - - # Need to sync all the GPUs if distributed group is initialized - log_msg = f"Using local new_num_pages: {new_num_pages}" - if is_distributed_initialized(): - gathered_num_pages = [None] * get_world_size() - all_gather_object(gathered_num_pages, new_num_pages) - new_num_pages = min(gathered_num_pages) - log_msg = f"After all_gather - new_num_pages: {new_num_pages}" - - self._log_info(log_msg) - cm.resize_cache(new_num_pages) - - # Log the final cache size for performance measurement, do not remove this log. - final_cache_size_bytes = cm.current_cache_size_bytes() - final_cache_size_gb = final_cache_size_bytes / (1024**3) # Convert to GiB - self._log_info( - f"Final KV cache size after resize: {final_cache_size_gb:.2f} GiB ({new_num_pages} pages)" - ) - - # Free memory - torch.cuda.empty_cache() + # Resize - KVCacheManager will compute optimal capacity based on free memory + cm.resize_kv_cache_manager(mem_reserved_for_forward) info = TransformInfo( skipped=False, @@ -411,6 +299,13 @@ class ResizeKVCache(BaseTransform): @TransformRegistry.register("initialize_cache") class InitializeCache(BaseTransform): + """Initialize KV caches using KVCacheManager. + + Gets kv_cache_config from shared_config.ad_config and creates the KVCacheManager + in estimation mode with conservative capacity. The ResizeKVCache transform will + later recreate it with optimal capacity after measuring memory usage. + """ + def _apply_to_full_model( self, mod: nn.Module, @@ -418,7 +313,9 @@ class InitializeCache(BaseTransform): factory: ModelFactory, shared_config: SharedConfig, ) -> Tuple[nn.Module, TransformInfo]: - num_caches = cm.initialize_caches() + # Initialize with estimation mode + # This allows resize_kv_cache to recreate with correct capacity after measuring memory + num_caches = cm.initialize_resources() self._log_info(f"Initialized {num_caches} caches for cached attention") info = TransformInfo( diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache_transformers.py b/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache_transformers.py index 1f34445647..ce1dd0df3e 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache_transformers.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache_transformers.py @@ -163,8 +163,8 @@ def get_cached_attn( query, key, value, - # metadata+caches+buffers with name lookup set up during kvcache transform - *[kwargs[k] for k in module._node_ref.meta["metadata_cache_buffer_keys"]], + # metadata+caches with name lookup set up during kvcache transform + *[kwargs[k] for k in module._node_ref.meta["metadata_cache_keys"]], # constants set up during kvcache transform *module._node_ref.meta["constants"], ) @@ -242,17 +242,11 @@ class HFReplaceCachedAttn(InsertCachedAttention): meta_nodes_std: List[Node], meta_nodes_extra: List[Node], cache_nodes: List[Node], - buffer_nodes: List[Node], constants: List[Constant], ): """Here we now need to actually do the correct mapping of the cached attn nodes.""" - # store reference to metadata, caches, buffers, and constants for this attn node - attn_node.meta["metadata_cache_buffer_keys"] = ( - *meta_nodes_std, - *meta_nodes_extra, - *cache_nodes, - *buffer_nodes, - ) + # store reference to metadata, caches, and constants for this attn node + attn_node.meta["metadata_cache_keys"] = (*meta_nodes_std, *meta_nodes_extra, *cache_nodes) attn_node.meta["constants"] = constants def _apply_to_full_model( diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/load_weights.py b/tensorrt_llm/_torch/auto_deploy/transform/library/load_weights.py index a2aefef706..0579543054 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/load_weights.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/load_weights.py @@ -8,6 +8,7 @@ from pydantic import Field from ...models.factory import ModelFactory from ...shim.interface import CachedSequenceInterface from ...utils._graph import move_to_device +from ...utils.cuda_mem_tracker import bytes_to from ..interface import ( BaseTransform, SharedConfig, @@ -43,10 +44,11 @@ class LoadWeightsToDevice(BaseTransform): factory: ModelFactory, shared_config: SharedConfig, ) -> Tuple[nn.Module, TransformInfo]: - factory.load_or_random_init( - mod, - device=self.config.checkpoint_device or cm.device, - ) + params_size = sum(p.numel() * p.element_size() for p in mod.parameters()) + total_size_GB = bytes_to(params_size, unit="GB") + self._log_info(f"Estimated parameters memory: {total_size_GB:.2f} GB") + + factory.load_or_random_init(mod, device=self.config.checkpoint_device or cm.device) move_to_device(mod, cm.device) info = TransformInfo(skipped=False, num_matches=0, is_clean=True, has_valid_shapes=True) diff --git a/tensorrt_llm/_torch/auto_deploy/utils/cuda_mem_tracker.py b/tensorrt_llm/_torch/auto_deploy/utils/cuda_mem_tracker.py index e73cec39e7..9d9ce45442 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/cuda_mem_tracker.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/cuda_mem_tracker.py @@ -16,12 +16,15 @@ import gc from contextlib import contextmanager -from typing import Tuple +from typing import Literal, Tuple, Union import torch from .logger import ad_logger +Number = Union[int, float] +ByteUnit = Literal["B", "KB", "MB", "GB", "TB"] + @contextmanager def cuda_memory_tracker(logger=ad_logger): @@ -43,10 +46,38 @@ def cuda_memory_tracker(logger=ad_logger): logger.warning(f"Potential memory leak detected, leaked memory: {leaked} bytes") -def get_mem_info_in_mb(empty_cache: bool = True) -> Tuple[int, int]: +def bytes_to(bytes: int, *more_bytes: int, unit: ByteUnit) -> Union[Number, Tuple[Number, ...]]: + units = {"KB": 1 << 10, "MB": 1 << 20, "GB": 1 << 30, "TB": 1 << 40} + bytes_converted = (bytes,) + more_bytes + unit = unit.upper() + if unit != "B": + bytes_converted = tuple(float(x) / units[unit.upper()] for x in bytes_converted) + return bytes_converted if more_bytes else bytes_converted[0] + + +def get_mem_info( + empty_cache: bool = True, unit: ByteUnit = "B" +) -> Tuple[Number, Number, Number, Number, Number]: + """Get the memory information of the current device. + + Args: + empty_cache: Whether to empty the memory cache. + unit: The unit of the memory information. Defaults to bytes. + + Returns: + A tuple of the + - total memory, + - free memory, + - reserved memory, + - allocated memory, + - fragmented memory + in the specified unit. + """ if empty_cache: # Clear the memory cache to get the exact free memory torch.cuda.empty_cache() free_mem, total_mem = torch.cuda.mem_get_info() - MB = 1024**2 - return free_mem // MB, total_mem // MB + res_mem = torch.cuda.memory_reserved() + alloc_mem = torch.cuda.memory_allocated() + frag_mem = res_mem - alloc_mem + return bytes_to(total_mem, free_mem, res_mem, alloc_mem, frag_mem, unit=unit) diff --git a/tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py b/tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py index ff0f0330e7..56d3cbdc1c 100644 --- a/tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py @@ -193,6 +193,7 @@ class MambaCacheManager(BaseResourceManager): dtype=torch.int32, device=device) + @torch.inference_mode() def _prepare_mamba_cache_blocks(self, request_ids: List[int]): self.state_indices_list.clear() for r in request_ids: diff --git a/tests/integration/defs/accuracy/test_llm_api_autodeploy.py b/tests/integration/defs/accuracy/test_llm_api_autodeploy.py index df673f5bd3..9dd02a392e 100644 --- a/tests/integration/defs/accuracy/test_llm_api_autodeploy.py +++ b/tests/integration/defs/accuracy/test_llm_api_autodeploy.py @@ -40,10 +40,10 @@ class TestLlama3_1_8B(LlmapiAccuracyTestHarness): # Set it explicitly here to 8192 which is the default in build_config. "max_num_tokens": 8192, "skip_loading_weights": False, + "kv_cache_config": { + "free_gpu_memory_fraction": 0.7 + }, "transforms": { - "resize_kv_cache": { - "free_mem_ratio": 0.7 - }, "compile_model": { "backend": "torch-cudagraph", @@ -55,7 +55,7 @@ class TestLlama3_1_8B(LlmapiAccuracyTestHarness): if enable_chunked_prefill: config["enable_chunked_prefill"] = True config[ - "max_num_tokens"] = 512 # NOTE: must be > max(attn_page_size, max_batch_size) + "max_num_tokens"] = 512 # NOTE: must be > max(tokens_per_block, max_batch_size) return config def get_default_sampling_params(self): @@ -110,7 +110,8 @@ class TestNemotronH(LlmapiAccuracyTestHarness): "trust_remote_code": True, # SSMs do not support cache reuse. "kv_cache_config": { - "enable_block_reuse": False + "enable_block_reuse": False, + "free_gpu_memory_fraction": 0.7 }, # Keep max_batch_size as in the PyTorch test to avoid OOM "max_batch_size": 128, @@ -120,9 +121,6 @@ class TestNemotronH(LlmapiAccuracyTestHarness): "max_num_tokens": 8192, "skip_loading_weights": False, "transforms": { - "resize_kv_cache": { - "free_mem_ratio": 0.7 - }, "compile_model": { "backend": "torch-cudagraph", "cuda_graph_batch_sizes": [1, 2, 4, 8, 16, 32, 64, 128], @@ -132,7 +130,7 @@ class TestNemotronH(LlmapiAccuracyTestHarness): if enable_chunked_prefill: config["enable_chunked_prefill"] = True config[ - "max_num_tokens"] = 512 # NOTE: must be > max(attn_page_size, max_batch_size) + "max_num_tokens"] = 512 # NOTE: must be > max(tokens_per_block, max_batch_size) return config def get_default_sampling_params(self): @@ -169,7 +167,10 @@ class TestNemotronMOE(LlmapiAccuracyTestHarness): "trust_remote_code": True, # SSMs do not support cache reuse. "kv_cache_config": { - "enable_block_reuse": False + "enable_block_reuse": False, + "free_gpu_memory_fraction": 0.7 + # NOTE: some accuracy benchmarks may require fp32 precision for mamba cache + # "mamba_ssm_cache_dtype": "float32", }, # Keep max_batch_size as in the PyTorch test to avoid OOM "max_batch_size": 128, @@ -180,7 +181,6 @@ class TestNemotronMOE(LlmapiAccuracyTestHarness): "max_num_tokens": 8192, "skip_loading_weights": False, "compile_backend": "torch-cudagraph", - "free_mem_ratio": 0.7, "cuda_graph_batch_sizes": [1, 2, 4, 8, 16, 32, 64, 128], "transforms": { "detect_sharding": { @@ -191,12 +191,6 @@ class TestNemotronMOE(LlmapiAccuracyTestHarness): "stage": "compile", "enabled": True, }, - # NOTE: some accuracy benchmarks may require fp32 precision for mamba cache - # "insert_cached_ssm_attention": { - # "cache_config": { - # "mamba_dtype": "float32", - # }, - # }, } } @@ -213,7 +207,8 @@ class TestNemotronMOE(LlmapiAccuracyTestHarness): kwargs = self.get_default_kwargs() # TODO: multi-stream MOE seems to increase the memory usage kwargs["max_batch_size"] = 32 - kwargs["free_mem_ratio"] = 0.4 + kwargs["kv_cache_config"] = {"free_gpu_memory_fraction": 0.4} + sampling_params = self.get_default_sampling_params() with AutoDeployLLM(model=self.MODEL_PATH_BF16, tokenizer=self.MODEL_PATH_BF16, **kwargs) as llm: @@ -279,7 +274,6 @@ class TestNemotronSuperV3(LlmapiAccuracyTestHarness): "trust_remote_code": True, "skip_loading_weights": False, "compile_backend": "torch-cudagraph", - "free_mem_ratio": 0.9, "max_batch_size": 128, "max_seq_len": self.MAX_SEQ_LEN, "max_num_tokens": self.MAX_SEQ_LEN, diff --git a/tests/integration/defs/perf/test_perf.py b/tests/integration/defs/perf/test_perf.py index 62a671c099..71ad1b97af 100644 --- a/tests/integration/defs/perf/test_perf.py +++ b/tests/integration/defs/perf/test_perf.py @@ -576,7 +576,6 @@ class PerfTestConfig: extra: bool = False, # _autodeploy backend specific parameters ad_compile_backend: str = "torch-opt", - free_mem_ratio: float = 0.9, extra_runtime: str = "trtllm", skip_loading_weights: bool = False, ): @@ -636,7 +635,6 @@ class PerfTestConfig: self.extra = extra # _autodeploy backend specific parameters self.ad_compile_backend = ad_compile_backend - self.free_mem_ratio = free_mem_ratio self.extra_runtime = extra_runtime self.skip_loading_weights = skip_loading_weights # Just build engines @@ -1422,9 +1420,6 @@ class MultiMetricPerfTest(AbstractPerfScriptTestClass): 'compile_model': { 'backend': self._config.ad_compile_backend }, - 'resize_kv_cache': { - 'free_mem_ratio': self._config.free_mem_ratio - }, }, 'runtime': self._config.extra_runtime, 'skip_loading_weights': self._config.skip_loading_weights diff --git a/tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py b/tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py index 89e18351f3..3379ec91bc 100644 --- a/tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py +++ b/tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py @@ -18,11 +18,11 @@ from tensorrt_llm._torch.auto_deploy.transform.library.sharding import ShardingT class FakeFactory(ModelFactory): - """Dummy factory to pass cache_config for testing.""" + """Dummy factory to pass cache_config_updates for testing.""" - def __init__(self, model=None, cache_config=None, quant_config=None): + def __init__(self, model=None, cache_config_updates=None, quant_config=None): self._model = model - self.cache_config = cache_config + self.cache_config_updates = cache_config_updates self.quant_config = quant_config def build_model(self, device: str): @@ -34,8 +34,8 @@ class FakeFactory(ModelFactory): def _load_checkpoint(self, model, device): return - def get_cache_config(self): - return self.cache_config + def get_cache_config_updates(self): + return self.cache_config_updates def get_quant_config(self): return self.quant_config diff --git a/tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py b/tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py index ae3a682cf2..39c071cb20 100644 --- a/tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py +++ b/tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py @@ -511,7 +511,10 @@ def get_small_model_config(model_hub_id: str, **llm_args_kwargs) -> Dict[str, An # add some defaults to llm_args llm_args["skip_loading_weights"] = True # No weight loading to speed up things - llm_args["attn_page_size"] = 4 # Make sure paging is activated despite small max_tokens + llm_args["kv_cache_config"] = { + "tokens_per_block": 4, # Make sure paging is activated despite small max_tokens + "free_gpu_memory_fraction": 0.0, # No resizing of the cache to keep the mem footprint small + } llm_args["max_batch_size"] = 2 # Minimum batching to speed up things # update with custom llm_args kwargs llm_args.update(llm_args_kwargs) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_flashinfer_attention_op.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_flashinfer_attention_op.py index 9ce3dcd8ec..66b2e228f3 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_flashinfer_attention_op.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_flashinfer_attention_op.py @@ -84,8 +84,7 @@ def test_flashinfer_attention_op_context(seq_length, n_heads, batch_size, dtype, ) # make sure planner is initialized - workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device) - _GlobalFlashInferPlanner.init_workspace(workspace) + _GlobalFlashInferPlanner.reset(torch.device(device)) batch_indices, positions = flashinfer.get_batch_indices_positions( qo_indptr, @@ -118,8 +117,6 @@ def test_flashinfer_attention_op_context(seq_length, n_heads, batch_size, dtype, # CACHES k_cache, v_cache, - # BUFFERS - workspace, # CONSTANTS None, 1.0, @@ -229,8 +226,7 @@ def test_flashinfer_attention_op_decode( ) # make sure planner is initialized - workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device) - _GlobalFlashInferPlanner.init_workspace(workspace) + _GlobalFlashInferPlanner.reset(torch.device(device)) batch_indices, positions = flashinfer.get_batch_indices_positions( qo_indptr, @@ -262,8 +258,6 @@ def test_flashinfer_attention_op_decode( # CACHES k_cache, v_cache, - # BUFFERS - workspace, # CONSTANTS None, 1.0, @@ -361,8 +355,7 @@ def test_flashinfer_attention_context_and_generate( ) # make sure planner is initialized - workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device) - _GlobalFlashInferPlanner.init_workspace(workspace) + _GlobalFlashInferPlanner.reset(torch.device(device)) batch_indices, positions = flashinfer.get_batch_indices_positions( qo_indptr, @@ -395,8 +388,6 @@ def test_flashinfer_attention_context_and_generate( # CACHES k_cache, v_cache, - # BUFFERS - workspace, # CONSTANTS None, 1.0, @@ -454,7 +445,7 @@ def test_flashinfer_attention_context_and_generate( v_3 = torch.randn(BATCH_SIZE, 1, N_HEADS * D_HEAD, dtype=DTYPE).to(device) # Create FlashInferAttention class before calling the custom op - _GlobalFlashInferPlanner.reset() + _GlobalFlashInferPlanner.reset(torch.device(device)) batch_indices, positions = flashinfer.get_batch_indices_positions( qo_indptr, @@ -485,8 +476,6 @@ def test_flashinfer_attention_context_and_generate( # CACHES k_cache, v_cache, - # BUFFERS - workspace, # CONSTANTS None, 1.0, @@ -586,8 +575,7 @@ def test_flashinfer_attention_op_context_input_pos(seq, batch_size, n_heads, dty ) # make sure planner is initialized - workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device) - _GlobalFlashInferPlanner.init_workspace(workspace) + _GlobalFlashInferPlanner.reset(torch.device(device)) batch_indices, positions = flashinfer.get_batch_indices_positions( qo_indptr, @@ -620,8 +608,6 @@ def test_flashinfer_attention_op_context_input_pos(seq, batch_size, n_heads, dty # CACHES k_cache, v_cache, - # BUFFERS - workspace, # CONSTANTS None, 1.0, @@ -748,8 +734,7 @@ def test_flashinfer_attention_with_fp8_cache( v_cache = v_cache.to(torch.float8_e4m3fn) # make sure planner is initialized - workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device) - _GlobalFlashInferPlanner.init_workspace(workspace) + _GlobalFlashInferPlanner.reset(torch.device(device)) batch_indices, positions = flashinfer.get_batch_indices_positions( qo_indptr, @@ -782,8 +767,6 @@ def test_flashinfer_attention_with_fp8_cache( # CACHES k_cache, v_cache, - # BUFFERS - workspace, # CONSTANTS None, K_SCALE, @@ -859,8 +842,7 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de seq_len_with_cache_host = (offsets + seq_len_tensor).cpu() # make sure planner is initialized - workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device) - _GlobalFlashInferPlanner.init_workspace(workspace) + _GlobalFlashInferPlanner.reset(torch.device(device)) batch_indices, positions = flashinfer.get_batch_indices_positions( qo_indptr, @@ -891,8 +873,6 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de # CACHES k_cache, v_cache, - # BUFFERS - workspace, # CONSTANTS None, 1.0, @@ -956,7 +936,7 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de seq_len_with_cache2_host = (offsets2 + seq_len_tensor2).cpu() # Create FlashInferAttention class before calling the custom op - _GlobalFlashInferPlanner.reset() + _GlobalFlashInferPlanner.reset(torch.device(device)) batch_indices, positions = flashinfer.get_batch_indices_positions( qo_indptr2, @@ -987,8 +967,6 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de # CACHES k_cache, v_cache, - # BUFFERS - workspace, # CONSTANTS None, 1.0, diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_resource_handlers.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_resource_handlers.py new file mode 100644 index 0000000000..e581e8a451 --- /dev/null +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_resource_handlers.py @@ -0,0 +1,237 @@ +"""Unit tests for ResourceHandler classes in attention_interface.py. + +Tests the new resource handler abstraction for cache management: +- PagedResourceHandler (for paged KV caches) +- StateResourceHandler (for SSM/conv states) +- UnpagedResourceHandler (for unpaged local caches) +- AttentionDescriptor.resolve_cache_dtype() +""" + +import pytest +import torch + +from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import ( + AttentionDescriptor, + ManagedResourceHandler, + PagedResourceHandler, + ResourceHandler, + SequenceInfo, + StateResourceHandler, + UnpagedResourceHandler, +) + +# ============================================================================= +# PagedResourceHandler Tests +# ============================================================================= + + +def test_paged_handler_stores_token_shape_and_dtype(): + """Verify PagedResourceHandler stores token_shape and dtype correctly.""" + handler = PagedResourceHandler(8, 64, dtype=torch.float16) + assert handler.token_shape == (8, 64) + assert handler.dtype == torch.float16 + + +def test_paged_handler_single_dimension_token_shape(): + """Test PagedResourceHandler with single dimension token shape.""" + handler = PagedResourceHandler(128, dtype=torch.bfloat16) + assert handler.token_shape == (128,) + assert handler.dtype == torch.bfloat16 + + +def test_paged_handler_multi_dimension_token_shape(): + """Test PagedResourceHandler with multiple dimension token shape.""" + handler = PagedResourceHandler(4, 8, 16, dtype=torch.float32) + assert handler.token_shape == (4, 8, 16) + assert handler.dtype == torch.float32 + + +def test_paged_handler_allocate_raises_not_implemented(): + """Verify PagedResourceHandler.allocate() raises NotImplementedError.""" + handler = PagedResourceHandler(8, 64, dtype=torch.float16) + seq_info = SequenceInfo(max_seq_len=128, max_batch_size=4) + + with pytest.raises(NotImplementedError, match="Managed resources should not be allocated"): + handler.allocate(seq_info) + + +def test_paged_handler_is_resource_handler(): + """Verify PagedResourceHandler is a ResourceHandler subclass.""" + handler = PagedResourceHandler(8, 64, dtype=torch.float16) + assert isinstance(handler, ResourceHandler) + + +def test_paged_handler_is_managed_resource(): + """Verify PagedResourceHandler is a ManagedResourceHandler.""" + handler = PagedResourceHandler(8, 64, dtype=torch.float16) + assert isinstance(handler, ManagedResourceHandler) + + +# ============================================================================= +# StateResourceHandler Tests +# ============================================================================= + + +def test_state_handler_stores_state_shape_and_dtype(): + """Verify StateResourceHandler stores state_shape and dtype correctly.""" + handler = StateResourceHandler(4, 64, 16, dtype=torch.bfloat16) + assert handler.state_shape == (4, 64, 16) + assert handler.dtype == torch.bfloat16 + + +def test_state_handler_single_dimension_state_shape(): + """Test StateResourceHandler with single dimension state shape.""" + handler = StateResourceHandler(256, dtype=torch.float16) + assert handler.state_shape == (256,) + assert handler.dtype == torch.float16 + + +def test_state_handler_conv_state_shape(): + """Test StateResourceHandler with typical conv state shape [in_channels, kernel_size-1].""" + handler = StateResourceHandler(512, 3, dtype=torch.bfloat16) + assert handler.state_shape == (512, 3) + assert handler.dtype == torch.bfloat16 + + +def test_state_handler_ssm_state_shape(): + """Test StateResourceHandler with typical SSM state shape [num_heads, head_dim, ssm_state_size].""" + handler = StateResourceHandler(4, 64, 16, dtype=torch.float32) + assert handler.state_shape == (4, 64, 16) + assert handler.dtype == torch.float32 + + +def test_state_handler_allocate_raises_not_implemented(): + """Verify StateResourceHandler.allocate() raises NotImplementedError.""" + handler = StateResourceHandler(4, 64, 16, dtype=torch.bfloat16) + seq_info = SequenceInfo(max_seq_len=128, max_batch_size=4) + + with pytest.raises(NotImplementedError, match="Managed resources should not be allocated"): + handler.allocate(seq_info) + + +def test_state_handler_is_resource_handler(): + """Verify StateResourceHandler is a ResourceHandler subclass.""" + handler = StateResourceHandler(4, 64, 16, dtype=torch.bfloat16) + assert isinstance(handler, ResourceHandler) + + +def test_state_handler_is_managed_resource(): + """Verify StateResourceHandler is a ManagedResourceHandler.""" + handler = StateResourceHandler(4, 64, 16, dtype=torch.bfloat16) + assert isinstance(handler, ManagedResourceHandler) + + +# ============================================================================= +# UnpagedResourceHandler Tests +# ============================================================================= + + +def test_unpaged_handler_stores_token_shape_and_dtype(): + """Verify UnpagedResourceHandler stores token_shape and dtype correctly.""" + handler = UnpagedResourceHandler(8, 64, dtype=torch.float16) + assert handler.token_shape == (8, 64) + assert handler.dtype == torch.float16 + + +def test_unpaged_handler_single_dimension_token_shape(): + """Test UnpagedResourceHandler with single dimension token shape.""" + handler = UnpagedResourceHandler(128, dtype=torch.bfloat16) + assert handler.token_shape == (128,) + assert handler.dtype == torch.bfloat16 + + +@pytest.mark.parametrize( + "num_kv_heads,head_dim,dtype", + [ + (8, 64, torch.float16), + (4, 128, torch.bfloat16), + (1, 64, torch.float32), + ], +) +def test_unpaged_handler_allocate_returns_correct_shape(num_kv_heads, head_dim, dtype): + """Verify UnpagedResourceHandler.allocate() returns tensor with correct shape.""" + max_batch_size = 4 + max_seq_len = 128 + + handler = UnpagedResourceHandler(num_kv_heads, head_dim, dtype=dtype) + seq_info = SequenceInfo(max_seq_len=max_seq_len, max_batch_size=max_batch_size) + seq_info.to("cuda") + + tensor = handler.allocate(seq_info) + + expected_shape = (seq_info.max_num_state_slots, max_seq_len, num_kv_heads, head_dim) + assert tensor.shape == expected_shape + assert tensor.dtype == dtype + assert tensor.device.type == "cuda" + + +def test_unpaged_handler_allocate_correct_device(): + """Verify UnpagedResourceHandler allocated tensor is on the correct device.""" + handler = UnpagedResourceHandler(8, 64, dtype=torch.float16) + seq_info = SequenceInfo(max_seq_len=128, max_batch_size=4) + seq_info.to("cuda") + + tensor = handler.allocate(seq_info) + assert tensor.device == seq_info.device + + +def test_unpaged_handler_is_resource_handler(): + """Verify UnpagedResourceHandler is a ResourceHandler subclass.""" + handler = UnpagedResourceHandler(8, 64, dtype=torch.float16) + assert isinstance(handler, ResourceHandler) + + +def test_unpaged_handler_is_not_managed_resource(): + """Verify UnpagedResourceHandler is NOT a ManagedResourceHandler.""" + handler = UnpagedResourceHandler(8, 64, dtype=torch.float16) + assert not isinstance(handler, ManagedResourceHandler) + + +# ============================================================================= +# AttentionDescriptor.resolve_cache_dtype() Tests +# ============================================================================= + + +def test_resolve_cache_dtype_auto_returns_fallback_float16(): + """Test 'auto' returns the fallback dtype (float16).""" + result = AttentionDescriptor.resolve_cache_dtype("auto", torch.float16) + assert result == torch.float16 + + +def test_resolve_cache_dtype_auto_returns_fallback_bfloat16(): + """Test 'auto' returns the fallback dtype (bfloat16).""" + result = AttentionDescriptor.resolve_cache_dtype("auto", torch.bfloat16) + assert result == torch.bfloat16 + + +def test_resolve_cache_dtype_auto_returns_fallback_float32(): + """Test 'auto' returns the fallback dtype (float32).""" + result = AttentionDescriptor.resolve_cache_dtype("auto", torch.float32) + assert result == torch.float32 + + +def test_resolve_cache_dtype_explicit_float16(): + """Test explicit 'float16' dtype string resolves correctly.""" + result = AttentionDescriptor.resolve_cache_dtype("float16", torch.bfloat16) + assert result == torch.float16 + + +def test_resolve_cache_dtype_explicit_bfloat16(): + """Test explicit 'bfloat16' dtype string resolves correctly.""" + result = AttentionDescriptor.resolve_cache_dtype("bfloat16", torch.float16) + assert result == torch.bfloat16 + + +def test_resolve_cache_dtype_explicit_float32(): + """Test explicit 'float32' dtype string resolves correctly.""" + result = AttentionDescriptor.resolve_cache_dtype("float32", torch.float16) + assert result == torch.float32 + + +@pytest.mark.skipif( + torch.cuda.get_device_capability(0) < (8, 9), reason="FP8 requires compute capability >= 8.9" +) +def test_resolve_cache_dtype_explicit_fp8(): + """Test explicit 'fp8' dtype string resolves correctly.""" + result = AttentionDescriptor.resolve_cache_dtype("fp8", torch.float16) + assert result == torch.float8_e4m3fn diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py index 8ceb287bc8..fa26211b95 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py @@ -329,7 +329,7 @@ def test_trtllm_fused_moe_fp8( pytest.skip(f"top_k ({top_k}) cannot be greater than num_experts ({num_experts})") assert itype == torch.float8_e4m3fn and wtype == torch.float8_e4m3fn, ( - "FP8 test only supports float8_e4m3fn" + "FP8 test only supports torch.float8_e4m3fn" ) assert otype == torch.bfloat16 or otype == torch.float16, ( "FP8 test only supports bfloat16 or float16 output type" diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_cached_sequence_interface.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_cached_sequence_interface.py new file mode 100644 index 0000000000..7e52623808 --- /dev/null +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_cached_sequence_interface.py @@ -0,0 +1,724 @@ +"""Unit tests for CachedSequenceInterface in interface.py. + +Tests the refactored CachedSequenceInterface which now: +- Creates SequenceInfo internally with tokens_per_block from KvCacheConfig +- Manages resources via KVCacheManager or MambaHybridCacheManager +- Supports paged resources (KV caches) and state resources (SSM states) +""" + +import pytest +import torch + +from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import ( + PagedResourceHandler, + SequenceInfo, + StateResourceHandler, + UnpagedResourceHandler, +) +from tensorrt_llm._torch.auto_deploy.shim.interface import CachedSequenceInterface +from tensorrt_llm._torch.pyexecutor.mamba_cache_manager import MambaHybridCacheManager +from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager +from tensorrt_llm.llmapi.llm_args import KvCacheConfig + +# ============================================================================= +# Fixtures +# ============================================================================= + + +@pytest.fixture +def default_kv_cache_config(): + """KvCacheConfig with default settings.""" + return KvCacheConfig() + + +@pytest.fixture +def paged_kv_cache_config(): + """KvCacheConfig with paging enabled and no resizing.""" + return KvCacheConfig( + tokens_per_block=32, + max_tokens=1024, + free_gpu_memory_fraction=0.0, # Disable dynamic resizing + ) + + +@pytest.fixture +def resizable_kv_cache_config(): + """KvCacheConfig with dynamic resizing enabled.""" + return KvCacheConfig( + tokens_per_block=32, + max_tokens=1024, + free_gpu_memory_fraction=0.5, + ) + + +# ============================================================================= +# Initialization Tests +# ============================================================================= + + +def test_init_creates_sequence_info_with_tokens_per_block(paged_kv_cache_config): + """Verify SequenceInfo is created with correct tokens_per_block from KvCacheConfig.""" + interface = CachedSequenceInterface( + max_seq_len=128, + max_batch_size=4, + device="cuda", + kv_cache_config=paged_kv_cache_config, + ) + + assert interface.info.tokens_per_block == paged_kv_cache_config.tokens_per_block + assert interface.info.max_seq_len == 128 + assert interface.info.max_batch_size == 4 + + +def test_init_uses_default_kv_cache_config_when_not_provided(): + """Verify default KvCacheConfig is used when not provided.""" + interface = CachedSequenceInterface( + max_seq_len=128, + max_batch_size=4, + device="cuda", + ) + + # Default KvCacheConfig should be created + assert interface.kv_cache_config is not None + # Default tokens_per_block is 64 in KvCacheConfig + assert interface.info.tokens_per_block == interface.kv_cache_config.tokens_per_block + + +def test_init_propagates_max_num_tokens(): + """Verify max_num_tokens propagates to SequenceInfo.""" + max_num_tokens = 512 + interface = CachedSequenceInterface( + max_seq_len=128, + max_batch_size=4, + max_num_tokens=max_num_tokens, + device="cuda", + ) + + assert interface.info.max_num_tokens == max_num_tokens + + +def test_init_propagates_vocab_size_padded(): + """Verify vocab_size_padded propagates to SequenceInfo.""" + vocab_size_padded = 32000 + interface = CachedSequenceInterface( + max_seq_len=128, + max_batch_size=4, + vocab_size_padded=vocab_size_padded, + device="cuda", + ) + + assert interface.info.vocab_size_padded == vocab_size_padded + + +def test_init_stores_device(): + """Verify device is stored correctly.""" + interface = CachedSequenceInterface( + max_seq_len=128, + max_batch_size=4, + device="cuda:0", + ) + + assert interface.device == "cuda:0" + + +def test_init_default_device_is_cuda(): + """Verify default device is 'cuda' when not specified.""" + interface = CachedSequenceInterface( + max_seq_len=128, + max_batch_size=4, + ) + + assert interface.device == "cuda" + + +# ============================================================================= +# Resource Registration Tests +# ============================================================================= + + +def test_add_resource_paged_handler(paged_kv_cache_config): + """Test adding a PagedResourceHandler resource.""" + interface = CachedSequenceInterface( + max_seq_len=128, + max_batch_size=4, + device="cuda", + kv_cache_config=paged_kv_cache_config, + ) + + handler = PagedResourceHandler(8, 64, dtype=torch.float16) + interface.add_resource("k_cache_0", handler) + + assert "k_cache_0" in interface._resource_lookup + assert interface._resource_lookup["k_cache_0"] is handler + + +def test_add_resource_state_handler(paged_kv_cache_config): + """Test adding a StateResourceHandler resource.""" + interface = CachedSequenceInterface( + max_seq_len=128, + max_batch_size=4, + device="cuda", + kv_cache_config=paged_kv_cache_config, + ) + + handler = StateResourceHandler(4, 64, 16, dtype=torch.bfloat16) + interface.add_resource("ssm_state_0", handler) + + assert interface._resource_lookup["ssm_state_0"] is handler + + +def test_add_resource_unpaged_handler(paged_kv_cache_config): + """Test adding an UnpagedResourceHandler resource.""" + interface = CachedSequenceInterface( + max_seq_len=128, + max_batch_size=4, + device="cuda", + kv_cache_config=paged_kv_cache_config, + ) + + handler = UnpagedResourceHandler(8, 64, dtype=torch.float16) + interface.add_resource("unpaged_cache", handler) + + assert "unpaged_cache" in interface._resource_lookup + assert interface._resource_lookup["unpaged_cache"] is handler + + +def test_add_multiple_resources(paged_kv_cache_config): + """Test adding multiple resources of different types.""" + interface = CachedSequenceInterface( + max_seq_len=128, + max_batch_size=4, + device="cuda", + kv_cache_config=paged_kv_cache_config, + ) + + k_handler = PagedResourceHandler(8, 64, dtype=torch.float16) + v_handler = PagedResourceHandler(8, 64, dtype=torch.float16) + ssm_handler = StateResourceHandler(4, 64, 16, dtype=torch.bfloat16) + + interface.add_resource("k_cache_0", k_handler) + interface.add_resource("v_cache_0", v_handler) + interface.add_resource("ssm_state_0", ssm_handler) + + assert len(interface._resource_lookup) == 3 + + +# ============================================================================= +# Resource Initialization Tests +# ============================================================================= + + +def test_initialize_resources_paged_only_creates_kv_cache_manager(paged_kv_cache_config): + """Test paged-only resources create KVCacheManager (not MambaHybridCacheManager).""" + interface = CachedSequenceInterface( + max_seq_len=128, + max_batch_size=4, + device="cuda", + kv_cache_config=paged_kv_cache_config, + ) + + # Add only paged resources + interface.add_resource("k_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16)) + interface.add_resource("v_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16)) + + num_caches = interface.initialize_resources() + + assert num_caches == 2 + assert isinstance(interface.kv_cache_manager, KVCacheManager) + assert not isinstance(interface.kv_cache_manager, MambaHybridCacheManager) + + +def test_initialize_resources_mixed_creates_mamba_hybrid_cache_manager(paged_kv_cache_config): + """Test mixed paged + state resources create MambaHybridCacheManager.""" + interface = CachedSequenceInterface( + max_seq_len=128, + max_batch_size=4, + device="cuda", + kv_cache_config=paged_kv_cache_config, + ) + + # Add paged and state resources + interface.add_resource("k_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16)) + interface.add_resource("v_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16)) + interface.add_resource("ssm_state_0", StateResourceHandler(4, 64, 16, dtype=torch.bfloat16)) + + num_caches = interface.initialize_resources() + + assert num_caches == 3 + assert isinstance(interface.kv_cache_manager, MambaHybridCacheManager) + + +def test_initialize_resources_creates_cache_views_with_correct_shape(paged_kv_cache_config): + """Verify cache views are created with correct shapes.""" + interface = CachedSequenceInterface( + max_seq_len=128, + max_batch_size=4, + device="cuda", + kv_cache_config=paged_kv_cache_config, + ) + + num_kv_heads = 8 + head_dim = 64 + interface.add_resource( + "k_cache_0", PagedResourceHandler(num_kv_heads, head_dim, dtype=torch.float16) + ) + interface.add_resource( + "v_cache_0", PagedResourceHandler(num_kv_heads, head_dim, dtype=torch.float16) + ) + + interface.initialize_resources() + + # Check cache views exist + assert "k_cache_0" in interface._caches + assert "v_cache_0" in interface._caches + + # Check shapes: [num_blocks, tokens_per_block, num_kv_heads, head_dim] + k_cache = interface._caches["k_cache_0"] + assert k_cache is not None + assert k_cache.shape[1] == paged_kv_cache_config.tokens_per_block + assert k_cache.shape[2] == num_kv_heads + assert k_cache.shape[3] == head_dim + assert k_cache.dtype == torch.float16 + + +def test_initialize_resources_creates_state_views_with_correct_shape(paged_kv_cache_config): + """Verify state views are created with correct shapes for MambaHybridCacheManager.""" + interface = CachedSequenceInterface( + max_seq_len=128, + max_batch_size=4, + device="cuda", + kv_cache_config=paged_kv_cache_config, + ) + + num_heads = 4 + head_dim = 64 + ssm_state_size = 16 + interface.add_resource("k_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16)) + interface.add_resource( + "ssm_state_0", + StateResourceHandler(num_heads, head_dim, ssm_state_size, dtype=torch.bfloat16), + ) + + interface.initialize_resources() + + # Check state view exists + ssm_cache = interface._caches["ssm_state_0"] + assert ssm_cache is not None + # Shape: [num_states, num_heads, head_dim, ssm_state_size] + assert ssm_cache.shape[1] == num_heads + assert ssm_cache.shape[2] == head_dim + assert ssm_cache.shape[3] == ssm_state_size + assert ssm_cache.dtype == torch.bfloat16 + + +def test_initialize_resources_unpaged_allocated_locally(paged_kv_cache_config): + """Verify UnpagedResourceHandler resources are allocated locally (not via cache manager).""" + interface = CachedSequenceInterface( + max_seq_len=128, + max_batch_size=4, + device="cuda", + kv_cache_config=paged_kv_cache_config, + ) + + num_kv_heads = 8 + head_dim = 64 + interface.add_resource( + "unpaged_cache", UnpagedResourceHandler(num_kv_heads, head_dim, dtype=torch.float16) + ) + + interface.initialize_resources() + + # Check unpaged cache was allocated + assert "unpaged_cache" in interface._caches + unpaged_cache = interface._caches["unpaged_cache"] + assert unpaged_cache is not None + # Shape: [max_batch_size + 1, max_seq_len, num_kv_heads, head_dim] + assert unpaged_cache.shape == (4 + 1, 128, num_kv_heads, head_dim) + + +def test_is_paged_returns_true_for_paged_only(paged_kv_cache_config): + """Test is_paged() returns True when all resources are paged.""" + interface = CachedSequenceInterface( + max_seq_len=128, + max_batch_size=4, + device="cuda", + kv_cache_config=paged_kv_cache_config, + ) + + interface.add_resource("k_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16)) + interface.add_resource("v_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16)) + interface.initialize_resources() + + assert interface.is_paged() is True + + +def test_is_paged_returns_false_for_hybrid(paged_kv_cache_config): + """Test is_paged() returns False when state resources exist.""" + interface = CachedSequenceInterface( + max_seq_len=128, + max_batch_size=4, + device="cuda", + kv_cache_config=paged_kv_cache_config, + ) + + interface.add_resource("k_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16)) + interface.add_resource("ssm_state_0", StateResourceHandler(4, 64, 16, dtype=torch.bfloat16)) + interface.initialize_resources() + + assert interface.is_paged() is False + + +# ============================================================================= +# KV Cache Resize Tests +# ============================================================================= + + +def test_needs_resize_returns_false_when_fraction_is_zero(paged_kv_cache_config): + """Test needs_resize() returns False when free_gpu_memory_fraction is 0.0.""" + interface = CachedSequenceInterface( + max_seq_len=128, + max_batch_size=4, + device="cuda", + kv_cache_config=paged_kv_cache_config, + ) + + interface.add_resource("k_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16)) + interface.initialize_resources() + + assert interface.needs_resize() is False + + +def test_needs_resize_returns_true_when_fraction_is_positive(resizable_kv_cache_config): + """Test needs_resize() returns True when free_gpu_memory_fraction is positive.""" + interface = CachedSequenceInterface( + max_seq_len=128, + max_batch_size=4, + device="cuda", + kv_cache_config=resizable_kv_cache_config, + ) + + interface.add_resource("k_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16)) + interface.initialize_resources() + + assert interface.needs_resize() is True + + +def test_resize_kv_cache_manager_skipped_when_not_needed(paged_kv_cache_config): + """Test resize_kv_cache_manager() does nothing when resize not needed.""" + interface = CachedSequenceInterface( + max_seq_len=128, + max_batch_size=4, + device="cuda", + kv_cache_config=paged_kv_cache_config, + ) + + interface.add_resource("k_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16)) + interface.initialize_resources() + + # Get initial state + initial_manager = interface.kv_cache_manager + + # Resize should be a no-op + interface.resize_kv_cache_manager() + + # Manager should be the same object (no recreation) + assert interface.kv_cache_manager is initial_manager + + +# ============================================================================= +# Shutdown and Cleanup Tests +# ============================================================================= + + +def test_shutdown_clears_caches(paged_kv_cache_config): + """Test shutdown() clears all caches.""" + interface = CachedSequenceInterface( + max_seq_len=128, + max_batch_size=4, + device="cuda", + kv_cache_config=paged_kv_cache_config, + ) + + interface.add_resource("k_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16)) + interface.add_resource("v_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16)) + interface.initialize_resources() + + assert len(interface._caches) == 2 + + interface.shutdown() + + assert len(interface._caches) == 0 + + +def test_clear_cache_views_sets_views_to_none(paged_kv_cache_config): + """Test _clear_cache_views() sets paged and state cache views to None.""" + interface = CachedSequenceInterface( + max_seq_len=128, + max_batch_size=4, + device="cuda", + kv_cache_config=paged_kv_cache_config, + ) + + interface.add_resource("k_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16)) + interface.add_resource("ssm_state_0", StateResourceHandler(4, 64, 16, dtype=torch.bfloat16)) + interface.initialize_resources() + + # Manually call _clear_cache_views + interface._clear_cache_views() + + # Paged and state caches should be None + assert interface._caches["k_cache_0"] is None + assert interface._caches["ssm_state_0"] is None + + +# ============================================================================= +# Configuration Update Tests +# ============================================================================= + + +def test_update_kv_cache_config_valid_field(): + """Test update_kv_cache_config() with valid field updates.""" + interface = CachedSequenceInterface( + max_seq_len=128, + max_batch_size=4, + device="cuda", + ) + + interface.update_kv_cache_config(tokens_per_block=64) + + assert interface.kv_cache_config.tokens_per_block == 64 + + +def test_update_kv_cache_config_multiple_fields(): + """Test update_kv_cache_config() with multiple field updates.""" + interface = CachedSequenceInterface( + max_seq_len=128, + max_batch_size=4, + device="cuda", + ) + + interface.update_kv_cache_config( + tokens_per_block=64, + max_tokens=2048, + free_gpu_memory_fraction=0.8, + ) + + assert interface.kv_cache_config.tokens_per_block == 64 + assert interface.kv_cache_config.max_tokens == 2048 + assert interface.kv_cache_config.free_gpu_memory_fraction == 0.8 + + +def test_update_kv_cache_config_invalid_field_raises(): + """Test update_kv_cache_config() raises ValueError for invalid fields.""" + interface = CachedSequenceInterface( + max_seq_len=128, + max_batch_size=4, + device="cuda", + ) + + with pytest.raises(ValueError, match="Invalid KVCacheConfig field"): + interface.update_kv_cache_config(invalid_field=123) + + +# ============================================================================= +# named_args and args Tests +# ============================================================================= + + +def test_named_args_includes_sequence_info_and_caches(paged_kv_cache_config): + """Verify named_args includes both SequenceInfo args and caches.""" + interface = CachedSequenceInterface( + max_seq_len=128, + max_batch_size=4, + device="cuda", + kv_cache_config=paged_kv_cache_config, + ) + + interface.add_resource("k_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16)) + interface.initialize_resources() + + named_args = interface.named_args + + # Should contain sequence info args + assert "input_ids" in named_args + assert "position_ids" in named_args + + # Should contain cache + assert "k_cache_0" in named_args + + +def test_args_returns_tuple_of_tensors(paged_kv_cache_config): + """Verify args returns a tuple of tensors.""" + interface = CachedSequenceInterface( + max_seq_len=128, + max_batch_size=4, + device="cuda", + kv_cache_config=paged_kv_cache_config, + ) + + interface.add_resource("k_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16)) + interface.initialize_resources() + + args = interface.args + + assert isinstance(args, tuple) + assert all(isinstance(a, torch.Tensor) for a in args) + + +# ============================================================================= +# to() method Tests +# ============================================================================= + + +def test_to_moves_sequence_info(paged_kv_cache_config): + """Verify to() moves SequenceInfo to the target device.""" + interface = CachedSequenceInterface( + max_seq_len=128, + max_batch_size=4, + device="cpu", + kv_cache_config=paged_kv_cache_config, + ) + + interface.to("cuda") + + assert interface.info.device.type == "cuda" + + +# ============================================================================= +# SequenceInfo API Tests +# ============================================================================= + + +def test_sequence_info_tokens_per_block_from_constructor(): + """Verify tokens_per_block is set correctly from constructor.""" + seq_info = SequenceInfo(max_seq_len=128, max_batch_size=4, tokens_per_block=32) + assert seq_info.tokens_per_block == 32 + + +def test_sequence_info_tokens_per_block_defaults_to_max_seq_len(): + """Verify tokens_per_block defaults to max_seq_len when not provided.""" + seq_info = SequenceInfo(max_seq_len=128, max_batch_size=4) + assert seq_info.tokens_per_block == 128 + + +def test_sequence_info_estimate_cache_tokens_per_forward(): + """Test estimate_cache_tokens_per_forward() calculation.""" + # With max_num_tokens=64, max_batch_size=4, tokens_per_block=16 + # seq_len = ceil(64/4) = 16 + # num_blocks_per_seq = ceil(16/16) = 1 + # num_blocks_total = 1 * 4 = 4 + # result = 4 * 16 = 64 + seq_info = SequenceInfo( + max_seq_len=128, + max_batch_size=4, + tokens_per_block=16, + max_num_tokens=64, + ) + + result = seq_info.estimate_cache_tokens_per_forward() + + # Expected: ceil(64/4) = 16 tokens per seq + # ceil(16/16) = 1 block per seq + # 1 * 4 = 4 blocks total + # 4 * 16 = 64 tokens + assert result == 64 + + +def test_sequence_info_estimate_cache_tokens_per_forward_with_overflow(): + """Test estimate_cache_tokens_per_forward() with sequence overflow into extra blocks.""" + # With max_num_tokens=100, max_batch_size=4, tokens_per_block=16 + # seq_len = ceil(100/4) = 25 + # num_blocks_per_seq = ceil(25/16) = 2 + # num_blocks_total = 2 * 4 = 8 + # result = 8 * 16 = 128 + seq_info = SequenceInfo( + max_seq_len=128, + max_batch_size=4, + tokens_per_block=16, + max_num_tokens=100, + ) + + result = seq_info.estimate_cache_tokens_per_forward() + assert result == 128 + + +def test_sequence_info_estimate_cache_loc_capacity_no_resize(): + """Test estimate_cache_loc_capacity() when capacity is sufficient.""" + seq_info = SequenceInfo( + max_seq_len=128, + max_batch_size=4, + tokens_per_block=32, + max_num_tokens=256, + ) + + initial_capacity = seq_info._input_buffer.get_capacity("cache_loc") + + # Request a small capacity that should already be available + seq_info.estimate_cache_loc_capacity(num_blocks=4) + + # Capacity should not have changed if it was already sufficient + if initial_capacity >= 4 * 4 + 1: # num_blocks * max_batch_size + 1 + assert seq_info._input_buffer.get_capacity("cache_loc") == initial_capacity + + +def test_sequence_info_estimate_cache_loc_capacity_resizes(): + """Test estimate_cache_loc_capacity() resizes buffer when needed.""" + seq_info = SequenceInfo( + max_seq_len=128, + max_batch_size=4, + tokens_per_block=32, + max_num_tokens=128, # Small to have small initial capacity + ) + + initial_capacity = seq_info._input_buffer.get_capacity("cache_loc") + + # Request a large capacity + large_num_blocks = 1000 + seq_info.estimate_cache_loc_capacity(num_blocks=large_num_blocks) + + expected_capacity = large_num_blocks * 4 + 1 # num_blocks * max_batch_size + 1 + if expected_capacity > initial_capacity: + assert seq_info._input_buffer.get_capacity("cache_loc") >= expected_capacity + + +def test_sequence_info_last_page_len_uses_tokens_per_block(): + """Verify nest_sequences calculates last_page_len using tokens_per_block.""" + seq_info = SequenceInfo( + max_seq_len=128, + max_batch_size=4, + tokens_per_block=16, + ) + + # Set up a sequence with 25 tokens + # last_page_len = (25 - 1) % 16 + 1 = 24 % 16 + 1 = 8 + 1 = 9 + input_ids = [[1] * 25] + seq_info.nest_sequences( + input_ids, + input_pos=0, + cache_loc=[0, 1], # 2 pages + pages_per_seq=[2], + ) + + expected_last_page_len = (25 - 1) % 16 + 1 + assert seq_info._args_list["last_page_len"][0] == expected_last_page_len + + +def test_sequence_info_page_assignments(): + """Test page_assignments property returns correct structure.""" + seq_info = SequenceInfo( + max_seq_len=128, + max_batch_size=4, + tokens_per_block=16, + ) + + # Set up two sequences with different page assignments + input_ids = [[1] * 10, [1] * 20] + seq_info.nest_sequences( + input_ids, + input_pos=0, + cache_loc=[0, 1, 2], # seq 0 has page 0, seq 1 has pages 1 and 2 + pages_per_seq=[1, 2], + ) + + page_assignments = seq_info.page_assignments + assert page_assignments == [[0], [1, 2]] diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_create_ad_executor.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_create_ad_executor.py index 222974d65a..cd3cc1cb38 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_create_ad_executor.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_create_ad_executor.py @@ -107,6 +107,7 @@ def test_create_autodeploy_executor_with_guided_decoding( 512 # placeholder to satisfy ADEngine.build_from_config ) mock_engine.cache_seq_interface.info.vocab_size_padded = vocab_size_padded + mock_engine.cache_seq_interface.max_num_state_slots = max_batch_size # Mock the specific dependencies requested, plus minimal additional mocks to prevent errors with ( diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_engine.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_engine.py index 63d54aa851..50bbc715e1 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_engine.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_engine.py @@ -5,10 +5,11 @@ import torch import torch.nn as nn from tensorrt_llm import SamplingParams -from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import SequenceInfo from tensorrt_llm._torch.auto_deploy.shim.ad_executor import ADEngine from tensorrt_llm._torch.auto_deploy.shim.demollm import DemoEngine +from tensorrt_llm._torch.auto_deploy.shim.interface import CachedSequenceInterface from tensorrt_llm._torch.pyexecutor.scheduler import ScheduledRequests +from tensorrt_llm.llmapi.llm_args import KvCacheConfig class TransformerLikeModelwithFakeCachePool(nn.Module): @@ -38,12 +39,13 @@ def get_inference_model(cache_seq_interface): model = TransformerLikeModelwithFakeCachePool(vocab_size, embed_dim, hidden_dim) model.eval().to(device) + return model @pytest.mark.parametrize("engine_cls", [ADEngine, DemoEngine]) -@pytest.mark.parametrize("attn_page_size", [0, 2, 0]) -def test_engine(engine_cls: Type[ADEngine], attn_page_size: int): +@pytest.mark.parametrize("tokens_per_block", [0, 2, 0]) +def test_engine(engine_cls: Type[ADEngine], tokens_per_block: int): """Test the SimpleEngine functionality.""" seed = 42 # Set random seed for model param init @@ -55,21 +57,24 @@ def test_engine(engine_cls: Type[ADEngine], attn_page_size: int): max_seq_len = 64 max_batch_size = 8 - sequence_info = SequenceInfo( + # Create KvCacheConfig with specified tokens_per_block (use 32 as default if 0) + kv_cache_config = KvCacheConfig(tokens_per_block=tokens_per_block or max_seq_len) + cache_seq_interface = CachedSequenceInterface( max_seq_len=max_seq_len, max_batch_size=max_batch_size, - page_size=attn_page_size, + device=device, + kv_cache_config=kv_cache_config, ) - sequence_info.to(device) + cache_seq_interface.to(device) - engine = engine_cls(get_inference_model, sequence_info, device) + engine = engine_cls(get_inference_model, cache_seq_interface) # Test basic token generation with torch.inference_mode(): # Test logits input_ids = [torch.tensor([0, 1, 2], device=device)] - sequence_info.reset() - sequence_info.nest_sequences(input_ids) + cache_seq_interface.info.reset() + cache_seq_interface.info.nest_sequences(input_ids) logits = engine._compute_logits() assert logits is not None, "Logits are None" @@ -77,9 +82,11 @@ def test_engine(engine_cls: Type[ADEngine], attn_page_size: int): original_logits = get_inference_model(mock_input)(input_ids[0].unsqueeze(0))[0] assert torch.allclose(logits, original_logits, atol=1e-5), "Generated Token ID mismatch" + cache_seq_interface.shutdown() -@pytest.mark.parametrize("attn_page_size", [0, 2]) -def test_demo_engine_sampling(attn_page_size: int): + +@pytest.mark.parametrize("tokens_per_block", [0, 2]) +def test_demo_engine_sampling(tokens_per_block: int): """Test sampling logic specific to DemoEngine.""" seed = 0 torch.manual_seed(seed) @@ -90,19 +97,22 @@ def test_demo_engine_sampling(attn_page_size: int): max_seq_len = 64 max_batch_size = 8 - sequence_info = SequenceInfo( + # Create KvCacheConfig with specified tokens_per_block (use 32 as default if 0) + kv_cache_config = KvCacheConfig(tokens_per_block=tokens_per_block or 32) + cache_seq_interface = CachedSequenceInterface( max_seq_len=max_seq_len, max_batch_size=max_batch_size, - page_size=attn_page_size, + device=device, + kv_cache_config=kv_cache_config, ) - sequence_info.to(device) + cache_seq_interface.to(device) - engine = DemoEngine(get_inference_model, sequence_info, device) + engine = DemoEngine(get_inference_model, cache_seq_interface) with torch.inference_mode(): input_ids = [torch.tensor([1, 2, 3, 4], device=device)] - sequence_info.reset() - sequence_info.nest_sequences(input_ids) + cache_seq_interface.info.reset() + cache_seq_interface.info.nest_sequences(input_ids) logits = engine._compute_logits() vocab_size = logits.size(-1) @@ -127,6 +137,8 @@ def test_demo_engine_sampling(attn_page_size: int): torch.testing.assert_close(token_ids_1, token_ids_2) + cache_seq_interface.shutdown() + class _DummyKVCacheManager: def __init__(self, tokens_per_block: int): @@ -163,8 +175,8 @@ class _DummyRequest: return self._tokens -@pytest.mark.parametrize("attn_page_size", [256, 2]) -def test_ad_engine_chunked_prefill_equivalence(attn_page_size: int): +@pytest.mark.parametrize("tokens_per_block", [256, 2]) +def test_ad_engine_chunked_prefill_equivalence(tokens_per_block: int): """Verify ADEngine logits match between chunked and non-chunked prefill. We simulate chunking by splitting a single context request into two chunks and @@ -179,19 +191,22 @@ def test_ad_engine_chunked_prefill_equivalence(attn_page_size: int): max_seq_len = 64 max_batch_size = 8 - sequence_info = SequenceInfo( + # Create KvCacheConfig with specified tokens_per_block + kv_cache_config = KvCacheConfig(tokens_per_block=tokens_per_block) + cache_seq_interface = CachedSequenceInterface( max_seq_len=max_seq_len, max_batch_size=max_batch_size, - page_size=attn_page_size, + device=device, + kv_cache_config=kv_cache_config, ) - sequence_info.to(device) + cache_seq_interface.to(device) - engine = ADEngine(get_inference_model, sequence_info, device) + engine = ADEngine(get_inference_model, cache_seq_interface) # A simple prompt; model is position-wise so last token dominates the last logit tokens = [1, 2, 3, 4, 5, 6] - kv_manager = _DummyKVCacheManager(tokens_per_block=attn_page_size) + kv_manager = _DummyKVCacheManager(tokens_per_block=tokens_per_block) resource_manager = _DummyResourceManager(kv_manager) # No-chunk: whole prompt in one request @@ -215,3 +230,237 @@ def test_ad_engine_chunked_prefill_equivalence(attn_page_size: int): logits_chunked_last = engine.forward(scheduled_requests_part2, resource_manager)["logits"][-1] torch.testing.assert_close(logits_full_last, logits_chunked_last) # , atol=1e-5) + + cache_seq_interface.shutdown() + + +# ============================================================================= +# Hybrid Cache Manager Integration Tests +# ============================================================================= + + +class _DummyHybridKVCacheManager: + """Simulates MambaHybridCacheManager with mamba_cache_index.""" + + def __init__(self, tokens_per_block: int, num_slots: int = 8): + self.tokens_per_block = tokens_per_block + # mamba_cache_index maps request_id to slot_idx + self.mamba_cache_index = {i: num_slots - 1 - i for i in range(num_slots)} + self.mamba_cache_free_blocks = num_slots + + def get_cache_indices(self, request): + return list(range(1024)) + + def get_num_kv_blocks(self, num_tokens: int) -> int: + if self.tokens_per_block and self.tokens_per_block > 0: + return (num_tokens + self.tokens_per_block - 1) // self.tokens_per_block + return num_tokens + + def get_num_free_blocks(self): + return 100 + + +class _DummyRequestWithRequestId: + """Request with py_request_id for hybrid cache manager testing.""" + + def __init__( + self, + tokens: List[int], + begin: int, + size: int, + seq_slot: int = 0, + request_id: int = 0, + ): + self._tokens = tokens + self.context_current_position = begin + self.context_chunk_size = size + self.seq_slot = seq_slot + self.py_request_id = request_id + self.py_batch_idx = None + self.py_multimodal_data = None + + def get_tokens(self, _beam: int) -> List[int]: + return self._tokens + + +def test_ad_engine_prepare_inputs_with_hybrid_cache_manager(): + """Test ADEngine _prepare_inputs uses mamba_cache_index when available.""" + seed = 42 + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + device = torch.device("cuda") + max_seq_len = 64 + max_batch_size = 8 + tokens_per_block = 16 + + kv_cache_config = KvCacheConfig(tokens_per_block=tokens_per_block) + cache_seq_interface = CachedSequenceInterface( + max_seq_len=max_seq_len, + max_batch_size=max_batch_size, + device=device, + kv_cache_config=kv_cache_config, + ) + cache_seq_interface.to(device) + + engine = ADEngine(get_inference_model, cache_seq_interface) + + # Create hybrid KV cache manager with specific mamba_cache_index mapping + hybrid_manager = _DummyHybridKVCacheManager(tokens_per_block=tokens_per_block) + + class _HybridResourceManager: + def __init__(self, kv_mgr): + self._kv = kv_mgr + + def get_resource_manager(self, _): + return self._kv + + resource_manager = _HybridResourceManager(hybrid_manager) + + # Create request with specific request_id + request_id = 3 + tokens = [1, 2, 3, 4] + req = _DummyRequestWithRequestId( + tokens=tokens, + begin=0, + size=len(tokens), + seq_slot=0, + request_id=request_id, + ) + + scheduled = ScheduledRequests() + scheduled.context_requests.append(req) + + # Call _prepare_inputs + engine._prepare_inputs(scheduled, resource_manager, new_tokens=None) + + # Verify slot_idx was taken from mamba_cache_index, not seq_slot + expected_slot_idx = hybrid_manager.mamba_cache_index[request_id] + actual_slot_idx = cache_seq_interface.info._args_list["slot_idx"][0] + assert actual_slot_idx == expected_slot_idx + + cache_seq_interface.shutdown() + + +def test_ad_engine_prepare_inputs_generation_with_hybrid_cache(): + """Test ADEngine _prepare_inputs handles generation requests with hybrid cache.""" + seed = 42 + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + device = torch.device("cuda") + max_seq_len = 64 + max_batch_size = 8 + tokens_per_block = 16 + + kv_cache_config = KvCacheConfig(tokens_per_block=tokens_per_block) + cache_seq_interface = CachedSequenceInterface( + max_seq_len=max_seq_len, + max_batch_size=max_batch_size, + device=device, + kv_cache_config=kv_cache_config, + ) + cache_seq_interface.to(device) + + engine = ADEngine(get_inference_model, cache_seq_interface) + + # Create hybrid KV cache manager + hybrid_manager = _DummyHybridKVCacheManager(tokens_per_block=tokens_per_block) + + class _HybridResourceManager: + def __init__(self, kv_mgr): + self._kv = kv_mgr + + def get_resource_manager(self, _): + return self._kv + + resource_manager = _HybridResourceManager(hybrid_manager) + + # Create generation request + class _GenRequest: + def __init__(self, request_id: int, seq_slot: int, num_tokens: int): + self.py_request_id = request_id + self.seq_slot = seq_slot + self.py_batch_idx = None + self.is_dummy = False + self.py_draft_tokens = [] + + # Mock methods for generation request + def get_token(beam, idx): + return 42 # Dummy token + + self.get_token = get_token + self.get_num_tokens = lambda beam: num_tokens + self.max_beam_num_tokens = num_tokens + + def get_draft_token_length(self): + return 0 + + # Create a generation request with specific request_id + request_id = 2 + gen_req = _GenRequest(request_id=request_id, seq_slot=5, num_tokens=10) + + scheduled = ScheduledRequests() + scheduled.generation_requests.append(gen_req) + + # Call _prepare_inputs + engine._prepare_inputs(scheduled, resource_manager, new_tokens=None) + + # Verify slot_idx was taken from mamba_cache_index + expected_slot_idx = hybrid_manager.mamba_cache_index[request_id] + actual_slot_idx = cache_seq_interface.info._args_list["slot_idx"][0] + assert actual_slot_idx == expected_slot_idx + + cache_seq_interface.shutdown() + + +def test_ad_engine_with_regular_kv_cache_manager(): + """Test ADEngine falls back to seq_slot when mamba_cache_index not available.""" + seed = 42 + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + device = torch.device("cuda") + max_seq_len = 64 + max_batch_size = 8 + tokens_per_block = 16 + + kv_cache_config = KvCacheConfig(tokens_per_block=tokens_per_block) + cache_seq_interface = CachedSequenceInterface( + max_seq_len=max_seq_len, + max_batch_size=max_batch_size, + device=device, + kv_cache_config=kv_cache_config, + ) + cache_seq_interface.to(device) + + engine = ADEngine(get_inference_model, cache_seq_interface) + + # Use regular (non-hybrid) KV cache manager without mamba_cache_index + regular_manager = _DummyKVCacheManager(tokens_per_block=tokens_per_block) + resource_manager = _DummyResourceManager(regular_manager) + + # Create request with specific seq_slot + expected_seq_slot = 3 + tokens = [1, 2, 3, 4] + req = _DummyRequest( + tokens=tokens, + begin=0, + size=len(tokens), + seq_slot=expected_seq_slot, + ) + + scheduled = ScheduledRequests() + scheduled.context_requests.append(req) + + # Call _prepare_inputs + engine._prepare_inputs(scheduled, resource_manager, new_tokens=None) + + # Verify slot_idx falls back to seq_slot + actual_slot_idx = cache_seq_interface.info._args_list["slot_idx"][0] + assert actual_slot_idx == expected_seq_slot + + cache_seq_interface.shutdown() diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_llm_config.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_llm_config.py index 8589253a04..ce4812dfe3 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_llm_config.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_llm_config.py @@ -4,7 +4,6 @@ import pydantic import pytest from tensorrt_llm._torch.auto_deploy import LLM, DemoLLM, LlmArgs -from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer def test_custom_values(): @@ -14,7 +13,6 @@ def test_custom_values(): "model_factory": "AutoModelForImageTextToText", "model_kwargs": {"custom_param": True}, "skip_loading_weights": True, - "attn_page_size": 128, "max_seq_len": 2048, "transforms": { "detect_sharding": { @@ -25,10 +23,6 @@ def test_custom_values(): "stage": "cache_init", "backend": "flashinfer", }, - "resize_kv_cache": { - "stage": "cache_init", - "free_mem_ratio": 0.9, - }, }, } @@ -39,32 +33,12 @@ def test_custom_values(): "custom_param": True, } assert args.skip_loading_weights - assert args.transforms["resize_kv_cache"]["free_mem_ratio"] == 0.9 assert args.transforms["detect_sharding"]["simple_shard_only"] - assert args.attn_page_size == 128 assert args.max_seq_len == 2048 # backend should be overridden if it was 'TRTLLM' assert args.transforms["insert_cached_attention"]["backend"] == "flashinfer" -def test_free_mem_ratio_validation(): - """Test free_mem_ratio validation.""" - - def get_transform_config(free_mem_ratio): - return {"resize_kv_cache": {"stage": "cache_init", "free_mem_ratio": free_mem_ratio}} - - # Valid values - InferenceOptimizer(None, get_transform_config(0.0)) - InferenceOptimizer(None, get_transform_config(1.0)) - InferenceOptimizer(None, get_transform_config(0.5)) - - # Invalid values - with pytest.raises(ValueError): - InferenceOptimizer(None, get_transform_config(-0.1)) - with pytest.raises(ValueError): - InferenceOptimizer(None, get_transform_config(1.1)) - - # ================================ # Config Flow Tests # ================================ @@ -77,7 +51,6 @@ def test_config_params(): "model": "test-model", "model_factory": "AutoModelForImageTextToText", "skip_loading_weights": True, - "attn_page_size": 17, "max_seq_len": 19, "max_batch_size": 5, "world_size": 3, @@ -90,10 +63,6 @@ def test_config_params(): "stage": "cache_init", "backend": "flashinfer", }, - "resize_kv_cache": { - "stage": "cache_init", - "free_mem_ratio": 0.7, - }, }, } @@ -151,16 +120,11 @@ def test_config_flow( # Common assertions for both APIs assert instance.args.model_factory == test_config_params["model_factory"] - assert ( - instance.args.transforms["resize_kv_cache"]["free_mem_ratio"] - == test_config_params["transforms"]["resize_kv_cache"]["free_mem_ratio"] - ) assert ( instance.args.transforms["detect_sharding"]["simple_shard_only"] == test_config_params["transforms"]["detect_sharding"]["simple_shard_only"] ) assert instance.args.skip_loading_weights == test_config_params["skip_loading_weights"] - assert instance.args.attn_page_size == test_config_params["attn_page_size"] assert instance.args.max_seq_len == test_config_params["max_seq_len"] assert instance.args.max_batch_size == test_config_params["max_batch_size"] @@ -215,23 +179,6 @@ def test_parallel_config_validation(parallel_field, invalid_value): LlmArgs(**kwargs) -@pytest.mark.parametrize( - "backend,expected_attn_page_size", - [ - ("flashinfer", 64), # Default attn_page_size - ("triton", 1024), # Should equal max_seq_len - ], -) -def test_attention_backend_page_size_logic(backend, expected_attn_page_size): - """Test attn_page_size logic for different attention backends.""" - args = LlmArgs( - model="test-model", - max_seq_len=1024, - transforms={"insert_cached_attention": {"stage": "cache_init", "backend": backend}}, - ) - assert args.attn_page_size == expected_attn_page_size - - # ================================ # CUDA Graph Batch Sizes Tests # ================================ @@ -351,7 +298,7 @@ class TestSequenceInfoExampleBatchSize: max_batch_size=1, max_seq_len=128, max_num_tokens=128, - page_size=64, + tokens_per_block=64, ) # Set example sequence (this is what's used during export) @@ -371,7 +318,7 @@ class TestSequenceInfoExampleBatchSize: max_batch_size=32, max_seq_len=128, max_num_tokens=128, - page_size=64, + tokens_per_block=64, ) seq_info.set_example_sequence() diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py index e03d366911..2ef8033349 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py @@ -41,8 +41,10 @@ def _check_ad_config(experiment_config: ExperimentConfig, llm_args: LlmArgs): ( "meta-llama/Meta-Llama-3.1-8B-Instruct", { + "kv_cache_config": { + "free_gpu_memory_fraction": 0.0001, + }, "transforms": { - "resize_kv_cache": {"free_mem_ratio": 0.0001}, "insert_cached_attention": {"backend": "flashinfer"}, # TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/9878 # "compile_model": {"backend": "torch-opt"}, diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py index 4dbe980802..86f47c3c70 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py @@ -84,25 +84,27 @@ def prepare_dataset(root_dir: str, temp_dir: str, model_path_or_name: str): @pytest.mark.parametrize("compile_backend", ["torch-compile", "torch-opt", "torch-cudagraph"]) @pytest.mark.parametrize("model_name", ["TinyLlama/TinyLlama-1.1B-Chat-v1.0"]) def test_trtllm_bench(llm_root, compile_backend, model_name): # noqa: F811 - config = get_small_model_config(model_name) + args = get_small_model_config(model_name)["args"] + # remove kv_cache_config and max_batch_size to avoid conflicts with trtllm-bench + args.pop("kv_cache_config", None) + args.pop("max_batch_size", None) with tempfile.TemporaryDirectory() as temp_dir: extra_llm_api_options_path = f"{temp_dir}/extra_llm_api_options.yaml" with open(extra_llm_api_options_path, "w") as f: yaml.dump( { - **config["args"], + **args, "transforms": { + "resize_kv_cache": {"enabled": False}, # rely on default estimation "compile_model": { "stage": "compile", "cuda_graph_batch_sizes": [1, 2, 4, 8, 16, 32, 64, 128], "backend": compile_backend, - } + }, }, }, f, ) - dataset_path = prepare_dataset(llm_root, temp_dir, config["args"]["model"]) - run_benchmark( - model_name, str(config["args"]["model"]), dataset_path, extra_llm_api_options_path - ) + dataset_path = prepare_dataset(llm_root, temp_dir, args["model"]) + run_benchmark(model_name, str(args["model"]), dataset_path, extra_llm_api_options_path) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gather_logits_before_lm_head.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gather_logits_before_lm_head.py index abf5d6e1d5..b353aa1bf7 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gather_logits_before_lm_head.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gather_logits_before_lm_head.py @@ -22,7 +22,6 @@ from torch.export import Dim from torch.fx import GraphModule # Import to register custom op -from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import SequenceInfo from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm from tensorrt_llm._torch.auto_deploy.shim.interface import CachedSequenceInterface from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer @@ -154,13 +153,12 @@ class TestGatherLogitsBeforeLmHeadTransform: def _create_cached_sequence_interface(self, max_batch_size: int = 8, device: str = "cuda"): """Create a mock CachedSequenceInterface for testing.""" - seq_info = SequenceInfo( + return CachedSequenceInterface( max_seq_len=64, max_batch_size=max_batch_size, + device=device, max_num_tokens=1024, ) - seq_info.to(device) - return CachedSequenceInterface(seq_info, device=device) def _check_gather_op_in_graph(self, gm: GraphModule) -> bool: """Check if gather_logits_before_lm_head op is in the graph.""" diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py index d67d790a47..d449e30ac1 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py @@ -1,4 +1,5 @@ from typing import List, Optional +from unittest.mock import MagicMock import pytest import torch @@ -6,7 +7,8 @@ import torch.nn as nn from _model_test_utils import GQA from _torch_test_utils import all_close -from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import CacheConfig, SequenceInfo +# Initialize resources first +from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import PagedResourceHandler from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm from tensorrt_llm._torch.auto_deploy.models.factory import ( FullModelExportInfo, @@ -14,15 +16,18 @@ from tensorrt_llm._torch.auto_deploy.models.factory import ( SubModuleExportInfo, ) from tensorrt_llm._torch.auto_deploy.shim.interface import CachedSequenceInterface +from tensorrt_llm._torch.auto_deploy.transform.interface import Stages, TransformConfig +from tensorrt_llm._torch.auto_deploy.transform.library.kvcache import InitializeCache, ResizeKVCache from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer +from tensorrt_llm.llmapi.llm_args import KvCacheConfig class DummyFactory(ModelFactory): - """Dummy factory to pass cache_config for testing.""" + """Dummy factory to pass cache_config_updates for testing.""" - def __init__(self, model, cache_config): + def __init__(self, model, cache_config_updates): self._model = model - self.cache_config = cache_config + self.cache_config_updates = cache_config_updates def build_model(self, device: str): return self._model.to(device=device) @@ -33,8 +38,8 @@ class DummyFactory(ModelFactory): def _load_checkpoint(self, model, device): return - def get_cache_config(self): - return self.cache_config + def get_cache_config_updates(self): + return self.cache_config_updates def get_export_infos(self, model: nn.Module) -> List[SubModuleExportInfo]: return [FullModelExportInfo()] @@ -151,12 +156,19 @@ def test_sdpa_with_kv_cache(dtype, attn_backend, gqa_config): max_position_embeddings = 128 vocab_size = 1000 - # set up sequence+cache objects using standard SequenceInfo - ci = SequenceInfo( + # set up sequence+cache objects using CachedSequenceInterface + # Use tokens_per_block=max_position_embeddings so each sequence fits in 1 page for the test + kv_cache_config = KvCacheConfig( + tokens_per_block=max_position_embeddings, + max_tokens=batch_size * max_position_embeddings, + free_gpu_memory_fraction=0.0, # Disable dynamic resizing for test + ) + cm = CachedSequenceInterface( max_seq_len=max_position_embeddings, max_batch_size=batch_size, + device="cuda", + kv_cache_config=kv_cache_config, ) - cm = CachedSequenceInterface(sequence_info=ci, device="cuda") # Create the model with embedding layer and SDPA, wrap it in a fake factory model = GQAWithSdpaAndEmbedding( @@ -175,7 +187,7 @@ def test_sdpa_with_kv_cache(dtype, attn_backend, gqa_config): # Apply the transformation optimizer = InferenceOptimizer( - DummyFactory(model, CacheConfig()), + DummyFactory(model, cache_config_updates={}), { "build_model": { "stage": "factory", @@ -204,7 +216,7 @@ def test_sdpa_with_kv_cache(dtype, attn_backend, gqa_config): gm = optimizer(cm) gm.to("cuda") - num_caches = cm.initialize_caches() + num_caches = cm.initialize_resources() print(f"num_caches: {num_caches}") # Helper function to call the model with proper sequence nesting @@ -248,3 +260,273 @@ def test_sdpa_with_kv_cache(dtype, attn_backend, gqa_config): # Test 4: Exportability of the transformed model exported_gm = torch_export_to_gm(gm, args=(), kwargs=cm.named_args) assert exported_gm is not None + + +# ============================================================================= +# Transform Unit Tests for Refactored Pipeline +# ============================================================================= + + +@pytest.fixture +def dummy_cached_interface(): + """Create a CachedSequenceInterface for transform testing.""" + kv_cache_config = KvCacheConfig( + tokens_per_block=32, + max_tokens=256, + free_gpu_memory_fraction=0.0, + ) + return CachedSequenceInterface( + max_seq_len=128, + max_batch_size=4, + device="cuda", + kv_cache_config=kv_cache_config, + ) + + +def test_initialize_cache_transform_calls_initialize_resources(dummy_cached_interface): + """Verify InitializeCache transform calls cm.initialize_resources().""" + # Create a mock module + mock_module = MagicMock() + + # Create the transform with a proper config + transform = InitializeCache(config=TransformConfig(stage=Stages.PATTERN_MATCHER)) + + # Add a resource to verify initialize_resources is called + from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import PagedResourceHandler + + dummy_cached_interface.add_resource( + "k_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16) + ) + + # Mock the factory and shared_config + mock_factory = MagicMock() + mock_shared_config = MagicMock() + + # Run the transform + result_mod, info = transform._apply_to_full_model( + mock_module, dummy_cached_interface, mock_factory, mock_shared_config + ) + + # Verify caches were initialized + assert info.skipped is False + assert info.num_matches >= 1 + assert dummy_cached_interface.kv_cache_manager is not None + + +def test_resize_kv_cache_transform_skipped_when_not_needed(dummy_cached_interface): + """Verify ResizeKVCache transform is skipped when resize not needed.""" + dummy_cached_interface.add_resource( + "k_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16) + ) + dummy_cached_interface.initialize_resources() + + # Create the transform with a proper config + transform = ResizeKVCache(config=TransformConfig(stage=Stages.PATTERN_MATCHER)) + + # Create a mock module + mock_module = MagicMock() + + # Mock forward call + mock_module.side_effect = lambda **kwargs: None + + mock_factory = MagicMock() + mock_shared_config = MagicMock() + + # Run the transform - should be skipped since free_gpu_memory_fraction=0.0 + result_mod, info = transform._apply_to_full_model( + mock_module, dummy_cached_interface, mock_factory, mock_shared_config + ) + + # Verify transform was skipped + assert info.skipped is True + + +def test_resize_kv_cache_transform_runs_when_needed(): + """Verify ResizeKVCache transform runs when resize is needed.""" + # Create interface with resizing enabled + kv_cache_config = KvCacheConfig( + tokens_per_block=32, + max_tokens=256, + free_gpu_memory_fraction=0.5, # Enable resizing + ) + cm = CachedSequenceInterface( + max_seq_len=128, + max_batch_size=4, + device="cuda", + kv_cache_config=kv_cache_config, + ) + + cm.add_resource("k_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16)) + cm.initialize_resources() + + # Create the transform with a proper config + transform = ResizeKVCache(config=TransformConfig(stage=Stages.PATTERN_MATCHER)) + + # Create a simple mock module that just returns None + class MockModule: + def __call__(self, **kwargs): + return None + + mock_module = MockModule() + mock_factory = MagicMock() + mock_shared_config = MagicMock() + + # Run the transform + result_mod, info = transform._apply_to_full_model( + mock_module, cm, mock_factory, mock_shared_config + ) + + # Verify transform was not skipped + assert info.skipped is False + + +def test_insert_cached_attention_uses_add_resource(): + """Verify InsertCachedAttention uses cm.add_resource() for cache registration.""" + # This test verifies the integration point between InsertCachedAttention + # and CachedSequenceInterface.add_resource() by checking that after the + # transform, resources are registered in the interface. + + num_attention_heads = 8 + hidden_size = 512 + num_key_value_heads = 8 + vocab_size = 1000 + batch_size = 4 + max_seq_len = 64 + + kv_cache_config = KvCacheConfig( + tokens_per_block=max_seq_len, + max_tokens=batch_size * max_seq_len, + free_gpu_memory_fraction=0.0, + ) + cm = CachedSequenceInterface( + max_seq_len=max_seq_len, + max_batch_size=batch_size, + device="cuda", + kv_cache_config=kv_cache_config, + ) + + # Create a model + model = GQAWithSdpaAndEmbedding( + num_attention_heads, + hidden_size, + num_key_value_heads, + vocab_size=vocab_size, + ).to(dtype=torch.float16, device="cuda") + + # Apply transformation + optimizer = InferenceOptimizer( + DummyFactory(model, cache_config_updates={}), + { + "build_model": { + "stage": "factory", + "run_per_gm": False, + "device": "cuda", + "run_graph_cleanup": False, + "requires_clean_graph": False, + }, + "export_to_gm": { + "stage": "export", + "strict": False, + "run_per_gm": False, + "clone_state_dict": True, + "run_graph_cleanup": False, + "requires_clean_graph": False, + }, + "cleanup_input_constraints": { + "stage": "post_export", + }, + "insert_cached_attention": { + "stage": "cache_init", + "backend": "triton", + }, + }, + ) + + optimizer(cm) + + # Verify resources were added + assert len(cm._resource_lookup) > 0 + # Should have k_cache and v_cache resources registered + resource_names = list(cm._resource_lookup.keys()) + assert any("k_cache" in name for name in resource_names) + assert any("v_cache" in name for name in resource_names) + + +def test_insert_cached_attention_passes_kv_cache_config(): + """Verify InsertCachedAttention passes cm.kv_cache_config to get_cache_initializers.""" + # This test verifies that the KvCacheConfig from the interface is used + # when initializing cache resources (e.g., for dtype configuration). + + num_attention_heads = 8 + hidden_size = 512 + num_key_value_heads = 8 + vocab_size = 1000 + batch_size = 4 + max_seq_len = 64 + + # Use specific dtype in kv_cache_config + kv_cache_config = KvCacheConfig( + tokens_per_block=max_seq_len, + max_tokens=batch_size * max_seq_len, + free_gpu_memory_fraction=0.0, + dtype="bfloat16", # Specify explicit dtype + ) + cm = CachedSequenceInterface( + max_seq_len=max_seq_len, + max_batch_size=batch_size, + device="cuda", + kv_cache_config=kv_cache_config, + ) + + # Verify kv_cache_config is accessible + assert cm.kv_cache_config.dtype == "bfloat16" + + # Create a model + model = GQAWithSdpaAndEmbedding( + num_attention_heads, + hidden_size, + num_key_value_heads, + vocab_size=vocab_size, + ).to(dtype=torch.bfloat16, device="cuda") + + # Apply transformation + optimizer = InferenceOptimizer( + DummyFactory(model, cache_config_updates={}), + { + "build_model": { + "stage": "factory", + "run_per_gm": False, + "device": "cuda", + "run_graph_cleanup": False, + "requires_clean_graph": False, + }, + "export_to_gm": { + "stage": "export", + "strict": False, + "run_per_gm": False, + "clone_state_dict": True, + "run_graph_cleanup": False, + "requires_clean_graph": False, + }, + "cleanup_input_constraints": { + "stage": "post_export", + }, + "insert_cached_attention": { + "stage": "cache_init", + "backend": "triton", + }, + }, + ) + + optimizer(cm) + + # Initialize resources + cm.initialize_resources() + + assert not cm.is_paged(), "triton should not use paged resources" + assert cm._caches, "at least some resources should be present" + + # Verify cache dtype matches config + for name, handler in cm._resource_lookup.items(): + if hasattr(handler, "dtype"): + assert handler.dtype == torch.bfloat16