[#10013][feat] AutoDeploy: native cache manager integration (#10635)

Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
This commit is contained in:
Lucas Liebenwein 2026-01-27 11:23:22 -05:00 committed by GitHub
parent 7f8c260601
commit ff3a494f5c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
55 changed files with 2827 additions and 1098 deletions

View File

@ -55,14 +55,17 @@ skip_loading_weights: false
# Sequence configuration
max_batch_size: 256
# transform options
# KV cache configuration
kv_cache_config:
# fraction of free memory to use for kv-caches
free_gpu_memory_fraction: 0.8
# transform options
transforms:
insert_cached_attention:
# attention backend
backend: flashinfer
resize_kv_cache:
# fraction of free memory to use for kv-caches
free_mem_ratio: 0.8
compile_model:
# compilation backend
backend: torch-opt
@ -80,7 +83,7 @@ Enable multi-GPU execution by specifying `--tp n`, where `n` is the number of GP
|-----------|---------|-------------|
| `compile_backend` | `torch-compile` | Compilation backend: `torch-simple`, `torch-compile`, `torch-cudagraph`, `torch-opt` |
| `runtime` | `trtllm` | Runtime engine: `trtllm`, `demollm` |
| `free_mem_ratio` | `0.0` | Fraction of available GPU memory for KV cache (0.0-1.0) |
| `kv_cache_config.free_gpu_memory_fraction` | `0.9` | Fraction of available GPU memory for KV cache (0.0-1.0) |
| `skip_loading_weights` | `false` | Skip weight loading for architecture-only benchmarks |
### CUDA Graph Optimization
@ -95,7 +98,7 @@ For optimal CUDA graph performance, specify batch sizes that match your expected
## Performance Optimization Tips
1. **Memory Management**: Set `free_mem_ratio` to 0.8-0.9 for optimal KV cache utilization
1. **Memory Management**: Set `kv_cache_config.free_gpu_memory_fraction` to 0.8-0.9 for optimal KV cache utilization
1. **Compilation Backend**: Use `torch-opt` for production workloads
1. **Attention Backend**: `flashinfer` generally provides the best performance for most models
1. **CUDA Graphs**: Enable CUDA graphs for batch sizes that match your production traffic patterns.

View File

@ -15,10 +15,8 @@ llm = LLM(
compile_backend="torch-compile",
model_kwargs={"num_hidden_layers": 2}, # test with smaller model configuration
attn_backend="flashinfer", # choose between "triton" and "flashinfer"
attn_page_size=64, # page size for attention (tokens_per_block, should be == max_seq_len for triton)
skip_loading_weights=False,
model_factory="AutoModelForCausalLM", # choose appropriate model factory
free_mem_ratio=0.8, # fraction of available memory for cache
max_seq_len=<MAX_SEQ_LEN>,
max_batch_size=<MAX_BATCH_SIZE>,
)

View File

@ -49,14 +49,17 @@ skip_loading_weights: false
# Sequence configuration
max_batch_size: 256
# transform options
# KV cache configuration
kv_cache_config:
# fraction of free memory to use for kv-caches
free_gpu_memory_fraction: 0.9
# transform options
transforms:
insert_cached_attention:
# attention backend
backend: flashinfer
resize_kv_cache:
# fraction of free memory to use for kv-caches
free_mem_ratio: 0.8
compile_model:
# compilation backend
backend: torch-opt
@ -74,7 +77,7 @@ Enable multi-GPU execution by specifying `--tp n`, where `n` is the number of GP
|-----------|---------|-------------|
| `compile_backend` | `torch-compile` | Compilation backend: `torch-simple`, `torch-compile`, `torch-cudagraph`, `torch-opt` |
| `runtime` | `trtllm` | Runtime engine: `trtllm`, `demollm` |
| `free_mem_ratio` | `0.0` | Fraction of available GPU memory for KV cache (0.0-1.0) |
| `kv_cache_config.free_gpu_memory_fraction` | `0.9` | Fraction of available GPU memory for KV cache (0.0-1.0) |
| `skip_loading_weights` | `false` | Skip weight loading for architecture-only benchmarks |
### CUDA Graph Optimization
@ -89,7 +92,7 @@ For optimal CUDA graph performance, specify batch sizes that match your expected
## Performance Optimization Tips
1. **Memory Management**: Set `free_mem_ratio` to 0.8-0.9 for optimal KV cache utilization
1. **Memory Management**: Set `kv_cache_config.free_gpu_memory_fraction` to 0.8-0.9 for optimal KV cache utilization
1. **Compilation Backend**: Use `torch-opt` for production workloads
1. **Attention Backend**: `flashinfer` generally provides the best performance for most models
1. **CUDA Graphs**: Enable CUDA graphs for batch sizes that match your production traffic patterns.

View File

@ -54,14 +54,17 @@ max_batch_size: 256
# multi-gpu execution
world_size: 1
# transform options
# KV cache configuration
kv_cache_config:
# fraction of free memory to use for kv-caches
free_gpu_memory_fraction: 0.9
# transform options
transforms:
insert_cached_attention:
# attention backend
backend: flashinfer
resize_kv_cache:
# fraction of free memory to use for kv-caches
free_mem_ratio: 0.8
compile_model:
# compilation backend
backend: torch-opt
@ -77,7 +80,7 @@ transforms:
- Prefer `compile_backend: torch-opt`
- Use `attn_backend: flashinfer`
- Set realistic `cuda_graph_batch_sizes` that match expected traffic
- Tune `free_mem_ratio` to 0.80.9
- Tune `kv_cache_config.free_gpu_memory_fraction` to 0.80.9
## See also

View File

@ -17,12 +17,9 @@ llm = LLM(
transforms={
"insert_cached_attention": {"backend": "flashinfer"}, # or "triton"
"insert_cached_mla_attention": {"backend": "MultiHeadLatentAttention"},
"resize_kv_cache": {"free_mem_ratio": 0.8},
"compile_model": {"backend": "torch-compile"},
"detect_sharding": {"simple_shard_only": False},
},
attn_page_size=64, # page size for attention
skip_loading_weights=False,
max_seq_len=<MAX_SEQ_LEN>,
max_batch_size=<MAX_BATCH_SIZE>,

View File

@ -49,7 +49,6 @@ Below is a non-exhaustive list of common configuration options:
| `--args.mla-backend` | Specifies implementation for multi-head latent attention |
| `--args.max-seq-len` | Maximum sequence length for inference/cache |
| `--args.max-batch-size` | Maximum dimension for statically allocated KV cache |
| `--args.attn-page-size` | Page size for attention |
| `--prompt.batch-size` | Number of queries to generate |
| `--benchmark.enabled` | Whether to run the built-in benchmark (true/false) |
@ -125,10 +124,8 @@ llm = LLM(
compile_backend="torch-compile",
model_kwargs={"num_hidden_layers": 2}, # test with smaller model configuration
attn_backend="flashinfer", # choose between "triton" and "flashinfer"
attn_page_size=64, # page size for attention (tokens_per_block, should be == max_seq_len for triton)
skip_loading_weights=False,
model_factory="AutoModelForCausalLM", # choose appropriate model factory
free_mem_ratio=0.8, # fraction of available memory for cache
max_seq_len=<MAX_SEQ_LEN>,
max_batch_size=<MAX_BATCH_SIZE>,
)

View File

@ -3,7 +3,6 @@
max_batch_size: 1024
max_num_tokens: 2048
free_mem_ratio: 0.9
trust_remote_code: true
cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 768, 1024]
kv_cache_config:

View File

@ -3,7 +3,6 @@
max_batch_size: 1024
max_num_tokens: 2048
free_mem_ratio: 0.9
trust_remote_code: true
cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 768, 1024]
kv_cache_config:

View File

@ -6,12 +6,12 @@ enable_chunked_prefill: true
attn_backend: flashinfer
model_factory: AutoModelForCausalLM
skip_loading_weights: false
# TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/9884
free_mem_ratio: 0.88
cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 24, 32, 64, 128, 256, 320, 384]
kv_cache_config:
# disable kv_cache reuse since not supported for hybrid/ssm models
enable_block_reuse: false
free_gpu_memory_fraction: 0.88
# tunable mamba cache dtype
# --> use float32 for accuracy and default (auto) for speed
mamba_ssm_cache_dtype: auto
transforms:
detect_sharding:
allreduce_strategy: SYMM_MEM
@ -39,12 +39,6 @@ transforms:
multi_stream_moe:
stage: compile
enabled: true
# tunable mamba cache dtype
# --> use float32 for accuracy and default (null) for speed
insert_cached_ssm_attention:
cache_config:
# mamba_dtype: float32
mamba_dtype: null
gather_logits_before_lm_head:
# TODO: fix https://github.com/NVIDIA/TensorRT-LLM/issues/9878 to enable by default
enabled: true

View File

@ -4,11 +4,7 @@ max_seq_len: 2097152
max_num_tokens: 8192
enable_chunked_prefill: true
model_factory: NemotronFlashForCausalLM
free_mem_ratio: 0.9
cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 24, 32, 64, 96, 128, 256, 320, 384]
kv_cache_config:
# disable kv_cache reuse since not supported for hybrid/ssm models
enable_block_reuse: false
transforms:
gather_logits_before_lm_head:
# TODO: fix https://github.com/NVIDIA/TensorRT-LLM/issues/9878 to enable by default

View File

@ -6,11 +6,11 @@ enable_chunked_prefill: true
attn_backend: flashinfer
model_factory: AutoModelForCausalLM
skip_loading_weights: false
free_mem_ratio: 0.9
cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 24, 32, 64, 128, 256, 320, 384]
kv_cache_config:
# disable kv_cache reuse since not supported for hybrid/ssm models
enable_block_reuse: false
# tunable mamba cache dtype
# --> use float32 for accuracy and default (auto) for speed
mamba_ssm_cache_dtype: auto
transforms:
detect_sharding:
allreduce_strategy: SYMM_MEM
@ -38,12 +38,6 @@ transforms:
multi_stream_moe:
stage: compile
enabled: false
# tunable mamba cache dtype
# --> use float32 for accuracy and default (null) for speed
insert_cached_ssm_attention:
cache_config:
# mamba_dtype: float32
mamba_dtype: null
gather_logits_before_lm_head:
# TODO: fix https://github.com/NVIDIA/TensorRT-LLM/issues/9878 to enable by default
enabled: true

View File

@ -37,15 +37,16 @@ enable_chunked_prefill: true
attn_backend: flashinfer
model_factory: AutoModelForCausalLM
skip_loading_weights: false
free_mem_ratio: 0.9
cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 24, 32, 64, 128, 256, 320, 384]
kv_cache_config:
# disable kv_cache reuse since not supported for hybrid/ssm models
enable_block_reuse: false
free_gpu_memory_fraction: 0.88
# tunable mamba cache dtype
# --> use float32 for accuracy and default (auto) for speed
# mamba_ssm_cache_dtype: float32
transforms:
detect_sharding:
sharding_dims: ['ep', 'bmm']
allreduce_strategy: 'SYMM_MEM'
sharding_dims: ['ep', 'bmm']
manual_config:
head_dim: 128
tp_plan:
@ -69,9 +70,8 @@ transforms:
multi_stream_moe:
stage: compile
enabled: true
insert_cached_ssm_attention:
cache_config:
mamba_dtype: float32
gather_logits_before_lm_head:
enabled: true
fuse_mamba_a_log:
stage: post_load_fusion
enabled: true

View File

@ -99,9 +99,11 @@ transforms:
stage: weight_load
run_per_gm: false
checkpoint_device: null
expect_mem_change: true
move_inputs_to_device:
stage: weight_load
run_per_gm: false
expect_mem_change: true
############################################################################################
# RUN POST-LOAD FUSION AND OPTIMIZATIONS
############################################################################################
@ -122,18 +124,17 @@ transforms:
backend: trtllm
fuse_moe:
stage: post_load_fusion
enabled: true
expect_mem_change: true
backend: trtllm
fuse_fp8_moe:
stage: post_load_fusion
enabled: true
expect_mem_change: true
backend: trtllm
fuse_nvfp4_moe:
stage: post_load_fusion
enabled: true
expect_mem_change: true
fuse_allreduce_residual_rmsnorm:
stage: post_load_fusion
# TODO (lucaslie): add backend selection as part of configurable inference optimizers
fuse_rmsnorm:
stage: post_load_fusion
rmsnorm_backend: flashinfer
@ -175,11 +176,12 @@ transforms:
backend: cached_residual_add
initialize_cache:
stage: cache_init
expect_mem_change: true
run_per_gm: false
resize_kv_cache:
stage: cache_init
expect_mem_change: true
run_per_gm: false
free_mem_ratio: 0.0
############################################################################################
# COMPILE MODEL
############################################################################################
@ -190,6 +192,7 @@ transforms:
enabled: false
compile_model:
stage: compile
expect_mem_change: true
run_per_gm: false
cuda_graph_batch_sizes: null
backend: torch-compile

View File

@ -7,12 +7,14 @@ transforms:
build_and_load_factory_model:
stage: factory
run_per_gm: false
expect_mem_change: true
############################################################################################
# MOVE ARGUMENTS TO DEVICE
############################################################################################
move_inputs_to_device:
stage: weight_load
run_per_gm: false
expect_mem_change: true
############################################################################################
# SWITCH TO CACHED+FLATTENED ATTENTION + INITIALIZE CACHES
############################################################################################
@ -26,10 +28,11 @@ transforms:
initialize_cache:
stage: cache_init
run_per_gm: false
expect_mem_change: true
resize_kv_cache:
stage: cache_init
run_per_gm: false
free_mem_ratio: 0.0
expect_mem_change: true
############################################################################################
# COMPILE MODEL
############################################################################################

View File

@ -9,16 +9,18 @@ and operates on a purely functional paradigm that is compatible with the torch c
"""
import math
from abc import ABC, abstractmethod
from typing import Dict, List, Literal, Optional, Protocol, Sequence, Set, Tuple, Type, Union
from typing import Dict, List, Literal, Optional, Protocol, Sequence, Set, Tuple, Type, Union, final
import torch
from pydantic import BaseModel, ConfigDict, Field, field_validator
from torch._ops import OpOverloadPacket
from torch.fx import Node
from torch.types import Number
from ...._utils import nvtx_range
from tensorrt_llm.llmapi.llm_args import KvCacheConfig
from ...._utils import nvtx_range, str_dtype_to_torch
from ..utils.logger import ad_logger
Constant = Union[int, float, str, None]
@ -281,44 +283,6 @@ class InputBuffer:
self._device_views = self._create_views(self._device_buffer)
class CacheConfig(BaseModel):
"""Cache configuration for attention-related dtypes."""
model_config = ConfigDict(
arbitrary_types_allowed=True,
extra="forbid",
)
dtype: Optional[torch.dtype] = Field(default=None, description="KV cache dtype.")
mamba_dtype: Optional[torch.dtype] = Field(default=None, description="Mamba cache dtype.")
delta_dtype: Optional[torch.dtype] = Field(
default=torch.float32, description="Delta cache dtype. Defaults to float32."
)
@field_validator("dtype", "mamba_dtype", "delta_dtype", mode="before")
@classmethod
def _coerce_dtype(cls, value):
if value is None or isinstance(value, torch.dtype):
return value
if isinstance(value, str):
dtype = getattr(torch, value, None)
assert isinstance(dtype, torch.dtype), f"Invalid {dtype=}"
return dtype
return value
def __or__(self, other: "CacheConfig") -> "CacheConfig":
"""Combine two CacheConfig objects field-wise using Python's `or` semantics.
For each field, selects the first non-None value between `self` and `other`.
"""
if not isinstance(other, CacheConfig):
raise NotImplementedError(f"Cannot combine CacheConfig with {type(other)}")
merged_kwargs = {}
for field_name in type(self).model_fields.keys():
merged_kwargs[field_name] = getattr(self, field_name) or getattr(other, field_name)
return CacheConfig(**merged_kwargs)
class SequenceInfo:
"""An interface to hold information about how the sequence is laid out and stored in cache.
@ -398,9 +362,9 @@ class SequenceInfo:
def __init__(
self,
max_seq_len: int = 1,
max_batch_size: int = 1,
page_size: int = 0,
max_seq_len: int,
max_batch_size: int,
tokens_per_block: Optional[int] = None,
max_num_tokens: Optional[int] = None,
vocab_size_padded: Optional[int] = None,
):
@ -411,9 +375,7 @@ class SequenceInfo:
includes the tokens in the input sequence and the tokens generated by the model.
max_batch_size: corresponds to the maximum number of sequences (or requests) that the
model can process.
page_size: corresponds to the page size of the cache. For an unpaged cache, the page
size should be set to max_seq_len. Also note that two sequences in a batch can not
share a page.
tokens_per_block: corresponds to the tokens per block of the cache.
max_num_tokens: corresponds to the maximum number of tokens that the model can process
across all sequences in the batch. If a batch is composed of context-only requests
of input sequence length ISL, then the maximum number of sequences possible in the
@ -427,46 +389,27 @@ class SequenceInfo:
# set up basic attributes
self.max_seq_len = max_seq_len
self.max_batch_size = max_batch_size
self.page_size = page_size if page_size > 0 else max_seq_len
self.vocab_size_padded = vocab_size_padded
# NOTE (lucaslie): WAR to address issue when using flashinfer attention with
self.tokens_per_block = tokens_per_block or max_seq_len
# NOTE (lucaslie): +1 is a WAR to address issue when using flashinfer attention with
# (max_batch_size, max_seq_len) input in trtllm runtime.
# see https://github.com/NVIDIA/TensorRT-LLM/issues/4504
max_seq_len_adjusted = self.max_seq_len + 1
self.max_num_tokens = max_num_tokens or (max_seq_len + 1) * max_batch_size
# TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/9883 clean up this hack
self.max_state_slots = max_batch_size + 1
# TODO (lucaslie): can we remove this eventually from this i/f?
self.vocab_size_padded = vocab_size_padded
# if the provided max_num_tokens is less than the max_batch_size * max_seq_len_adjusted,
# we use the provided max_num_tokens. If max_num_tokens provided is more, we still use
# max_batch_size * max_seq_len_adjusted since the extra tokens cannot be used.
self.max_num_tokens = self.max_batch_size * max_seq_len_adjusted
if max_num_tokens is not None and max_num_tokens > 0:
self.max_num_tokens = min(self.max_num_tokens, max_num_tokens)
# Num pages can not be less than max_batch_size.
self._num_pages = max(
self.max_batch_size,
(self.max_num_tokens) // self.page_size # floored number of pages
+ (self.max_num_tokens / self.max_batch_size % self.page_size > 0) # check for overflow
* self.max_batch_size, # +1 page per sequence if overflow is required
)
# sanity check
assert self.num_pages >= self.max_batch_size, "num_pages can't be less than max_batch_size"
# cache_loc requires some special treatment due to block reuse. Note that the constraint for
# cache_loc with block_reuse is as follows:
# 0 <= cache_loc < num_pages
# len(cache_loc) <= max_num_cache_loc_assignments
max_num_cache_loc_assignments = (
max_seq_len_adjusted // self.page_size + 1
) * self.max_batch_size
# NOTE: we keep an extra state slot around to simplify cuda graph padding
# WHY?
# Requests that just finished won't free their used resources immediately. Specifically, the
# running order is self.scheduler.schedule_request, self._forward_step() and
# self._process_previous_batch() in the PyExecutor. Hence, the current forward step will
# remove finished requests but will not remove mamba_cache immediately and therefore it
# won't be available in time for padding in the next forward step.
self.max_num_state_slots = max_batch_size + 1
# log parameters
ad_logger.info(
f"[SequenceInfo:] {self.max_seq_len=}, {self.max_batch_size=}, {self.page_size=}, "
f"{self.max_num_tokens=} (inferred), {max_num_tokens=} (provided), {self.num_pages=}, "
f"{max_num_cache_loc_assignments=}"
f"[SequenceInfo:] {self.max_seq_len=}, {self.max_batch_size=}, {self.max_num_tokens=}"
)
# indicator if extra args are activated that are needed for cached attention backends
@ -496,8 +439,10 @@ class SequenceInfo:
# OTHER FIELDS WHERE WE NEED EFFICIENT HOST<>DEVICE TRANSFER
("_gather_idx", self.max_num_tokens, torch.int),
("_mask_scatter_indices", self.max_num_tokens, torch.int),
# cache_loc is LAST for truncation optimization (it's the largest tensor)
("cache_loc", max_num_cache_loc_assignments, torch.int),
# cache_loc is LAST for truncation optimization (it can be the largest tensor)
# NOTE: sufficient for max_num_tokens forward pass. will be resized when KVCacheManager
# is created.
("cache_loc", self.max_num_tokens, torch.int),
]
# Create the InputBuffer that manages contiguous host and device memory
@ -619,35 +564,44 @@ class SequenceInfo:
def is_generate(self) -> bool:
return all(sl == 1 for sl in self.seq_len)
@property
def num_pages(self) -> int:
return self._num_pages
@num_pages.setter
def num_pages(self, value):
self._num_pages = value
# Check if we need to resize cache_loc (it's the last tensor in the buffer)
cache_loc_capacity = self._input_buffer.get_capacity("cache_loc")
if value > cache_loc_capacity:
ad_logger.info(
f"Resizing cache_loc capacity from {cache_loc_capacity} to {value} "
f"to accommodate num_pages={value}"
)
# Resize the input buffer (cache_loc is the last tensor, so this is supported)
self._input_buffer.resize("cache_loc", value)
# Also resize the args_list to match
old_size = len(self._args_list["cache_loc"])
self._args_list["cache_loc"].extend([0] * (value - old_size))
@property
def is_paged(self) -> bool:
return self.page_size < self.max_seq_len
@property
def page_assignments(self) -> List[List[int]]:
"""Return the page assignments for each sequence."""
return self._get_page_assignments(self.cache_loc, self.pages_per_seq)
def estimate_cache_tokens_per_forward(self) -> int:
"""Estimate the max number of tokens that will be cached for a forward pass.
It is estimated assuming a worst-case allocation of tokens across sequences in a batch.
"""
seq_len = math.ceil(self.max_num_tokens / self.max_batch_size)
num_blocks_estimate_per_seq = math.ceil(seq_len / self.tokens_per_block)
num_blocks_estimate = num_blocks_estimate_per_seq * self.max_batch_size
return num_blocks_estimate * self.tokens_per_block
def estimate_cache_loc_capacity(self, num_blocks: int) -> None:
"""Estimate needed capacity of cache_loc based on available blocks and resize."""
cache_loc_capacity = self._input_buffer.get_capacity("cache_loc")
# cache_loc requires some special treatment due to block reuse. Note that the constraint for
# cache_loc with block_reuse is as follows:
# 0 <= cache_loc < num_blocks
# len(cache_loc) <= num_blocks * max_batch_size
# NOTE: # However, num_blocks * max_batch_size may be a potentially very large number and an
# overestimation, see we assume at most 10x reuse of blocks.
estimated_capacity = num_blocks * min(10, self.max_batch_size)
# NOTE (lucaslie): WAR to address issue when using flashinfer attention with
# (max_batch_size, max_seq_len) input in trtllm runtime.
# see https://github.com/NVIDIA/TensorRT-LLM/issues/4504
estimated_capacity = estimated_capacity + 1
if estimated_capacity > cache_loc_capacity:
self._input_buffer.resize("cache_loc", estimated_capacity)
# Also resize the args_list to match
old_size = len(self._args_list["cache_loc"])
self._args_list["cache_loc"].extend([0] * (estimated_capacity - old_size))
@staticmethod
def _get_page_assignments(
cache_locations: List[int], pages_per_sequence: List[int]
@ -732,7 +686,8 @@ class SequenceInfo:
# figure out page assignments
pages_per_seq = [
len(ids_one_seq) // self.page_size + (len(ids_one_seq) % self.page_size > 0)
len(ids_one_seq) // self.tokens_per_block
+ (len(ids_one_seq) % self.tokens_per_block > 0)
for ids_one_seq in input_ids
]
cache_loc = list(range(sum(pages_per_seq)))
@ -889,7 +844,7 @@ class SequenceInfo:
This i/f will ensure that all sequence info args are updated accordingly. Reset values are
chosen as "neutral" values so that for cases like rounding up batch sizes for cudagraph we
only write to unused buffers/caches.
only write to unused caches.
"""
### UPDATE SEQUENCE LENGTH AND INPUT POSITION FIRST SINCE IT'S USED FOR OTHER UPDATES ######
if seq_len is None:
@ -953,7 +908,7 @@ class SequenceInfo:
# update last page length
if last_page_len is None:
last_page_len = [(slwc - 1) % self.page_size + 1 for slwc in seq_len_with_cache]
last_page_len = [(slwc - 1) % self.tokens_per_block + 1 for slwc in seq_len_with_cache]
self._store_arg("last_page_len", last_page_len)
# check for updated slot_idx
@ -1049,6 +1004,90 @@ class SequenceInfo:
host_function(**{arg: self._get_arg(arg) for arg in args})
class ResourceHandler(ABC):
"""An abstract interface to handle a generic resource needed by attention operators.
The ResourceHandler interface standardizes operations that the cached sequence interface
performs on the resources providing an abstract handle.
"""
@abstractmethod
def allocate(self, sequence_info: SequenceInfo) -> torch.Tensor:
"""Initialize the resource for the given sequence info."""
class ManagedResourceHandler(ResourceHandler):
"""An abstract interface to handle a resource that is managed by the cache manager."""
@final
def allocate(self, sequence_info: SequenceInfo) -> torch.Tensor:
"""Allocate the resource for the given sequence info."""
raise NotImplementedError("Managed resources should not be allocated directly!")
class PagedResourceHandler(ManagedResourceHandler):
"""An abstract interface to handle a paged resource.
The PagedResourceHandler can be used to handle resources that support paging such as kv-caches.
"""
def __init__(self, *token_shape: int, dtype: torch.dtype) -> None:
"""Initialize the PagedResourceHandler.
Args:
page_shape: The shape of a single page of the resource.
dtype: The dtype of the resource.
"""
self.token_shape = token_shape
self.dtype = dtype
class StateResourceHandler(ManagedResourceHandler):
"""Handler for per-sequence state resources (e.g., Mamba SSM/conv states).
These resources have shape [max_batch_size, *state_shape] and are
managed by MambaHybridCacheManager via byte-level pooling.
"""
def __init__(self, *state_shape: int, dtype: torch.dtype) -> None:
"""Initialize the StateResourceHandler.
Args:
state_shape: The shape of a single state resource.
dtype: The dtype of the state resource.
"""
self.state_shape = state_shape
self.dtype = dtype
class UnpagedResourceHandler(ResourceHandler):
"""Handler for per-token unpaged resources (e.g., unpaged KV caches).
These resources have shape [max_batch_size, max_seq_len, *token_shape].
They are allocated locally and not managed by MambaHybridCacheManager.
"""
def __init__(self, *token_shape: int, dtype: torch.dtype) -> None:
"""Initialize the UnpagedResourceHandler.
Args:
token_shape: The shape of the resource per token.
dtype: The dtype of the resource.
"""
self.token_shape = token_shape
self.dtype = dtype
def allocate(self, sequence_info: SequenceInfo) -> torch.Tensor:
"""Initialize the unpaged resource for the given sequence info."""
return torch.empty(
sequence_info.max_num_state_slots,
sequence_info.max_seq_len,
*self.token_shape,
device=sequence_info.device,
dtype=self.dtype,
)
class MHACallable(Protocol):
def __call__(
self,
@ -1062,18 +1101,10 @@ class PrepareMetadataCallable(Protocol):
) -> List[torch.Tensor]: ...
class GetCacheCallable(Protocol):
def __call__(self, sequence_info: SequenceInfo) -> torch.Tensor: ...
class GetBufferCallable(GetCacheCallable):
pass
CacheInitializerDict = Dict[str, GetCacheCallable]
BufferInitializerDict = Dict[str, GetBufferCallable]
AttentionLayout = Literal["bsnd", "bnsd"]
ResourceHandlerDict = Dict[str, ResourceHandler]
class AttentionDescriptor(ABC):
"""An interface to define a functional attention operator.
@ -1083,11 +1114,6 @@ class AttentionDescriptor(ABC):
specific to the attention op.
"""
@classmethod
@abstractmethod
def is_paged(cls) -> bool:
"""Return if the attention op is paged or not."""
@classmethod
@abstractmethod
def get_attention_layout(cls) -> AttentionLayout:
@ -1116,7 +1142,6 @@ class AttentionDescriptor(ABC):
*meta_std, # standard metadata fields identified by matching arg names!
*meta_extra,# metadata about the sequences as returned by the prepare_metadata op
*caches, # contains layer-specific caches per provided cache initializers
*buffers, # global buffers used by the attention op as provided by buffer initializers
*constants, # basic arguments (int, float, str, None) added as CONSTANTS in the graph
) -> torch.Tensor: ...
```
@ -1172,52 +1197,46 @@ class AttentionDescriptor(ABC):
@classmethod
@abstractmethod
def get_cache_initializers(
cls, source_attn_node: Node, cache_config: CacheConfig
) -> CacheInitializerDict:
"""Provide a dictionary of function pointers that can be used to initialize the caches.
cls, source_attn_node: Node, cache_config: KvCacheConfig
) -> ResourceHandlerDict:
"""Provide a dictionary of resource handlers that can be used to initialize the resources.
The key corresponds to the argument name used in the attention op signature. The function
key doesn't need to be unique across multiple attention nodes in the graph. The key used to
describe the cache in the graph will be patched with the attention node index to ensure
uniqueness.
``get_cache_initializers`` will be called *once* during cache initialization and before
the initial forward pass for each attention op detected in the graph. The caches will be
managed by the global CacheManager and passed back to the attention op during the forward
pass.
The resource will be initialized before the initial forward pass and will be managed by the
global CacheManager and passed back to the model during the forward pass.
If the cache initializer requires information about the attention op, it can retrieve
the necessary information from the source attention node and cache config.
"""
@classmethod
def get_global_buffer_initializers(cls, source_attn_node: Node) -> BufferInitializerDict:
"""Provide a dictionary of function pointers that can be used to initialize buffers.
The key corresponds to the buffer name used in the graph module and will **not**
be patched unlike a cache key. Hence, it is a **global** key that is shared across all
attention ops in the model much like a regular buffer in an nn.Module. That means if this
i/f is called for multiple attention ops, the same buffer will be shared across all of them
if this function provides the same key multiple times.
Buffers are initialize *once* after the model initialization and before the initial forward
pass for each attention op detected in the graph. The buffer will be managed by the global
CacheManager and passed back to the attention op during the forward pass.
If the buffer initializer requires information about the attention op, it can retrieve
the necessary information from the source attention node.
"""
return {}
@classmethod
def get_constants(cls, source_attn_node: Node) -> List[Constant]:
"""Provide a list of constant arguments to be passed to the attention op.
The constant arguments are passed to the attention op as additional arguments after the
caches and buffers. The constants are expected to be of type int, float, str, or None.
caches. The constants are expected to be of type int, float, str, or None.
"""
return []
@staticmethod
def resolve_cache_dtype(dtype_config: str, fallback_dtype: torch.dtype) -> torch.dtype:
"""Resolve cache dtype from KvCacheConfig dtype string to torch.dtype.
Args:
dtype_config: The dtype string from KvCacheConfig (e.g., "auto", "float16", "bfloat16").
fallback_dtype: The fallback dtype to use when dtype_config is "auto".
Returns:
The resolved torch.dtype.
"""
if dtype_config == "auto":
return fallback_dtype
return str_dtype_to_torch(dtype_config)
@classmethod
def get_host_prepare_metadata_function(cls) -> Optional[PrepareMetadataHostCallable]:
"""Get function that performs host-side prep for the forward pass for the attention op.

View File

@ -11,17 +11,16 @@ import torch
from torch._ops import OpOverloadPacket
from torch.fx import Node
from .....llmapi.llm_args import KvCacheConfig
from ...utils.node_utils import extract_op_args
from ..attention_interface import (
AttentionDescriptor,
AttentionLayout,
AttentionRegistry,
BufferInitializerDict,
CacheConfig,
CacheInitializerDict,
Constant,
MHACallable,
SequenceInfo,
ResourceHandlerDict,
StateResourceHandler,
)
from .delta_rule.chunk import chunk_delta_rule_fwd
from .delta_rule.fused_recurrent import fused_recurrent_delta_rule_fwd
@ -136,11 +135,6 @@ def fla_cached_delta_rule_fake(
@AttentionRegistry.register("fla_delta")
class FlaDeltaBackend(AttentionDescriptor):
@classmethod
def is_paged(cls) -> bool:
# TODO: we should refine our notion of "is_paged" --> seems counterintuitive for ssm nows
return True
@classmethod
def get_attention_layout(cls) -> AttentionLayout:
return "bsnd"
@ -164,8 +158,8 @@ class FlaDeltaBackend(AttentionDescriptor):
@classmethod
def get_cache_initializers(
cls, source_attn_node: Node, cache_config: CacheConfig
) -> CacheInitializerDict:
cls, source_attn_node: Node, cache_config: KvCacheConfig
) -> ResourceHandlerDict:
key_node = source_attn_node.args[1]
value_node = source_attn_node.args[2]
num_heads = key_node.meta["val"].shape[-2]
@ -173,21 +167,15 @@ class FlaDeltaBackend(AttentionDescriptor):
value_dim = value_node.meta["val"].shape[-1]
key_dtype = key_node.meta["val"].dtype
def _get_delta_cache(si: SequenceInfo):
return torch.empty(
si.max_state_slots,
return {
"delta_cache": StateResourceHandler(
num_heads,
key_dim,
value_dim,
device=si.device,
dtype=cache_config.delta_dtype or key_dtype,
# NOTE: not configurable at the moment, using auto to match the key dtype
dtype=cls.resolve_cache_dtype("auto", key_dtype),
)
return {"delta_cache": _get_delta_cache}
@classmethod
def get_global_buffer_initializers(cls, source_attn_node: Node) -> BufferInitializerDict:
return {}
}
@classmethod
def get_constants(cls, source_attn_node: Node) -> List[Constant]:

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass, fields
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Literal, Optional, Tuple, Union
import flashinfer
import torch
@ -7,6 +7,7 @@ from torch._ops import OpOverloadPacket
from torch._subclasses import FakeTensor
from torch.fx import Node
from ....llmapi.llm_args import KvCacheConfig
from ...flashinfer_utils import get_env_enable_pdl
from ..utils.cuda_graph import cuda_graph_state
from ..utils.logger import ad_logger
@ -15,14 +16,12 @@ from .attention_interface import (
AttentionDescriptor,
AttentionLayout,
AttentionRegistry,
BufferInitializerDict,
CacheConfig,
CacheInitializerDict,
Constant,
MHACallable,
PagedResourceHandler,
PrepareMetadataCallable,
PrepareMetadataHostCallable,
SequenceInfo,
ResourceHandlerDict,
)
@ -57,6 +56,9 @@ class _FlashInferPlanner:
]
plan_params_prefill: Optional[PlanParams]
plan_params_decode: Optional[PlanParams]
kv_layout: Literal["NHD", "HND"] = (
"NHD" # TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/10966
)
def __init__(self):
self.workspace_buffer = None
@ -77,7 +79,7 @@ class _FlashInferPlanner:
if use_cuda_graph:
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
self.workspace_buffer,
"NHD",
self.kv_layout,
use_cuda_graph=True,
paged_kv_indptr_buffer=indptr,
paged_kv_indices_buffer=indices,
@ -88,28 +90,33 @@ class _FlashInferPlanner:
else:
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
self.workspace_buffer,
"NHD",
self.kv_layout,
use_tensor_cores=True,
backend="fa2" if torch.cuda.get_device_capability(0) == (9, 0) else "auto",
)
def init_workspace(self, workspace_buffer: torch.Tensor):
def reset(self, device: torch.device) -> None:
self.plan_params_prefill = None
self.plan_params_decode = None
if isinstance(self.workspace_buffer, torch.Tensor):
return
self.__init__() # reset all state
self.workspace_buffer = workspace_buffer
# NOTE (lucaslie): avoid OOM for many cudagraphs,
# see https://github.com/NVIDIA/TensorRT-LLM/pull/3686
self.workspace_buffer = torch.empty(320 * 1024 * 1024, device=device, dtype=torch.uint8)
# NOTE (lucaslie): flashinfer fa3 backend has accuracy issue + illegal memory access issues
# on H100 PCIe, see https://github.com/NVIDIA/TensorRT-LLM/issues/4504
self.prefill_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
self.workspace_buffer,
"NHD",
self.kv_layout,
backend="fa2",
)
self.decode_wrapper = self._init_decode_wrapper()
def reset(self) -> None:
self.plan_params_prefill = None
self.plan_params_decode = None
def plan_generate_only(
self,
num_seq: int,
@ -248,7 +255,7 @@ def prepare_flashinfer_metadata(
num_seq = num_prefill + num_decode
num_tokens = num_prefill_tokens + num_decode
_GlobalFlashInferPlanner.reset()
_GlobalFlashInferPlanner.reset(position_ids.device)
qo_indptr = cu_seqlen[: num_seq + 1]
@ -316,8 +323,6 @@ def flashinfer_mha_with_cache(
# CACHES
k_cache: torch.Tensor,
v_cache: torch.Tensor,
# BUFFERS
workspace_buffer: torch.Tensor,
# CONSTANTS
scale: Optional[float],
k_scale: float,
@ -458,8 +463,6 @@ def flashinfer_mha_with_cache_fake(
# CACHES
k_cache: torch.Tensor,
v_cache: torch.Tensor,
# BUFFERS
workspace_buffer: torch.Tensor,
# CONSTANTS
scale: Optional[float],
k_scale: float,
@ -474,11 +477,6 @@ class FlashInferAttention(AttentionDescriptor):
def _get_planner(cls) -> _FlashInferPlanner:
return _GlobalFlashInferPlanner
@classmethod
def is_paged(cls):
"""Return if the attention op is paged or not."""
return True
@classmethod
def get_attention_layout(cls) -> AttentionLayout:
"""Get the attention layout expected by the backend."""
@ -519,35 +517,25 @@ class FlashInferAttention(AttentionDescriptor):
@classmethod
def get_cache_initializers(
cls, source_attn_node: Node, cache_config: CacheConfig
) -> CacheInitializerDict:
cls, source_attn_node: Node, cache_config: KvCacheConfig
) -> ResourceHandlerDict:
# source op is [bsnd] layout already
k_fake: FakeTensor = source_attn_node.args[1].meta["val"]
num_kv_heads = k_fake.shape[2]
head_dim = k_fake.shape[3]
def _get_cache(si: SequenceInfo):
return torch.empty(
si.num_pages,
si.page_size,
return {
"k_cache": PagedResourceHandler(
num_kv_heads,
head_dim,
device=si.device,
dtype=cache_config.dtype or k_fake.dtype,
)
return {"k_cache": _get_cache, "v_cache": _get_cache}
@classmethod
def get_global_buffer_initializers(cls, source_attn_node: Node) -> BufferInitializerDict:
def _init_workspace(si: SequenceInfo) -> torch.Tensor:
# NOTE (lucaslie): avoid OOM for many cudagraphs,
# see https://github.com/NVIDIA/TensorRT-LLM/pull/3686
buffer = torch.empty(320 * 1024 * 1024, dtype=torch.uint8, device=si.device)
cls._get_planner().init_workspace(buffer)
return buffer
return {"workspace_buffer": _init_workspace}
dtype=cls.resolve_cache_dtype(cache_config.dtype, k_fake.dtype),
),
"v_cache": PagedResourceHandler(
num_kv_heads,
head_dim,
dtype=cls.resolve_cache_dtype(cache_config.dtype, k_fake.dtype),
),
}
@classmethod
def get_host_prepare_metadata_function(cls) -> Optional[PrepareMetadataHostCallable]:

View File

@ -33,16 +33,16 @@ from torch.fx import Node
from tensorrt_llm._torch.modules.mamba import PAD_SLOT_ID
from tensorrt_llm._torch.modules.mamba.causal_conv1d import causal_conv1d_fn, causal_conv1d_update
from .....llmapi.llm_args import KvCacheConfig
from ...utils.node_utils import extract_op_args
from ..attention_interface import (
AttentionDescriptor,
AttentionLayout,
AttentionRegistry,
CacheConfig,
CacheInitializerDict,
Constant,
MHACallable,
SequenceInfo,
ResourceHandlerDict,
StateResourceHandler,
)
@ -164,10 +164,6 @@ def cuda_cached_causal_conv1d_wrapper(input: torch.Tensor, *args, **kwargs) -> t
@AttentionRegistry.register("cuda_causal_conv")
class CudaBackendCausalConv(AttentionDescriptor):
@classmethod
def is_paged(cls) -> bool:
return True
@classmethod
def get_attention_layout(cls) -> AttentionLayout:
# Hidden states follow [b, s, c]
@ -193,24 +189,21 @@ class CudaBackendCausalConv(AttentionDescriptor):
@classmethod
def get_cache_initializers(
cls, source_attn_node: Node, cache_config: CacheConfig
) -> CacheInitializerDict:
cls, source_attn_node: Node, cache_config: KvCacheConfig
) -> ResourceHandlerDict:
inp_fake: torch.Tensor = source_attn_node.args[0].meta["val"]
w_fake: torch.Tensor = source_attn_node.args[1].meta["val"]
in_channels = inp_fake.shape[-1]
kernel_size = w_fake.shape[-1]
def _get_conv_cache(si: SequenceInfo):
return torch.empty(
si.max_state_slots,
in_channels,
max(1, kernel_size - 1),
device=si.device,
dtype=inp_fake.dtype,
)
return {"conv_state_cache": _get_conv_cache}
conv_state_handler = StateResourceHandler(
in_channels,
max(1, kernel_size - 1),
# NOTE: not configurable at the moment, using auto to match the input dtype
dtype=cls.resolve_cache_dtype("auto", inp_fake.dtype),
)
return {"conv_state_cache": conv_state_handler}
@classmethod
def get_constants(cls, source_attn_node: Node) -> List[Constant]:

View File

@ -16,17 +16,16 @@ import torch.nn.functional as F
from torch._ops import OpOverloadPacket
from torch.fx import Node
from .....llmapi.llm_args import KvCacheConfig
from ...utils.node_utils import extract_op_args
from ..attention_interface import (
AttentionDescriptor,
AttentionLayout,
AttentionRegistry,
BufferInitializerDict,
CacheConfig,
CacheInitializerDict,
Constant,
MHACallable,
SequenceInfo,
ResourceHandlerDict,
StateResourceHandler,
)
@ -270,11 +269,6 @@ def _torch_cached_causal_conv1d_fake(
@AttentionRegistry.register("torch_causal_conv")
class TorchBackendCausalConv(AttentionDescriptor):
@classmethod
def is_paged(cls) -> bool:
# TODO: we should refine our notion of "is_paged" --> seems counterintuitive for ssm nows
return True
@classmethod
def get_attention_layout(cls) -> AttentionLayout:
# Hidden states follow [b, s, c]
@ -300,28 +294,22 @@ class TorchBackendCausalConv(AttentionDescriptor):
@classmethod
def get_cache_initializers(
cls, source_attn_node: Node, cache_config: CacheConfig
) -> CacheInitializerDict:
cls, source_attn_node: Node, cache_config: KvCacheConfig
) -> ResourceHandlerDict:
inp_fake: torch.Tensor = source_attn_node.args[0].meta["val"]
w_fake: torch.Tensor = source_attn_node.args[1].meta["val"]
in_channels = inp_fake.shape[-1]
kernel_size = w_fake.shape[-1]
def _get_conv_cache(si: SequenceInfo):
return torch.empty(
si.max_state_slots,
return {
"conv_state_cache": StateResourceHandler(
in_channels,
kernel_size,
device=si.device,
dtype=inp_fake.dtype,
# NOTE: not configurable at the moment, using auto to match the input dtype
dtype=cls.resolve_cache_dtype("auto", inp_fake.dtype),
)
return {"conv_state_cache": _get_conv_cache}
@classmethod
def get_global_buffer_initializers(cls, source_attn_node: Node) -> BufferInitializerDict:
return {}
}
@classmethod
def get_constants(cls, source_attn_node: Node) -> List[Constant]:

View File

@ -12,17 +12,16 @@ import torch
from torch._ops import OpOverloadPacket
from torch.fx import Node
from .....llmapi.llm_args import KvCacheConfig
from ...utils.node_utils import extract_op_args
from ..attention_interface import (
AttentionDescriptor,
AttentionLayout,
AttentionRegistry,
BufferInitializerDict,
CacheConfig,
CacheInitializerDict,
Constant,
MHACallable,
SequenceInfo,
ResourceHandlerDict,
StateResourceHandler,
)
from .torch_mamba import _torch_ssm_prefill
@ -268,11 +267,6 @@ def _torch_cached_ssm_fake(
@AttentionRegistry.register("torch_ssm")
class TorchBackendSSM(AttentionDescriptor):
@classmethod
def is_paged(cls) -> bool:
# TODO: we should refine our notion of "is_paged" --> seems counterintuitive for ssm now
return True
@classmethod
def get_attention_layout(cls) -> AttentionLayout:
# Hidden states follow [b, s, n, d]
@ -297,8 +291,8 @@ class TorchBackendSSM(AttentionDescriptor):
@classmethod
def get_cache_initializers(
cls, source_attn_node: Node, cache_config: CacheConfig
) -> CacheInitializerDict:
cls, source_attn_node: Node, cache_config: KvCacheConfig
) -> ResourceHandlerDict:
# Shapes from fake tensors
hs_fake: torch.Tensor = source_attn_node.args[0].meta["val"]
B_fake: torch.Tensor = source_attn_node.args[2].meta["val"]
@ -315,23 +309,13 @@ class TorchBackendSSM(AttentionDescriptor):
ssm_state_size = max(1, B_fake.shape[-1])
# extract ssm_state_dtype from cache_config or hs_fake
ssm_state_dtype = cache_config.mamba_dtype or hs_fake.dtype
ssm_state_dtype = cls.resolve_cache_dtype(cache_config.mamba_ssm_cache_dtype, hs_fake.dtype)
def _get_ssm_cache(si: SequenceInfo):
return torch.empty(
si.max_state_slots,
num_heads,
head_dim,
ssm_state_size,
device=si.device,
dtype=ssm_state_dtype,
return {
"ssm_state_cache": StateResourceHandler(
num_heads, head_dim, ssm_state_size, dtype=ssm_state_dtype
)
return {"ssm_state_cache": _get_ssm_cache}
@classmethod
def get_global_buffer_initializers(cls, source_attn_node: Node) -> BufferInitializerDict:
return {}
}
@classmethod
def get_constants(cls, source_attn_node: Node) -> List[Constant]:

View File

@ -24,17 +24,17 @@ from tensorrt_llm._torch.modules.mamba.mamba2_metadata import cu_seqlens_to_chun
from tensorrt_llm._torch.modules.mamba.selective_state_update import selective_state_update
from tensorrt_llm._torch.modules.mamba.ssd_combined import mamba_chunk_scan_combined
from .....llmapi.llm_args import KvCacheConfig
from ...utils.node_utils import extract_op_args
from ..attention_interface import (
AttentionDescriptor,
AttentionLayout,
AttentionRegistry,
CacheConfig,
CacheInitializerDict,
Constant,
MHACallable,
PrepareMetadataCallable,
SequenceInfo,
ResourceHandlerDict,
StateResourceHandler,
)
@ -274,10 +274,6 @@ def _triton_cached_ssm_fake(
@AttentionRegistry.register("triton_ssm")
class TritonBackendSSM(AttentionDescriptor):
@classmethod
def is_paged(cls) -> bool:
return True
@classmethod
def get_attention_layout(cls) -> AttentionLayout:
# Hidden states follow [b, s, n, d]
@ -313,8 +309,8 @@ class TritonBackendSSM(AttentionDescriptor):
@classmethod
def get_cache_initializers(
cls, source_attn_node: Node, cache_config: CacheConfig
) -> CacheInitializerDict:
cls, source_attn_node: Node, cache_config: KvCacheConfig
) -> ResourceHandlerDict:
# Shapes from fake tensors
hs_fake: torch.Tensor = source_attn_node.args[0].meta["val"]
B_fake: torch.Tensor = source_attn_node.args[2].meta["val"]
@ -328,19 +324,13 @@ class TritonBackendSSM(AttentionDescriptor):
ssm_state_size = max(1, B_fake.shape[-1])
# extract ssm_state_dtype from cache_config or hs_fake
ssm_state_dtype = cache_config.mamba_dtype or hs_fake.dtype
ssm_state_dtype = cls.resolve_cache_dtype(cache_config.mamba_ssm_cache_dtype, hs_fake.dtype)
def _get_ssm_cache(si: SequenceInfo):
return torch.empty(
si.max_state_slots,
num_heads,
head_dim,
ssm_state_size,
device=si.device,
dtype=ssm_state_dtype,
return {
"ssm_state_cache": StateResourceHandler(
num_heads, head_dim, ssm_state_size, dtype=ssm_state_dtype
)
return {"ssm_state_cache": _get_ssm_cache}
}
@classmethod
def get_constants(cls, source_attn_node: Node) -> List[Constant]:

View File

@ -6,21 +6,35 @@ import torch
from torch._ops import OpOverloadPacket
from torch.fx import Node
from ....llmapi.llm_args import KvCacheConfig
from .attention_interface import (
AttentionDescriptor,
AttentionLayout,
AttentionRegistry,
BufferInitializerDict,
CacheConfig,
CacheInitializerDict,
MHACallable,
SequenceInfo,
ResourceHandlerDict,
UnpagedResourceHandler,
)
from .triton_attention import _flattened_context_mha, _generate_mha
Constant = Union[int, float, str, None]
def _precompute_inv_freq(
max_seq_len: int, head_dim: int, rope_theta: float, device: torch.device
) -> torch.Tensor:
inv_freq = 1.0 / (
rope_theta ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim)
)
t = torch.arange(max_seq_len, device=inv_freq.device, dtype=inv_freq.dtype)
freqs = torch.outer(t, inv_freq.to(t.device))
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
cos_sin_stacked = torch.stack([emb.cos().to(torch.bfloat16), emb.sin().to(torch.bfloat16)])
return cos_sin_stacked
@torch.library.custom_op(
"auto_deploy::triton_attention_fused_flattened_mla_with_cache", mutates_args=()
)
@ -41,10 +55,9 @@ def fused_flattened_mla_with_cache(
# CACHES
k_cache: torch.Tensor,
v_cache: torch.Tensor,
# BUFFERS
cos_sin_stacked: torch.Tensor,
# CONSTANTS
softmax_scale: Optional[float] = None,
rope_theta: Optional[float] = None,
) -> torch.Tensor:
"""Flattened & fused MLA with cache with triton kernels."""
# b, s info
@ -84,7 +97,12 @@ def fused_flattened_mla_with_cache(
k_pe = k_pe.clone().transpose(1, 2).view(*bs_view, -1, qk_rope_head_dim).contiguous()
value_states = value_states.transpose(1, 2).view(*bs_view, -1, v_head_dim).contiguous()
# Apply RoPE
if cos_sin_stacked.numel() > 0:
if rope_theta is not None:
max_seq_len = (input_pos + seq_len).max().item()
cos_sin_stacked = _precompute_inv_freq(
max_seq_len, qk_rope_head_dim, rope_theta, q_pe.device
)
# Extract cos and sin from freqs_cis
cos_base = cos_sin_stacked[0, ...]
sin_base = cos_sin_stacked[1, ...]
@ -176,10 +194,9 @@ def fused_flattened_mla_with_cache_fake(
# CACHES
k_cache: torch.Tensor,
v_cache: torch.Tensor,
# BUFFERS
cos_sin_stacked: torch.Tensor,
# CONSTANTS
softmax_scale: Optional[float] = None,
rope_theta: Optional[float] = None,
):
v_head_dim = kv.shape[-1] - q_nope.shape[-1]
return torch.empty_like(kv[..., -v_head_dim:])
@ -187,11 +204,6 @@ def fused_flattened_mla_with_cache_fake(
@AttentionRegistry.register("MultiHeadLatentAttention")
class MultiHeadLatentAttention(AttentionDescriptor):
@classmethod
def is_paged(cls) -> bool:
"""Return if the attention op is paged or not."""
return False
@classmethod
def get_attention_layout(cls) -> AttentionLayout:
"""Get the attention layout expected by the backend."""
@ -216,8 +228,8 @@ class MultiHeadLatentAttention(AttentionDescriptor):
@classmethod
def get_cache_initializers(
cls, source_attn_node: Node, cache_config: CacheConfig
) -> CacheInitializerDict:
cls, source_attn_node: Node, cache_config: KvCacheConfig
) -> ResourceHandlerDict:
q_nope_fake = source_attn_node.args[0].meta["val"]
q_pe_fake = source_attn_node.args[1].meta["val"]
kv_fake = source_attn_node.args[2].meta["val"]
@ -226,56 +238,21 @@ class MultiHeadLatentAttention(AttentionDescriptor):
head_dim = q_nope_fake.shape[-1]
rope_dim = q_pe_fake.shape[-1]
def _get_k_cache(si: SequenceInfo):
assert not si.is_paged, "Paged cache not supported for MultiHeadLatentAttention"
return torch.empty(
si.num_pages,
si.page_size,
return {
"k_cache": UnpagedResourceHandler(
num_kv_heads,
head_dim + rope_dim,
device=si.device,
dtype=cache_config.dtype or kv_fake.dtype,
)
def _get_v_cache(si: SequenceInfo):
assert not si.is_paged, "Paged cache not supported for MultiHeadLatentAttention"
return torch.empty(
si.num_pages,
si.page_size,
dtype=cls.resolve_cache_dtype(cache_config.dtype, kv_fake.dtype),
),
"v_cache": UnpagedResourceHandler(
num_kv_heads,
head_dim,
device=si.device,
dtype=cache_config.dtype or kv_fake.dtype,
)
return {"k_cache": _get_k_cache, "v_cache": _get_v_cache}
@classmethod
def get_global_buffer_initializers(cls, source_attn_node: Node) -> BufferInitializerDict:
q_pe_fake = source_attn_node.args[1].meta["val"]
rope_head_dim = q_pe_fake.shape[-1]
rope_theta: float = 10000.0 # TODO: remove once MLA is unfused
def _get_cos_sin_stacked(si: SequenceInfo):
if rope_theta is None:
return torch.empty(0, device=si.device)
return cls._precompute_inv_freq(si.max_seq_len, rope_head_dim, rope_theta).to(si.device)
return {
f"cos_sin_stacked_{rope_head_dim}_{rope_theta}".replace(".", "_"): _get_cos_sin_stacked
dtype=cls.resolve_cache_dtype(cache_config.dtype, kv_fake.dtype),
),
}
@classmethod
def get_constants(cls, source_attn_node: Node) -> List[Constant]:
return [None]
@staticmethod
def _precompute_inv_freq(seq_len: int, head_dim: int, rope_theta: float = 1e4) -> torch.Tensor:
inv_freq = 1.0 / (rope_theta ** (torch.arange(0, head_dim, 2).float() / head_dim))
t = torch.arange(seq_len, device=inv_freq.device, dtype=inv_freq.dtype)
freqs = torch.outer(t, inv_freq.to(t.device))
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
cos_sin_stacked = torch.stack([emb.cos().to(torch.bfloat16), emb.sin().to(torch.bfloat16)])
return cos_sin_stacked
softmax_scale = None
rope_theta = 10000.0 # TODO: remove once MLA is unfused
return [softmax_scale, rope_theta]

View File

@ -8,18 +8,17 @@ from torch._ops import OpOverloadPacket
from torch._subclasses import FakeTensor
from torch.fx import Node
from ....llmapi.llm_args import KvCacheConfig
from ..utils.logger import ad_logger
from ..utils.node_utils import extract_op_args
from .attention_interface import (
AttentionDescriptor,
AttentionLayout,
AttentionRegistry,
BufferInitializerDict,
CacheConfig,
CacheInitializerDict,
Constant,
MHACallable,
SequenceInfo,
ResourceHandlerDict,
UnpagedResourceHandler,
)
from .torch_attention import repeat_kv, update_kv_cache
@ -375,11 +374,6 @@ def torch_backend_mha_with_cache_fake(
@AttentionRegistry.register("torch")
class TorchBackendAttention(AttentionDescriptor):
@classmethod
def is_paged(cls) -> bool:
"""Return if the attention op is paged or not."""
return False
@classmethod
def get_attention_layout(cls) -> AttentionLayout:
"""Get the attention layout expected by the source op and the cached attention op."""
@ -404,8 +398,8 @@ class TorchBackendAttention(AttentionDescriptor):
@classmethod
def get_cache_initializers(
cls, source_attn_node: Node, cache_config: CacheConfig
) -> CacheInitializerDict:
cls, source_attn_node: Node, cache_config: KvCacheConfig
) -> ResourceHandlerDict:
# source op is [bsnd] layout already
k_fake: FakeTensor = source_attn_node.args[1].meta["val"]
v_fake: FakeTensor = source_attn_node.args[2].meta["val"]
@ -413,33 +407,18 @@ class TorchBackendAttention(AttentionDescriptor):
k_head_dim = k_fake.shape[3]
v_head_dim = v_fake.shape[3]
def _get_k_cache(si: SequenceInfo):
assert not si.is_paged, "Paged cache not supported for torch backend"
return torch.empty(
si.num_pages,
si.page_size,
return {
"k_cache": UnpagedResourceHandler(
num_kv_heads,
k_head_dim,
device=si.device,
dtype=cache_config.dtype or k_fake.dtype,
)
def _get_v_cache(si: SequenceInfo):
assert not si.is_paged, "Paged cache not supported for torch backend"
return torch.empty(
si.num_pages,
si.page_size,
dtype=cls.resolve_cache_dtype(cache_config.dtype, k_fake.dtype),
),
"v_cache": UnpagedResourceHandler(
num_kv_heads,
v_head_dim,
device=si.device,
dtype=cache_config.dtype or v_fake.dtype,
)
return {"k_cache": _get_k_cache, "v_cache": _get_v_cache}
@classmethod
def get_global_buffer_initializers(cls, source_attn_node: Node) -> BufferInitializerDict:
return {}
dtype=cls.resolve_cache_dtype(cache_config.dtype, v_fake.dtype),
),
}
@classmethod
def get_constants(cls, source_attn_node: Node) -> List[Constant]:

View File

@ -172,7 +172,7 @@ def torch_fake_quant_fp8_linear(
- input_zp / weight_zp ignored
"""
if weight_quantized.dtype != torch.float8_e4m3fn:
raise TypeError("FP8 path requires weight_quantized.dtype == float8_e4m3fn")
raise TypeError("FP8 path requires weight_quantized.dtype == torch.float8_e4m3fn")
s_in = _expect_single_scale(input_scale, "input_scale")
s_w = _expect_single_scale(weight_scale, "weight_scale")

View File

@ -9,18 +9,17 @@ from torch._ops import OpOverloadPacket
from torch._subclasses import FakeTensor
from torch.fx import Node
from ....llmapi.llm_args import KvCacheConfig
from ..utils.logger import ad_logger
from ..utils.node_utils import extract_op_args
from .attention_interface import (
AttentionDescriptor,
AttentionLayout,
AttentionRegistry,
BufferInitializerDict,
CacheConfig,
CacheInitializerDict,
Constant,
MHACallable,
SequenceInfo,
ResourceHandlerDict,
UnpagedResourceHandler,
)
from .triton_kernels.attention_with_kv_cache import (
attention_kv_stage2,
@ -312,11 +311,6 @@ def flattened_mha_fake(
@AttentionRegistry.register("triton")
class TritonAttention(AttentionDescriptor):
@classmethod
def is_paged(cls) -> bool:
"""Return if the attention op is paged or not."""
return False
@classmethod
def get_attention_layout(cls) -> AttentionLayout:
"""Get the attention layout expected by the source op and the cached attention op."""
@ -341,8 +335,8 @@ class TritonAttention(AttentionDescriptor):
@classmethod
def get_cache_initializers(
cls, source_attn_node: Node, cache_config: CacheConfig
) -> CacheInitializerDict:
cls, source_attn_node: Node, cache_config: KvCacheConfig
) -> ResourceHandlerDict:
# source op is [bsnd] layout already
k_fake: FakeTensor = source_attn_node.args[1].meta["val"]
v_fake: FakeTensor = source_attn_node.args[2].meta["val"]
@ -350,33 +344,18 @@ class TritonAttention(AttentionDescriptor):
k_head_dim = k_fake.shape[3]
v_head_dim = v_fake.shape[3]
def _get_k_cache(si: SequenceInfo):
assert not si.is_paged, "Paged cache not supported for triton"
return torch.empty(
si.num_pages,
si.page_size,
return {
"k_cache": UnpagedResourceHandler(
num_kv_heads,
k_head_dim,
device=si.device,
dtype=cache_config.dtype or k_fake.dtype,
)
def _get_v_cache(si: SequenceInfo):
assert not si.is_paged, "Paged cache not supported for triton"
return torch.empty(
si.num_pages,
si.page_size,
dtype=cls.resolve_cache_dtype(cache_config.dtype, k_fake.dtype),
),
"v_cache": UnpagedResourceHandler(
num_kv_heads,
v_head_dim,
device=si.device,
dtype=cache_config.dtype or v_fake.dtype,
)
return {"k_cache": _get_k_cache, "v_cache": _get_v_cache}
@classmethod
def get_global_buffer_initializers(cls, source_attn_node: Node) -> BufferInitializerDict:
return {}
dtype=cls.resolve_cache_dtype(cache_config.dtype, v_fake.dtype),
),
}
@classmethod
def get_constants(cls, source_attn_node: Node) -> List[Constant]:

View File

@ -9,7 +9,6 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
from ...llmapi.llm_args import (
BuildConfig,
EagleDecodingConfig,
KvCacheConfig,
SamplerType,
TorchLlmArgs,
_ParallelConfig,
@ -45,7 +44,6 @@ def _check_for_default_value_only(
_TRANSFORMS_SHORTCUT_LOOKUP = {
"attn_backend": ("insert_cached_attention.backend", "transformers_replace_cached_attn.backend"),
"free_mem_ratio": ("resize_kv_cache.free_mem_ratio",),
"compile_backend": ("compile_model.backend",),
"cuda_graph_batch_sizes": ("compile_model.cuda_graph_batch_sizes",),
}
@ -191,25 +189,11 @@ class LlmArgs(DynamicYamlMixInForSettings, TorchLlmArgs, BaseSettings):
device: str = Field(default="cuda", description="The device to use for the model.", frozen=True)
# TODO: see if we can just remove this field and use kv_cache_config.dtype instead?
kv_cache_dtype: str = Field(
default="auto",
description="Data type for KV cache. This is a temporary field until kv_cache_dtype is "
"supported in AutoDeploy.",
)
sampler_type: Union[str, SamplerType] = Field(
default=SamplerType.TorchSampler,
description="The type of sampler to use. Options are TRTLLMSampler or TorchSampler. Defaults to TorchSampler.",
)
# NOTE: we do not support copy_on_partial_reuse in AutoDeploy yet
# see https://github.com/NVIDIA/TensorRT-LLM/issues/7142
kv_cache_config: KvCacheConfig = Field(
default_factory=lambda **kwargs: KvCacheConfig(copy_on_partial_reuse=False, **kwargs),
description="KV cache config.",
)
max_beam_width: int = Field(
default=1,
description="The maximum beam width. >1 is not supported by AutoDeploy.",
@ -240,12 +224,6 @@ class LlmArgs(DynamicYamlMixInForSettings, TorchLlmArgs, BaseSettings):
default="flashinfer",
description=_shortcut_description("Attention backend to use.", "attn_backend"),
)
free_mem_ratio: float = Field(
default=0.0,
description=_shortcut_description(
"The fraction of available memory to allocate for cache.", "free_mem_ratio"
),
)
compile_backend: str = Field(
default="torch-compile",
description=_shortcut_description(
@ -263,16 +241,8 @@ class LlmArgs(DynamicYamlMixInForSettings, TorchLlmArgs, BaseSettings):
)
### SEQUENCE INTERFACE CONFIG ##################################################################
max_num_tokens: Optional[int] = Field(default=None, description="The maximum number of tokens.")
max_seq_len: int = Field(default=512, ge=1, description="The maximum sequence length.")
max_batch_size: int = Field(default=8, ge=1, description="The maximum batch size.")
attn_page_size: int = Field(
default=64,
ge=1,
description="Page size for attention (tokens_per_block). For triton and torch "
"backends, this should equal max_seq_len. Temporary field until tokens_per_block gets "
"properly passed through.",
)
def model_dump(self, *args, **kwargs):
"""Convert the arguments to a dictionary that can be used as kwargs for the LLM API."""
@ -292,23 +262,6 @@ class LlmArgs(DynamicYamlMixInForSettings, TorchLlmArgs, BaseSettings):
return kwargs
### VALIDATION #################################################################################
@model_validator(mode="after")
# TODO: discuss what to do with this once we fully transition to the new inference optimizer
def update_attn_page_size(self):
# NOTE force attn_page_size to equal max_seq_len for triton backend
if self.transforms.get("insert_cached_attention", {}).get("backend") in [
"triton",
"torch",
]:
self.attn_page_size = self.max_seq_len
# NOTE: (hg) For transformers mode. This is ugly.
if self.transforms.get("transformers_replace_cached_attn", {}).get("backend") in [
"triton",
"torch",
]:
self.attn_page_size = self.max_seq_len
return self
@field_validator("model_factory", mode="after")
@classmethod
def model_factory_exists(cls, value: str) -> str:
@ -358,16 +311,6 @@ class LlmArgs(DynamicYamlMixInForSettings, TorchLlmArgs, BaseSettings):
self.update_transforms_with_shortcuts()
return self
@field_validator("kv_cache_config", mode="after")
@classmethod
def validate_kv_cache_config(cls, kv_cache_config: KvCacheConfig) -> KvCacheConfig:
if kv_cache_config.copy_on_partial_reuse:
kv_cache_config.copy_on_partial_reuse = False
ad_logger.warning(
"copy_on_partial_reuse is not supported by AutoDeploy. Setting it to False."
)
return kv_cache_config
### UTILITY METHODS ############################################################################
def create_factory(self) -> ModelFactory:
"""Create a model factory from the arguments."""

View File

@ -27,10 +27,6 @@ from torch._prims_common import DeviceLikeType
from torch.export import Dim
from torch.fx import GraphModule
from ..custom_ops.attention_interface import CacheConfig
from ..utils.cuda_mem_tracker import get_mem_info_in_mb
from ..utils.logger import ad_logger
DynamicShape = Dict[int, Dim] # indicating the dynamic shape in tensor dimension
@ -211,13 +207,15 @@ class ModelFactory(ABC):
"""Returns the sharding config for this model."""
return self._sharding_config
def get_cache_config(self) -> CacheConfig:
"""Return the cache configuration for the model.
def get_cache_config_updates(self) -> Dict[str, Any]:
"""Return updates for the KVCacheConfig for the model.
Returns:
The cache configuration for the model.
A dictionary of updates for the KVCacheConfig for the model.
Check tensorrt_llm/llmapi/llm_args.py for the KVCacheConfig fields.
"""
return CacheConfig()
return {}
def init_tokenizer(self) -> Optional[Any]:
"""Initialize the tokenizer for the model.
@ -301,22 +299,12 @@ class ModelFactory(ABC):
<SIZE_OF_LARGEST_CHECKPOINT_FILE>
"""
ad_logger.info("Loading and initializing weights.")
free_mem_pre, _ = get_mem_info_in_mb()
ad_logger.info(f"Free memory before loading weights (MB): {free_mem_pre}")
self._to_maybe_random(model, device)
params_size = sum(p.numel() * p.element_size() for p in model.parameters())
total_size_GB = params_size / (1024**3)
ad_logger.info(f"Estimated parameters memory: {total_size_GB:.2f} GB")
if not self.skip_loading_weights:
self.prefetch_checkpoint(force=True)
self._load_checkpoint(model, device)
ad_logger.info("Loading and initializing weights. Done.")
free_mem_post, _ = get_mem_info_in_mb()
ad_logger.info(f"Free memory after loading weights (MB): {free_mem_post}")
@staticmethod
def _to_maybe_random(model: nn.Module, device: DeviceLikeType):
"""A mix of ``model.to(device)`` and random initialization of parameters.

View File

@ -33,7 +33,6 @@ from transformers.utils import (
WEIGHTS_NAME,
)
from ..custom_ops.attention_interface import CacheConfig
from ..utils._config import deep_merge_dicts
from ..utils.logger import ad_logger
from .factory import (
@ -261,18 +260,16 @@ class AutoModelForCausalLMFactory(AutoModelFactory):
return self._quant_config_reader.get_config()
return {}
def get_cache_config(self):
"""Return kv cache dtype configuration."""
def get_cache_config_updates(self):
"""Return kv cache dtype updates."""
if not self._quant_config_reader:
return CacheConfig(dtype=None)
return {}
kv_cache_dtype = self._quant_config_reader.get_config().get("kv_cache_dtype")
torch_dtype = torch.float8_e4m3fn if kv_cache_dtype == "float8_e4m3fn" else None
assert torch_dtype in (torch.float8_e4m3fn, None), (
f"Unsupported dtype: {torch_dtype}. Only torch.float8_e4m3fn is supported."
kv_cache_dtype = self._quant_config_reader.get_config().get("kv_cache_dtype", "auto")
assert kv_cache_dtype in ("fp8", "auto"), (
f"Unsupported dtype: {kv_cache_dtype}. Only fp8 and auto are supported."
)
return CacheConfig(dtype=torch_dtype)
return {"dtype": kv_cache_dtype}
def init_tokenizer(self) -> Optional[Any]:
"""Initialize the tokenizer—either a custom name or the model's default."""

View File

@ -106,7 +106,7 @@ class ModelOPTQuantConfigReader(QuantConfigReader):
if kv_algo:
if kv_algo != "FP8":
raise ValueError(f"KV cache quantization format {kv_algo} not supported.")
quant_config["kv_cache_dtype"] = "float8_e4m3fn"
quant_config["kv_cache_dtype"] = "fp8"
self._quant_config = quant_config

View File

@ -41,6 +41,7 @@ from tensorrt_llm._utils import nvtx_range
from tensorrt_llm.llmapi.llm_args import (
ContextChunkingPolicy,
EagleDecodingConfig,
KvCacheConfig,
LoadFormat,
SamplerType,
TorchLlmArgs,
@ -48,9 +49,9 @@ from tensorrt_llm.llmapi.llm_args import (
from tensorrt_llm.llmapi.tokenizer import TokenizerBase
from ...._utils import get_free_port, mpi_rank, mpi_world_size
from ....bindings.internal.batch_manager import CacheType
from ....mapping import Mapping
from ...distributed import Distributed
from ...pyexecutor.mamba_cache_manager import MambaHybridCacheManager
from ...pyexecutor.model_engine import ModelEngine, PyTorchModelEngine
from ...pyexecutor.py_executor import PyExecutor
from ...pyexecutor.resource_manager import (
@ -67,7 +68,6 @@ from ...pyexecutor.scheduler import (
ScheduledRequests,
SimpleScheduler,
)
from ..custom_ops.attention_interface import SequenceInfo
from ..distributed.common import initialize_or_skip
from ..llm_args import LlmArgs
from ..transform.optimizer import InferenceOptimizer
@ -82,44 +82,6 @@ class ReportingInfo:
enable_iter_req_stats: bool = False
class _CacheManagerWithFakePool(KVCacheManager):
"""We use the default KVCacheManager but with a fake pool by setting head_dim=0.
The actual cache pools are managed by auto_deploy layerwise cache pools.
"""
def __init__(
self,
kv_cache_config,
num_blocks: int,
tokens_per_block: int,
max_seq_len: int,
max_batch_size: int,
):
self.num_blocks = num_blocks
super().__init__(
kv_cache_config=kv_cache_config,
kv_cache_type=CacheType.SELF,
num_layers=1,
num_kv_heads=1,
head_dim=0,
tokens_per_block=tokens_per_block,
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
mapping=Mapping(),
)
def calculate_max_num_blocks(
self, kv_cache_config, head_dim, tokens_per_block, mapping, dtype, kv_factor
) -> Tuple[int, int]:
"""Calculate the maximum number of blocks needed for the cache."""
# TODO: this is VERY hacky... Ideally, we want to compute the number of blocks
# just like in the original implementation. However, let's wait for the layer-wise attention
# implementation before over-optimizing the function here
ad_logger.info("Using fake cache manager with head_dim=0 and num pages:", self.num_blocks)
return self.num_blocks, 0
class ADHiddenStateManager(Eagle3ResourceManager):
def __init__(
self,
@ -275,6 +237,7 @@ def construct_draft_llm_args(
def create_draft_kv_cache_manager_maybe(
draft_model_engine: Optional[PyTorchModelEngine],
ad_config: LlmArgs,
kv_cache_config_tuned: KvCacheConfig,
dist_mapping: Mapping,
) -> Optional[KVCacheManager]:
if draft_model_engine is None or not draft_model_engine.model.model_config.is_generation:
@ -287,8 +250,8 @@ def create_draft_kv_cache_manager_maybe(
model_engine=draft_model_engine,
kv_cache_manager_cls=kv_cache_manager_cls,
mapping=dist_mapping,
kv_cache_config=ad_config.kv_cache_config,
tokens_per_block=ad_config.attn_page_size,
kv_cache_config=kv_cache_config_tuned,
tokens_per_block=kv_cache_config_tuned.tokens_per_block,
max_seq_len=ad_config.max_seq_len,
max_batch_size=ad_config.max_batch_size,
spec_config=ad_config.speculative_config,
@ -321,19 +284,33 @@ def _generate_dummy_request(
ResourceManagerType.SPEC_RESOURCE_MANAGER
)
# check if we have a free slot available and free page available
if not slot_manager.slot_manager.free_slots or kv_cache_manager.get_num_free_blocks() == 0:
# check if it's a hybrid kv-cache manager
is_hybrid_cache = isinstance(kv_cache_manager, MambaHybridCacheManager)
# check if we have a free page and free state available
if not kv_cache_manager.get_num_free_blocks():
return None
if is_hybrid_cache and not kv_cache_manager.mamba_cache_free_blocks:
return None
# generate a dummy request
dummy_request = kv_cache_manager.add_dummy_requests([request_id], **request_kwargs)[0]
dummy_request.is_cuda_graph_dummy = True
# generate a dummy scheduled requests object
dummy_scheduled_requests = ScheduledRequests()
dummy_scheduled_requests.generation_requests.append(dummy_request)
# if it's a hybrid kv-cache manager, we need to manually call prepare_resources again (not done
# in add_dummy_requests)
if is_hybrid_cache:
kv_cache_manager.prepare_resources(dummy_scheduled_requests)
# add to spec resource manager
if spec_res_mgr:
spec_res_mgr.add_dummy_requests([request_id])
# TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/9883 clean up this hack
# NOTE: hack to avoid blocking a slot for the dummy request
dummy_request.seq_slot = slot_manager.get_max_resource_count()
dummy_request.py_seq_slot = dummy_request.seq_slot
@ -448,11 +425,6 @@ class ADEngine(ModelEngine):
):
"""Build the ADEngine using the LlmArgs that gets passed through from the LLM."""
max_batch_size = ad_config.max_batch_size
max_seq_len = ad_config.max_seq_len
attn_page_size = ad_config.attn_page_size
max_num_tokens = ad_config.max_num_tokens
# update device to contain the current default device if it's in cuda
device = torch.device(ad_config.device)
if device.type == "cuda" and device.index is None:
@ -461,14 +433,17 @@ class ADEngine(ModelEngine):
factory = ad_config.create_factory()
# initialize seq info object
seq_info = SequenceInfo(
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
page_size=attn_page_size,
max_num_tokens=max_num_tokens,
# Initialize CachedSequenceInterface - it will create SequenceInfo internally
# using tokens_per_block from kv_cache_config
cache_seq_interface = CachedSequenceInterface(
max_seq_len=ad_config.max_seq_len,
max_batch_size=ad_config.max_batch_size,
device=device,
kv_cache_config=ad_config.kv_cache_config,
max_num_tokens=ad_config.max_num_tokens,
vocab_size_padded=factory.vocab_size_padded,
)
reporting_info = ReportingInfo(
print_log=False,
enable_iter_perf_stats=ad_config.enable_iter_perf_stats,
@ -483,8 +458,7 @@ class ADEngine(ModelEngine):
# construct engine
return cls(
build_and_optimize,
seq_info,
device,
cache_seq_interface,
ad_config=ad_config,
mapping=mapping,
dist=dist,
@ -495,14 +469,21 @@ class ADEngine(ModelEngine):
def __init__(
self,
get_inference_model: GetInferenceModel,
seq_info: SequenceInfo,
device: DeviceLikeType,
cache_seq_interface: CachedSequenceInterface,
ad_config: Optional[LlmArgs] = None,
mapping: Optional[Mapping] = None,
dist: Optional[Distributed] = None,
reporting_info: ReportingInfo = ReportingInfo(),
) -> None:
"""Initialize the engine with model and sequence information."""
"""Initialize the engine with model and CachedSequenceInterface.
Args:
get_inference_model: Callable that builds the inference model.
cache_seq_interface: The CachedSequenceInterface containing sequence and cache config.
ad_config: Optional LLM configuration.
mapping: Optional distributed mapping configuration.
reporting_info: Reporting configuration for logging.
"""
# NOTE (lucaslie): create a fake Namespace to satisfy PyExecutor requirements...
# This is not correctly declared in the base ModelEngine class though...
self.llm_args = SimpleNamespace()
@ -514,8 +495,8 @@ class ADEngine(ModelEngine):
self.llm_args.batch_wait_timeout_ms = 0
self.llm_args.batch_wait_timeout_iters = 0
self.llm_args.batch_wait_max_tokens_ratio = 0.0
self.llm_args.max_num_tokens = seq_info.max_num_tokens
self.llm_args.max_seq_len = seq_info.max_seq_len
self.llm_args.max_num_tokens = cache_seq_interface.info.max_num_tokens
self.llm_args.max_seq_len = cache_seq_interface.info.max_seq_len
self.iter_counter = 0
self.iter_states = {}
@ -546,13 +527,10 @@ class ADEngine(ModelEngine):
)
# For compatibility with PyTorchModelEngine utilities
self.batch_size = seq_info.max_batch_size
self.batch_size = cache_seq_interface.info.max_batch_size
# construct cache sequence interface
self.cache_seq_interface = CachedSequenceInterface(
sequence_info=seq_info,
device=device,
)
# Store the cache sequence interface
self.cache_seq_interface = cache_seq_interface
# build model
self.model = get_inference_model(self.cache_seq_interface)
@ -628,12 +606,18 @@ class ADEngine(ModelEngine):
# gather indices for logits
logits_gather_indices: List[int] = []
page_size = self.cache_seq_interface.info.page_size
page_size = kv_cache_manager.tokens_per_block
dummy_token = -1
num_ctx_requests = len(context_requests)
num_ctx_tokens = 0
num_generation_tokens = 0
# Helper to get slot index - use mamba_cache_index if available (MambaHybridCacheManager)
def _get_slot_idx(request) -> int:
if hasattr(kv_cache_manager, "mamba_cache_index"):
return kv_cache_manager.mamba_cache_index[request.py_request_id]
return request.seq_slot
# look at context requests first
for request in context_requests:
# store input ids and pos of first token in sequence
@ -669,8 +653,8 @@ class ADEngine(ModelEngine):
position_ids.append(list(range(input_pos[-1], seq_len_with_cache[-1])))
# store seq slot idx
slot_idx.append(request.seq_slot)
# store seq slot idx (use mamba_cache_index if available)
slot_idx.append(_get_slot_idx(request))
use_initial_states.append(input_pos[-1] > 0)
# store extra arguments
@ -749,7 +733,8 @@ class ADEngine(ModelEngine):
num_generation_tokens += 1 + get_draft_token_length(request)
request.py_batch_idx = request.seq_slot
slot_idx.append(request.seq_slot)
# store seq slot idx (use mamba_cache_index if available)
slot_idx.append(_get_slot_idx(request))
use_initial_states.append(input_pos[-1] > 0)
seq_len.append(len(input_ids[-1]))
@ -941,11 +926,11 @@ def create_draft_model_engine_maybe(
draft_spec_config = copy.copy(spec_config)
kv_cache_config = ad_config.kv_cache_config
kv_cache_config_tuned = target_engine.cache_seq_interface.kv_cache_config_tuned
attn_runtime_features = AttentionRuntimeFeatures(
chunked_prefill=ad_config.enable_chunked_prefill,
cache_reuse=kv_cache_config.enable_block_reuse,
cache_reuse=kv_cache_config_tuned.enable_block_reuse,
has_speculative_draft_tokens=has_spec_drafter,
chunk_size=target_engine.llm_args.max_num_tokens,
)
@ -1108,40 +1093,15 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer
else None
)
# check kvcache config for partial block reuse
# TODO: copy_on_partial_reuse is not supported yet, see
# https://github.com/NVIDIA/TensorRT-LLM/issues/7142 for more details.
enable_block_reuse = ad_config.kv_cache_config.enable_block_reuse
enable_partial_reuse = ad_config.kv_cache_config.enable_partial_reuse
copy_on_partial_reuse = ad_config.kv_cache_config.copy_on_partial_reuse
if enable_block_reuse and enable_partial_reuse and copy_on_partial_reuse:
raise RuntimeError(
f"partial block reuse with {copy_on_partial_reuse=} set to True is NOT supported"
" in AutoDeploy. Please set it to False via the kv_cache_config.copy_on_partial_reuse "
"field in tensorrt_llm._torch.auto_deploy.llm_args.LlmArgs."
)
# TODO: detect whether SSM layer is present in the model and raise an error or disable block
# reuse with a warning --> see https://github.com/NVIDIA/TensorRT-LLM/issues/7142. For now, we
# just emit a general warning.
if enable_block_reuse:
ad_logger.warning(
f"{enable_block_reuse=} is enabled. Note that this is not supported for SSM layers and"
" may lead to incorrect results if the model contains SSM layers."
)
# resource managers
kv_cache_manager = _CacheManagerWithFakePool(
ad_config.kv_cache_config,
num_blocks=engine.cache_seq_interface.info.num_pages,
tokens_per_block=ad_config.attn_page_size,
max_seq_len=ad_config.max_seq_len,
max_batch_size=ad_config.max_batch_size,
)
# KVCacheManager is now created and managed by CachedSequenceInterface during the
# initialize_cache/resize_kv_cache transform pipeline. Get it from the interface.
kv_cache_manager = engine.cache_seq_interface.kv_cache_manager
kv_cache_config_tuned = engine.cache_seq_interface.kv_cache_config_tuned
seq_slot_manager = SeqSlotManager(max_num_sequences=max_num_sequences)
draft_kv_cache_manager = create_draft_kv_cache_manager_maybe(
draft_model_engine, ad_config, dist_mapping
draft_model_engine, ad_config, kv_cache_config_tuned, dist_mapping
)
resource_manager = ResourceManager(
@ -1160,7 +1120,7 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer
# Chunked prefill
if ad_config.enable_chunked_prefill:
chunk_unit_size = ad_config.attn_page_size
chunk_unit_size = kv_cache_config_tuned.tokens_per_block
chunking_policy = ContextChunkingPolicy.FIRST_COME_FIRST_SERVED
ctx_chunk_config: Tuple[StrEnum, int] = (chunking_policy, chunk_unit_size)
else:

View File

@ -71,14 +71,15 @@ class DemoEngine(ADEngine):
currently available token slots in the assigned pages and assign a new, previously
unassigned page if needed.
"""
si = self.cache_seq_interface.info
page_assignments = si.page_assignments
page_assignments = self.cache_seq_interface.info.page_assignments
num_pages = self.cache_seq_interface.kv_cache_manager.blocks_in_primary_pool
tokens_per_block = self.cache_seq_interface.kv_cache_manager.tokens_per_block
free_pages = set(range(si.num_pages)) - {i for pages in page_assignments for i in pages}
free_pages = set(range(num_pages)) - {i for pages in page_assignments for i in pages}
updated_assignments = []
for t_l, pages in zip(total_lens, page_assignments):
extra_tokens = t_l - len(pages) * si.page_size
num_extra_pages = (extra_tokens // si.page_size) + (extra_tokens > 0)
extra_tokens = t_l - len(pages) * tokens_per_block
num_extra_pages = (extra_tokens // tokens_per_block) + (extra_tokens > 0)
updated_assignments.append(pages + [free_pages.pop() for _ in range(num_extra_pages)])
return updated_assignments

View File

@ -1,95 +1,473 @@
from typing import Callable, Dict, List, Optional, Tuple, final
import copy
import functools
import math
from typing import Callable, Dict, Optional, Tuple, Union, final
import torch
import torch.nn as nn
from torch._prims_common import DeviceLikeType
from ..custom_ops.attention_interface import GetCacheCallable, SequenceInfo
import tensorrt_llm.bindings
from tensorrt_llm.llmapi.llm_args import KvCacheConfig
from tensorrt_llm.mapping import Mapping
from ...pyexecutor.mamba_cache_manager import MambaHybridCacheManager
from ...pyexecutor.resource_manager import KVCacheManager
from ..custom_ops.attention_interface import (
PagedResourceHandler,
ResourceHandler,
ResourceHandlerDict,
SequenceInfo,
StateResourceHandler,
)
from ..distributed.common import all_gather_object, get_world_size
from ..distributed.common import is_initialized as is_distributed_initialized
from ..utils.cuda_mem_tracker import bytes_to, get_mem_info
from ..utils.logger import ad_logger
CacheTypeCpp = tensorrt_llm.bindings.internal.batch_manager.CacheType
DataType = tensorrt_llm.bindings.DataType
def with_pre_callback(method, callback):
"""Wrap method to call callback before the original method."""
@functools.wraps(method)
def wrapper(*args, **kwargs):
callback()
return method(*args, **kwargs)
return wrapper
@final
class CachedSequenceInterface:
"""An interface responsible for maintaining information about sequences and their caches."""
"""An interface responsible for maintaining information about sequences and their caches.
This class is the single source of truth for sequence and cache configuration. It creates
SequenceInfo internally, ensuring that tokens_per_block and other fields from KvCacheConfig
are always consistent.
"""
def __init__(
self, sequence_info: SequenceInfo, device: Optional[DeviceLikeType] = None
self,
max_seq_len: int,
max_batch_size: int,
device: Optional[DeviceLikeType] = None,
kv_cache_config: Optional[KvCacheConfig] = None,
max_num_tokens: Optional[int] = None,
vocab_size_padded: Optional[int] = None,
) -> None:
"""Initialize the CachedSequenceInterface.
Args:
max_seq_len: Maximum sequence length including input and generated tokens.
max_batch_size: Maximum number of sequences (requests) that can be processed.
device: Target device for tensors. Defaults to "cuda".
kv_cache_config: KV cache configuration. If None, uses default KvCacheConfig.
max_num_tokens: Maximum total tokens across all sequences. If None, computed from
max_seq_len and max_batch_size.
vocab_size_padded: Padded vocabulary size of the model.
"""
# TODO (lucaslie): this is somewhat circular/confusing. Here `device` denotes the desired
# device and not the actual device unlike, e.g., in SequenceInfo. We rely on the attribute
# here to read the desired device across the inference optimizer pipeline. We should ideally
# think about a better way to handle this,
# see https://github.com/NVIDIA/TensorRT-LLM/issues/8371
self.device = device or "cuda"
self.info = sequence_info
self._cache_initializers: Dict[str, GetCacheCallable] = {}
# Initialize kv_cache_config first since SequenceInfo needs tokens_per_block from it
self._kv_cache_config_original: KvCacheConfig = kv_cache_config or KvCacheConfig()
self._kv_cache_config_tuned: Optional[KvCacheConfig] = None
# Create SequenceInfo internally, using tokens_per_block from kv_cache_config
self.info = SequenceInfo(
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
tokens_per_block=self._kv_cache_config_original.tokens_per_block,
max_num_tokens=max_num_tokens,
vocab_size_padded=vocab_size_padded,
)
self._resource_lookup: ResourceHandlerDict = {}
self._caches: Dict[str, torch.Tensor] = {}
# KVCacheManager (or MambaHybridCacheManager) for managed resources
self._kv_cache_manager: Optional[Union[KVCacheManager, MambaHybridCacheManager]] = None
# Ordered dicts tracking resource handlers by type
self._paged_cache_order: ResourceHandlerDict = {} # Paged resources (kv caches)
self._state_resource_order: ResourceHandlerDict = {} # State resources (ssm states)
@property
def args(self) -> Tuple[torch.Tensor, ...]:
"""Return all the graph arguments owned by this interface."""
return (*self.info.args, *self._caches.values())
return tuple(self.named_args.values())
@property
def named_args(self) -> Dict[str, torch.Tensor]:
"""Return all the named arguments owned by this interface."""
return {**self.info.named_args, **self._caches}
@property
def all_future_arg_names(self) -> List[str]:
"""Return all the argument names owned by this interface including uninitialized caches."""
return list(self.info.named_args.keys()) + list(self._cache_initializers.keys())
def to(self, *args, **kwargs) -> None:
self.info.to(*args, **kwargs)
if self._caches:
for cache in self._caches.values():
# Only move locally-allocated caches (paged/state caches are managed by cache managers)
for name, cache in self._caches.items():
if name not in self._paged_cache_order and name not in self._state_resource_order:
cache.to(*args, **kwargs)
def add_cache(self, name: str, get_cache: GetCacheCallable) -> None:
"""Add a cache initializer to the cache interface."""
self._cache_initializers[name] = get_cache
def update_kv_cache_config(self, **kwargs) -> None:
"""Update the KVCacheConfig with the given kwargs."""
for k, v in kwargs.items():
if k in type(self._kv_cache_config_original).model_fields:
setattr(self._kv_cache_config_original, k, v)
else:
raise ValueError(f"Invalid KVCacheConfig field: {k}")
def initialize_caches(self) -> int:
"""Initialize caches using the cache initializers."""
assert not self._caches, "Caches already initialized."
self.info.to(self.device)
self._caches = {
name: get_cache(self.info) for name, get_cache in self._cache_initializers.items()
def add_resource(self, name: str, resource_handler: ResourceHandler) -> None:
"""Add a resource handler to the cache interface."""
self._resource_lookup[name] = resource_handler
def _create_kv_cache_manager(self, max_tokens: Optional[int] = None) -> int:
"""Create KVCacheManager or MambaHybridCacheManager with multi-layer byte-level params.
This uses a multi-layer approach with byte-level abstraction:
- Paged resources: Each resource gets its own layer in KVCacheManager with
num_kv_heads=bytes_per_token for that resource, head_dim=1.
- State resources: Each resource gets its own layer in MambaCacheManager with
head_dim=bytes_per_slot for that resource.
Each layer's cache is contiguous, avoiding byte-offset slicing within layers.
When state resources exist, MambaHybridCacheManager is used to manage both.
Important NOTE on contiguity of managed resources:
- We only guarantee contiguity for an individual page or an individual state slot.
- Outside of these individual pages/slots, resources are NOT guaranteed to be contiguous.
Args:
max_tokens: Maximum number of tokens to allocate. If provided, it will use the min value
between this value and max_tokens in kv_cache_config.
Returns:
The final number of tokens that can be cached in the KVCacheManager.
NOTE: this number may differ from the provided ``max_tokens`` arg for two reasons:
1. the final number of tokens is synced (min) across ranks
2. rounding for getting a multiple of tokens_per_block
"""
# Build per-layer num_kv_heads list for paged resources
# Each paged resource becomes one "layer" with num_kv_heads = bytes_per_token
num_kv_heads_per_layer = [
math.prod(h.token_shape) * h.dtype.itemsize for h in self._paged_cache_order.values()
]
# Calculate total bytes per slot for state resources (modeled as single layer)
cumulative_bytes_per_state = [0]
for name, handler in self._state_resource_order.items():
byte_size = math.prod(handler.state_shape) * handler.dtype.itemsize
cumulative_bytes_per_state.append(cumulative_bytes_per_state[-1] + byte_size)
# Make a deep copy of the kv_cache_config to avoid modifying the original object
kv_cache_config = copy.deepcopy(self._kv_cache_config_original)
# Disable copy_on_partial_reuse
# TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/10966
if kv_cache_config.copy_on_partial_reuse:
kv_cache_config.copy_on_partial_reuse = False
ad_logger.info("Disabling copy_on_partial_reuse for AutoDeploy backend.")
# Update kv_cache_config based on max_tokens if provided
if max_tokens is not None:
# sync max_tokens across ranks
if is_distributed_initialized():
max_tokens_gathered = [None] * get_world_size()
all_gather_object(max_tokens_gathered, max_tokens)
max_tokens = min(max_tokens_gathered)
kv_cache_config.free_gpu_memory_fraction = None
kv_cache_config.max_tokens = min(kv_cache_config.max_tokens or max_tokens, max_tokens)
# Check if we should disable block reuse
if kv_cache_config.enable_block_reuse and not self.is_paged():
kv_cache_config.enable_block_reuse = False
ad_logger.info(f"Setting {kv_cache_config.enable_block_reuse=} for non-paged models.")
# Make sure to set free_gpu_memory_fraction to None if set to 0.0
# NOTE: KVCacheConfig validator enforces that free_gpu_memory_fraction must be between 0.0
# and 1.0 but we allow 0.0 to be set to disable resizing (corresponding to None in the
# manager).
if kv_cache_config.free_gpu_memory_fraction == 0.0:
kv_cache_config.free_gpu_memory_fraction = None
# Common KV cache parameters
kv_cache_kwargs = {
"kv_cache_config": kv_cache_config,
"kv_cache_type": CacheTypeCpp.SELFKONLY, # kv_factor=1, treat K, V separately
"num_layers": len(self._paged_cache_order), # correct num layers
"num_kv_heads": num_kv_heads_per_layer, # per-layer bytes_per_token
"head_dim": 1, # all bytes in num_kv_heads
"tokens_per_block": kv_cache_config.tokens_per_block,
"max_seq_len": self.info.max_seq_len,
"max_batch_size": self.info.max_batch_size,
"mapping": Mapping(),
# NOTE (lucaslie): this is the only 1-byte dtype currently supported by the
# KVCacheManager. Ideally, we would use the typical uint8 dtype for byte-level
# abstraction, but this is not supported.
"dtype": DataType.FP8, # 1-byte dtype for byte-level abstraction
"layer_mask": None,
# NOTE (lucaslie): we can always run with False here since when we are estimating, we
# are explicitly setting the max_tokens in which case it's okay to use False here since
# we don't rely on free_gpu_memory_fraction inside the KVCacheManager. This is similar
# to _torch.pyexecutor._util.KVCacheCreator, which explicitly estimates the max_tokens
# outside of the KVCacheManager.
"is_estimating_kv_cache": False,
}
# update args if we are just doing a dummy cache manager
if not len(self._paged_cache_order):
kv_cache_kwargs.update(
{
"num_layers": 1,
"num_kv_heads": 1,
"head_dim": 1,
}
)
if self._state_resource_order:
# NOTE: +1 for cuda graph padding
kv_cache_kwargs["max_batch_size"] = self.info.max_num_state_slots
self._kv_cache_manager = MambaHybridCacheManager(
# Mamba params for single-layer byte buffer
mamba_d_state=1,
mamba_d_conv=1, # conv_states will have shape [..., 0] (empty)
mamba_num_heads=1,
mamba_n_groups=1,
mamba_head_dim=cumulative_bytes_per_state[-1], # Total bytes per slot
mamba_num_layers=1, # Single layer
mamba_layer_mask=None, # Single enabled layer
mamba_cache_dtype=torch.uint8, # Byte-level
mamba_ssm_cache_dtype=torch.uint8, # Byte-level
# KV cache params
**kv_cache_kwargs,
)
else:
# No state resources - use pure KVCacheManager
self._kv_cache_manager = KVCacheManager(**kv_cache_kwargs)
# store the tuned kv_cache_config
self._kv_cache_config_tuned = kv_cache_config
# Ensure cache_loc capacity is sufficient for the new KVCacheManager
blocks_in_primary_pool = self._kv_cache_manager.blocks_in_primary_pool
tokens_per_block = self._kv_cache_manager.tokens_per_block
self.info.estimate_cache_loc_capacity(blocks_in_primary_pool)
# Create paged resource views from per-layer buffers
for layer_idx, (name, handler) in enumerate(self._paged_cache_order.items()):
view = self._kv_cache_manager.get_buffers(layer_idx, kv_layout="NHD")
view = view.view(blocks_in_primary_pool, tokens_per_block, -1).view(handler.dtype)
view = view.view(blocks_in_primary_pool, tokens_per_block, *handler.token_shape)
# Sanity check on contiguity of individual pages
view_one_page = view[0]
assert view_one_page.is_contiguous(), f"Per-page cache for {name} is not contiguous"
self._caches[name] = view
for layer_idx, (name, handler) in enumerate(self._state_resource_order.items()):
num_states = len(self._kv_cache_manager.state_indices)
# Get the single-layer ssm_states buffer
# ssm_states shape: [1, num_states, 1, total_bytes_per_slot, 1]
ssm_buffer = self._kv_cache_manager.get_ssm_states(0)
# Flatten to [max_batch, total_bytes_per_slot_for_all_layers]
ssm_buffer = ssm_buffer.view(num_states, -1)
offset_start = cumulative_bytes_per_state[layer_idx]
offset_end = cumulative_bytes_per_state[layer_idx + 1]
# Slice at byte offset, reinterpret dtype, reshape
view = ssm_buffer[:, offset_start:offset_end]
view = view.view(handler.dtype)
view = view.view(num_states, *handler.state_shape)
# Sanity check on contiguity of individual state slots
assert view[0].is_contiguous(), f"Per-slot state for {name} cache is not contiguous"
self._caches[name] = view
# Patch shutdown to clear cache views before pool release
self._kv_cache_manager.shutdown = with_pre_callback(
self._kv_cache_manager.shutdown,
self._clear_cache_views,
)
max_resource_count = self._kv_cache_manager.get_max_resource_count()
max_tokens_final = max_resource_count * self._kv_cache_manager.tokens_per_block
return max_tokens_final
def initialize_resources(self) -> int:
"""Initialize resources - paged/state caches via cache managers, others separately.
Paged resources are managed by KVCacheManager (or the KV portion of MambaHybridCacheManager).
State resources are managed by the Mamba portion of MambaHybridCacheManager.
Other resources are allocated locally as a fallback.
Returns:
The number of caches initialized.
"""
assert not self._caches and not self._paged_cache_order, "Caches already initialized."
self.info.to(self.device)
# Separate resources by type
for name, handler in self._resource_lookup.items():
if isinstance(handler, PagedResourceHandler):
self._paged_cache_order[name] = handler
self._caches[name] = None # Will be set by _create_kv_cache_manager
elif isinstance(handler, StateResourceHandler):
self._state_resource_order[name] = handler
self._caches[name] = None # Will be set by _create_kv_cache_manager
else:
# Unknown handler type - allocate locally (fallback)
self._caches[name] = handler.allocate(self.info)
# Create unified cache manager (handles both paged and state resources)
if self.needs_resize() or self._requires_token_estimate():
max_tokens_estimate = self.info.estimate_cache_tokens_per_forward()
else:
# if we don't need a resize, we will just use the original settings in kv_cache_config
# instead of passing in an overwrite here.
max_tokens_estimate = None
self._create_kv_cache_manager(max_tokens=max_tokens_estimate)
return len(self._caches)
def current_cache_size_bytes(self) -> int:
"""Calculate and return the total size of all caches in bytes."""
total_size = 0
for name, cache in self._caches.items():
# this hack is needed since _caches also contains global buffers such as freqs_cis.
if "cache" in name:
total_size += cache.element_size() * cache.numel()
return total_size
def is_paged(self) -> bool:
"""Return True if all resources are paged and part of the KVCacheManager."""
return set(self._paged_cache_order.keys()) == set(self._resource_lookup.keys())
def current_kv_cache_size_bytes(self) -> int:
"""Return size in bytes of KV caches only (k_cache_*, v_cache_*).
def _requires_token_estimate(self) -> bool:
"""Check if our kv_cache_config requires."""
return (
self._kv_cache_config_original.free_gpu_memory_fraction in [None, 0.0]
and self._kv_cache_config_original.max_tokens is None
)
Excludes SSM/conv/etc. which do not scale with num_pages.
def needs_resize(self) -> bool:
"""Check if we need a resize or not."""
has_paged = bool(self._paged_cache_order)
return has_paged and self._kv_cache_config_original.free_gpu_memory_fraction not in [
None,
0.0,
]
def resize_kv_cache_manager(self, mem_exclude: int = 0) -> None:
"""Shutdown existing KVCacheManager and create new one with optimal capacity.
Args:
mem_exclude: Extra memory to exclude from the calculation of optimal capacity.
This is in bytes and typically the memory reserved for the forward pass.
This implements the two-phase approach: after running a forward pass during estimation
to allocate intermediate memory, call this method to recreate the KVCacheManager.
The new manager will compute optimal capacity based on current free GPU memory
via calculate_max_num_blocks.
"""
total_size = 0
for name, cache in self._caches.items():
if name.startswith("k_cache_") or name.startswith("v_cache_"):
total_size += cache.element_size() * cache.numel()
return total_size
if not self.needs_resize():
return
def resize_cache(self, new_num_pages: int):
"""Resize the cache to the new number of pages."""
# TODO: We should do some sanity check on the new number of pages.
self.info.num_pages = new_num_pages
for name, cache in self._caches.items():
# We assume cache is a tensor of shape (max_batch_size, page_size, n_heads, head_dim)
# TODO: cache resize should ideally be handled via a callback to the AttentionDescriptor
# to avoid hard-coding any assumptions about the cache shape or its "pagedness"
if "k_cache" in name or "v_cache" in name:
current_shape = cache.shape
new_shape = (new_num_pages, *current_shape[1:])
cache.resize_(new_shape)
# get per-token cache size for resizable resources
paged_cache_bytes_per_token = self._kv_cache_manager.get_cache_bytes_per_token()
# get total cache size of state resources that cannot be resized
# NOTE: this does NOT include resources handled OUTSIDE of the KVCacheManager or
# MambaHybridCacheManager. Those will persistent and will be accounted for via free_mem even
# after the initialize kv_cache_manager is shutdown.
state_cache_bytes_total = sum(
cache.numel() * cache.element_size()
for name, cache in self._caches.items()
if name in self._state_resource_order
)
# get unmanaged cache size
unmanaged_cache_bytes_total = sum(
cache.numel() * cache.element_size()
for name, cache in self._caches.items()
if name not in self._paged_cache_order and name not in self._state_resource_order
)
# Shutdown existing KVCacheManager to free memory
self._kv_cache_manager.shutdown()
# Get current free GPU memory (roughly includes model weights + non-managed resources)
_, free_mem, *_ = get_mem_info(empty_cache=True)
# Compute available memory for the KVCacheManager
# NOTE: free_mem was obtained AFTER shutdown of initial KVCacheManager - hence it accounts
# for unmanaged resources but it does NOT account for state resources since those were
# freed as part of the shutdown.
free_gpu_memory_fraction = self._kv_cache_config_original.free_gpu_memory_fraction
mem_for_paged_optimal = (
free_mem - state_cache_bytes_total - mem_exclude
) * free_gpu_memory_fraction
# Check how many tokens we can fit into the paged cache
max_tokens_optimal = int(mem_for_paged_optimal // paged_cache_bytes_per_token)
# Create new KVCacheManager with final capacity
max_tokens_final = self._create_kv_cache_manager(max_tokens=max_tokens_optimal)
# Log resulting memory information
mem_info = [
f"free_mem={bytes_to(free_mem, unit='GB'):.2f}GB",
f"free_gpu_memory_fraction={free_gpu_memory_fraction}",
f"mem_exclude={bytes_to(mem_exclude, unit='GB'):.2f}GB",
f"mem_exclude_for_state={bytes_to(state_cache_bytes_total, unit='GB'):.2f}GB",
f"mem_for_paged_optimal={bytes_to(mem_for_paged_optimal, unit='GB'):.2f}GB",
]
total_cache_bytes = (
mem_for_paged_optimal + state_cache_bytes_total + unmanaged_cache_bytes_total
)
mem_cache_info = [
f"Max Tokens={max_tokens_final}",
f"Paged={bytes_to(mem_for_paged_optimal, unit='GB'):.2f}GB",
f"State={bytes_to(state_cache_bytes_total, unit='GB'):.2f}GB",
f"Unmanaged={bytes_to(unmanaged_cache_bytes_total, unit='GB'):.2f}GB",
f"Total={bytes_to(total_cache_bytes, unit='GB'):.2f}GB",
]
ad_logger.info(f"Mem info for resize: {' | '.join(mem_info)}")
ad_logger.info(f"Final Cache Mem: {' | '.join(mem_cache_info)}")
@property
def kv_cache_manager(self) -> Optional[KVCacheManager]:
"""Return the KVCacheManager managing paged resources, or None if not initialized."""
assert self._kv_cache_manager is not None, "KVCacheManager not initialized."
return self._kv_cache_manager
@property
def kv_cache_config_tuned(self) -> KvCacheConfig:
"""Return the KVCacheConfig tuned for the KVCacheManager."""
assert None not in [self._kv_cache_manager, self._kv_cache_config_tuned], (
"KVCacheManager not initialized."
)
return self._kv_cache_config_tuned
@property
def kv_cache_config(self) -> KvCacheConfig:
"""Return the original KVCacheConfig as passed in."""
return self._kv_cache_config_original
def _clear_cache_views(self) -> None:
"""Set paged and state cache views to None before pool release."""
self._kv_cache_config_tuned = None
for name in self._paged_cache_order:
self._caches[name] = None
for name in self._state_resource_order:
self._caches[name] = None
def shutdown(self) -> None:
"""Shutdown and release all resources."""
if self._kv_cache_manager is not None:
self._kv_cache_manager.shutdown()
self._kv_cache_config_tuned = None
self._caches.clear()
GetInferenceModel = Callable[[CachedSequenceInterface], nn.Module]

View File

@ -5,7 +5,8 @@ This module defines the base classes and interfaces for all transforms.
import time
from abc import ABC
from contextlib import nullcontext
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from enum import Enum
from functools import total_ordering, wraps
from typing import Any, Callable, Dict, Mapping, Tuple, Type, Union, final
@ -24,8 +25,60 @@ from ..utils._graph import (
placeholders_on_meta,
run_shape_prop,
)
from ..utils.cuda_mem_tracker import get_mem_info
from ..utils.logger import ad_logger
# ANSI color codes for log formatting (set to False to disable colors)
# NOTE: colors disabled by default to make logging in CI/CD pipelines easier to read
_ENABLE_LOG_COLORS = False
class _Colors:
RESET = "\033[0m" if _ENABLE_LOG_COLORS else ""
BOLD = "\033[1m" if _ENABLE_LOG_COLORS else ""
DIM = "\033[2m" if _ENABLE_LOG_COLORS else ""
CYAN = "\033[36m" if _ENABLE_LOG_COLORS else ""
MAGENTA = "\033[35m" if _ENABLE_LOG_COLORS else ""
GREEN = "\033[32m" if _ENABLE_LOG_COLORS else ""
YELLOW = "\033[33m" if _ENABLE_LOG_COLORS else ""
ORANGE = "\033[38;5;208m" if _ENABLE_LOG_COLORS else ""
@dataclass
class MemStats:
"""Memory statistics snapshot for tracking CUDA memory usage."""
tot: float
free: float
resv: float
alloc: float
frag: float
def diff(self, other: "MemStats") -> "MemStats":
"""Calculate the difference (self - other)."""
return MemStats(
tot=self.tot - other.tot,
free=self.free - other.free,
resv=self.resv - other.resv,
alloc=self.alloc - other.alloc,
frag=self.frag - other.frag,
)
def to_dict(self) -> Dict[str, float]:
"""Convert to dictionary for serialization."""
return {
"tot": self.tot,
"free": self.free,
"resv": self.resv,
"alloc": self.alloc,
"frag": self.frag,
}
@classmethod
def from_dict(cls, d: Dict[str, float]) -> "MemStats":
"""Create from dictionary."""
return cls(tot=d["tot"], free=d["free"], resv=d["resv"], alloc=d["alloc"], frag=d["frag"])
class TransformError(Exception):
"""An exception raised when a transform fails."""
@ -114,6 +167,11 @@ class TransformConfig(BaseModel):
description="Whether this transform requires shape propagation before it is applied.",
)
expect_mem_change: bool = Field(
default=False,
description="Whether this transform is expected to cause changes in CUDA memory stats.",
)
AutodeployMeta = Dict[str, Any]
_UntypedInferenceOptimizerConfig = Dict[str, Any]
@ -223,6 +281,7 @@ class BaseTransform(ABC):
config: TransformConfig # overwrite type hint if other config cls is used in subclass!
_autodeploy_meta_key: str = "_autodeploy"
_history_key: str = "transform_history"
_mem_history_key: str = "mem_history"
_transform_key: str # Set by TransformRegistry.register() decorator
@classmethod
@ -324,6 +383,9 @@ class BaseTransform(ABC):
# show debug info for debug config
ad_logger.debug(f"{t_name} config: {self.config}")
# capture memory stats at the start
mem_pre = self._get_mem_stats(empty_cache=True)
# store some timing information
elapsed_time_total = -time.time()
elapsed_time_pre_cleanup = 0.0
@ -340,23 +402,28 @@ class BaseTransform(ABC):
self.config.requires_shape_prop,
info.is_clean,
info.has_valid_shapes,
phase="pre",
)
elapsed_time_pre_cleanup += time.time()
# run the transform in a error-handling wrapper if desired
elapsed_time_apply = -time.time()
if self.config.skip_on_error:
try:
with self._apply_logging_context():
self._log_info("applying transform...")
if self.config.skip_on_error:
try:
mod, info_apply = self._apply_per_gm_or_whole_model(
mod, cm, factory, shared_config
)
except Exception as e:
error_msg = f"Transform {t_name} failed"
ad_logger.warning(f"{error_msg}: {e}")
info_apply = TransformInfo(skipped=True, num_matches=0)
else:
# handle this here normally to improve debugging and error message
mod, info_apply = self._apply_per_gm_or_whole_model(
mod, cm, factory, shared_config
)
except Exception as e:
error_msg = f"Transform {t_name} failed"
ad_logger.warning(f"{error_msg}: {e}")
info_apply = TransformInfo(skipped=True, num_matches=0)
else:
# handle this here normally to improve debugging and error message
mod, info_apply = self._apply_per_gm_or_whole_model(mod, cm, factory, shared_config)
elapsed_time_apply += time.time()
# we cannot say it's clean if the previous wasn't clean even if this one is
@ -371,31 +438,40 @@ class BaseTransform(ABC):
self.config.run_shape_prop,
info.is_clean,
info.has_valid_shapes,
phase="post",
)
elapsed_time_post_cleanup += time.time()
elapsed_time_total += time.time()
# capture memory stats at the end and log summary (only log if enabled)
mem_post = self._get_mem_stats(empty_cache=True)
if self.config.enabled:
self._log_mem_summary(mem_pre, mem_post, self.config.expect_mem_change)
# log the result of the transform
log_msgs = [
f"enabled={self.config.enabled}",
"skipped=True" if info.skipped else f"num_matches={info.num_matches}",
f"is_clean={info.is_clean}",
f"has_valid_shapes={info.has_valid_shapes}",
]
self._log_info(", ".join(log_msgs))
log_msgs_timing = [
f"elapsed time: total={elapsed_time_total:.3f}s",
f"pre_cleanup={elapsed_time_pre_cleanup:.3f}s",
f"apply={elapsed_time_apply:.3f}s",
f"post_cleanup={elapsed_time_post_cleanup:.3f}s",
]
self._log_info(", ".join(log_msgs_timing))
self._log_transform_summary(
enabled=self.config.enabled,
skipped=info.skipped,
num_matches=info.num_matches,
elapsed_total=elapsed_time_total,
elapsed_pre=elapsed_time_pre_cleanup,
elapsed_apply=elapsed_time_apply,
elapsed_post=elapsed_time_post_cleanup,
)
ad_logger.debug(f"Model after {t_name}: {mod}")
# update + store new meta data
# update + store new meta data (transform history and memory history)
history[t_name] = info
autodeploy_meta[self._history_key] = history
# store memory history
mem_history: Dict[str, Dict[str, Dict[str, float]]] = autodeploy_meta.get(
self._mem_history_key, {}
)
mem_history[t_name] = {"pre": mem_pre.to_dict(), "post": mem_post.to_dict()}
autodeploy_meta[self._mem_history_key] = mem_history
self._set_autodeploy_meta(mod, autodeploy_meta)
# return the graph module
@ -423,11 +499,161 @@ class BaseTransform(ABC):
info = info & info_apply if info is not None else info_apply
return mod, info
@final
def _log_warning(self, *args: any):
"""Log a warning message with the transform key."""
ad_logger.warning(*args)
@final
def _log_info(self, *args: any):
"""Log a message with the transform key."""
ad_logger.info(*args)
@final
def _log_debug(self, *args: any):
"""Log a message with the transform key."""
ad_logger.debug(*args)
@contextmanager
def _apply_logging_context(self):
"""Context manager to add [APPLY] prefix to logs during transform execution."""
original_log = ad_logger.log
apply_label = "[APPLY]"
def _patched_log(severity, *msg):
# Prepend [APPLY] after any existing prefix
if msg:
first = msg[0] if isinstance(msg[0], str) else str(msg[0])
return original_log(severity, f"{apply_label} {first}", *msg[1:])
return original_log(severity, apply_label)
ad_logger.log = _patched_log # type: ignore[assignment]
try:
yield
finally:
ad_logger.log = original_log # type: ignore[assignment]
@final
def _log_cleanup_status(self, phase: str, action: str, reason: str = "") -> None:
"""Log cleanup status with colored formatting.
Args:
phase: "pre" or "post"
action: "ran" or "skipped"
reason: Description of what ran or why skipped
"""
label = f"{_Colors.CYAN}[{phase.upper()}-CLEANUP]{_Colors.RESET}"
if action == "skipped":
self._log_info(f"{label} {_Colors.DIM}skipped ({reason}){_Colors.RESET}")
else:
self._log_info(f"{label} {reason}")
@final
def _log_transform_summary(
self,
enabled: bool,
skipped: bool,
num_matches: int,
elapsed_total: float,
elapsed_pre: float,
elapsed_apply: float,
elapsed_post: float,
) -> None:
"""Log transform summary with colored formatting.
Args:
enabled: Whether the transform was enabled.
skipped: Whether the transform was skipped.
num_matches: Number of matches found.
elapsed_total: Total elapsed time in seconds.
elapsed_pre: Pre-cleanup elapsed time in seconds.
elapsed_apply: Apply elapsed time in seconds.
elapsed_post: Post-cleanup elapsed time in seconds.
"""
label = f"{_Colors.GREEN}[SUMMARY]{_Colors.RESET}"
timing_str = (
f"{elapsed_total:.3f}s "
f"(pre={elapsed_pre:.3f}s, apply={elapsed_apply:.3f}s, post={elapsed_post:.3f}s)"
)
if not enabled:
self._log_info(f"{label} {_Colors.DIM}disabled{_Colors.RESET}")
elif skipped:
self._log_info(f"{label} {_Colors.DIM}skipped{_Colors.RESET} | time: {timing_str}")
else:
self._log_info(f"{label} matches={num_matches} | time: {timing_str}")
@final
def _get_mem_stats(self, empty_cache: bool = True) -> MemStats:
"""Get current CUDA memory statistics.
Args:
empty_cache: Whether to empty the memory cache before getting the memory stats.
Returns:
MemStats object with current memory values in GB.
"""
tot, free, resv, alloc, frag = get_mem_info(empty_cache=empty_cache, unit="GB")
return MemStats(tot=tot, free=free, resv=resv, alloc=alloc, frag=frag)
@final
def _log_mem_summary(self, pre: MemStats, post: MemStats, expect_mem_change: bool) -> None:
"""Log memory summary with diff between pre and post stats.
Logs one of three cases:
1. Expected mem change: info log, magenta color
2. Unexpected mem change: warning log, yellow color
3. No mem change: debug log, no colors
Args:
pre: Memory stats captured before the transform.
post: Memory stats captured after the transform.
expect_mem_change: Whether this transform is expected to cause memory changes.
"""
diff = post.diff(pre)
# Threshold for detecting significant memory changes (in GB)
mem_change_threshold = 0.005
# Check if there was a significant memory change
has_mem_change = (
abs(diff.resv) >= mem_change_threshold
or abs(diff.alloc) >= mem_change_threshold
or abs(diff.frag) >= mem_change_threshold
)
def _fmt_val_with_delta(val: float, delta: float, color: str) -> str:
"""Format value with optional delta in the specified color."""
val_str = f"{val:6.2f}GB"
if abs(delta) < mem_change_threshold:
return val_str
sign = "+" if delta > 0 else ""
if color:
return f"{val_str} {_Colors.BOLD}{color}({sign}{delta:.2f}GB){_Colors.RESET}"
return f"{val_str} ({sign}{delta:.2f}GB)"
def _fmt_parts(color: str) -> str:
"""Format all memory parts with the specified color for deltas."""
parts = [
f"free: {_fmt_val_with_delta(post.free, diff.free, color)}",
f"resv: {_fmt_val_with_delta(post.resv, diff.resv, color)}",
f"alloc: {_fmt_val_with_delta(post.alloc, diff.alloc, color)}",
f"frag: {_fmt_val_with_delta(post.frag, diff.frag, color)}",
]
return " | ".join(parts)
if has_mem_change and expect_mem_change:
# Case 1: Expected mem change - info log, magenta
label = f"{_Colors.MAGENTA}[CUDA MEM DIFF (EXPECTED)]{_Colors.RESET}"
self._log_info(f"{label} {_fmt_parts(_Colors.MAGENTA)}")
elif has_mem_change and not expect_mem_change:
# Case 2: Unexpected mem change - warning log, yellow
label = f"{_Colors.YELLOW}[CUDA MEM DIFF (UNEXPECTED)]{_Colors.RESET}"
self._log_warning(f"{label} {_fmt_parts(_Colors.YELLOW)}")
else:
# Case 3: No mem change - debug log, no colors
self._log_debug(f"[CUDA MEM] {_fmt_parts('')}")
@final
def _get_autodeploy_meta(self, mod: nn.Module) -> AutodeployMeta:
"""Get the autodeploy metadata from the graphmodule."""
@ -450,8 +676,9 @@ class BaseTransform(ABC):
clean_shape: bool,
is_clean: bool,
has_valid_shapes: bool,
phase: str,
) -> TransformInfo:
"""Run graph cleanup before the transform.
"""Run graph cleanup before or after the transform.
Args:
mod: The model to run cleanup on.
@ -459,23 +686,28 @@ class BaseTransform(ABC):
clean_shape: Whether we want clean shapes after the transform.
is_clean: The current cleanup status.
has_valid_shapes: The current shape propagation status.
phase: The phase of cleanup ("pre" or "post").
Returns:
An info object indicating the cleanup status after this function is called.
"""
# check if run cleanup depending on the config and info
if clean_shape and not (is_clean and has_valid_shapes):
self._log_info("running graph cleanup (with shape_prop)")
self._log_cleanup_status(phase, "ran", "graph canonicalization + shape_prop")
canonicalize_graph(mod)
with lift_to_meta(mod) if placeholders_on_meta(mod) else nullcontext():
run_shape_prop(mod)
is_clean = True
has_valid_shapes = True
elif clean_graph and not is_clean:
self._log_info("running graph cleanup (no shape_prop)")
self._log_cleanup_status(phase, "ran", "graph canonicalization")
canonicalize_graph(mod)
is_clean = True
has_valid_shapes = False
elif not clean_graph and not clean_shape:
self._log_cleanup_status(phase, "skipped", "disabled")
else:
self._log_cleanup_status(phase, "skipped", "graph already clean")
return TransformInfo(is_clean=is_clean, has_valid_shapes=has_valid_shapes)

View File

@ -46,6 +46,9 @@ class BuildModel(BaseTransform):
# build the model
model = factory.build_model(self.config.device)
# update the kv cache config
cm.update_kv_cache_config(**factory.get_cache_config_updates())
# by convention, we say the model is always clean
info = TransformInfo(skipped=False, num_matches=1, is_clean=True, has_valid_shapes=True)

View File

@ -21,13 +21,14 @@ import torch
from torch._ops import OpOverloadPacket
from torch.fx import GraphModule, Node
from .....llmapi.llm_args import KvCacheConfig
from ...custom_ops.attention_interface import (
AttentionDescriptor,
AttentionLayout,
AttentionRegistry,
CacheConfig,
CacheInitializerDict,
MHACallable,
ResourceHandler,
ResourceHandlerDict,
SequenceInfo,
)
from ...models.factory import ModelFactory
@ -195,12 +196,30 @@ class DetectHiddenStatesForCapture(BaseTransform):
return gm, info
class HiddenStatesResourceHandler(ResourceHandler):
"""A resource handler for hidden states."""
def __init__(self, hidden_size: int, dtype: torch.dtype) -> None:
"""Initialize the HiddenStatesResourceHandler.
Args:
hidden_size: The size of the hidden states resource.
dtype: The dtype of the hidden states resource.
"""
self.hidden_size = hidden_size
self.dtype = dtype
def allocate(self, sequence_info: SequenceInfo) -> torch.Tensor:
return torch.empty(
sequence_info.max_num_tokens,
self.hidden_size,
device=sequence_info.device,
dtype=self.dtype,
)
@AttentionRegistry.register("cached_residual_add")
class CachedResidualAdd(AttentionDescriptor):
@classmethod
def is_paged(cls) -> bool:
return True
@classmethod
def get_attention_layout(cls) -> AttentionLayout:
return "bsnd"
@ -219,15 +238,12 @@ class CachedResidualAdd(AttentionDescriptor):
@classmethod
def get_cache_initializers(
cls, source_attn_node: Node, cache_config: CacheConfig
) -> CacheInitializerDict:
cls, source_attn_node: Node, cache_config: KvCacheConfig
) -> ResourceHandlerDict:
hidden_size = source_attn_node.meta["val"].shape[-1]
hidden_type = source_attn_node.meta["val"].dtype
def _get_hidden_states_cache(si: SequenceInfo):
return torch.empty(si.max_num_tokens, hidden_size, device=si.device, dtype=hidden_type)
return {"hidden_states_cache": _get_hidden_states_cache}
return {"hidden_states_cache": HiddenStatesResourceHandler(hidden_size, dtype=hidden_type)}
@classmethod
def get_standard_metadata_args(cls) -> List[str]:

View File

@ -18,7 +18,7 @@
import inspect
import operator
from typing import Dict, List, Optional, Tuple, Type
from typing import List, Optional, Tuple, Type
import torch
import torch.nn as nn
@ -28,17 +28,13 @@ from torch.fx import GraphModule, Node
from ...custom_ops.attention_interface import (
AttentionDescriptor,
AttentionRegistry,
CacheConfig,
Constant,
PrepareMetadataCallable,
)
from ...distributed.common import all_gather_object, get_world_size
from ...distributed.common import is_initialized as is_distributed_initialized
from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface
from ...utils._graph import add_graph_input
from ...utils.cuda_mem_tracker import get_mem_info_in_mb
from ...utils.logger import ad_logger
from ...utils.cuda_mem_tracker import get_mem_info
from ...utils.node_utils import is_op
from ..interface import (
BaseTransform,
@ -53,9 +49,6 @@ class InsertCachedAttentionConfig(TransformConfig):
"""Configuration for the insert cached attention transform."""
backend: Optional[str] = Field(default=None, description="The attention backend to use.")
cache_config: CacheConfig = Field(
default_factory=CacheConfig, description="The custom cache configuration to use."
)
@TransformRegistry.register("insert_cached_attention")
@ -154,7 +147,6 @@ class InsertCachedAttention(BaseTransform):
meta_nodes_std: List[Node],
meta_nodes_extra: List[Node],
cache_nodes: List[Node],
buffer_nodes: List[Node],
constants: List[Constant],
):
"""Insert a cached attention node into the graph."""
@ -166,7 +158,6 @@ class InsertCachedAttention(BaseTransform):
*meta_nodes_std,
*meta_nodes_extra,
*cache_nodes,
*buffer_nodes,
*constants,
),
)
@ -183,10 +174,6 @@ class InsertCachedAttention(BaseTransform):
"""Replace uncached source attention node with corresponding cached attn node."""
attn_descriptor = self.attn_descriptor
# run field-wise or to combine the cache config from the transform and the factory
# the transform config takes precedence over the factory config
cache_config = self.config.cache_config | factory.get_cache_config()
# Get all attention nodes and their info objects
source_op = attn_descriptor.get_source_attention_op()
@ -199,10 +186,6 @@ class InsertCachedAttention(BaseTransform):
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
)
# Sanity check
if cm.info.is_paged:
assert attn_descriptor.is_paged(), "Paged sequence info requires paged attention op."
# get standard metadata nodes for all source attention nodes
meta_nodes_std = self._process_metadata_std(gm, cm)
@ -212,8 +195,6 @@ class InsertCachedAttention(BaseTransform):
# Register host-side prepare_metadata function for attention descriptor.
self._process_metadata_host(cm)
buffer_in_lookup: Dict[str, Node] = {}
# replace fused attention node with attention node that has kv cache
num_cached_attn_replacements = 0
for idx, attn_node in enumerate(source_attn_nodes):
@ -222,22 +203,13 @@ class InsertCachedAttention(BaseTransform):
# setup + store cache initializers and caches as input nodes
cache_in_nodes = []
for k, get_cache in attn_descriptor.get_cache_initializers(
attn_node, cache_config
for k, resource_handler in attn_descriptor.get_cache_initializers(
attn_node, cm.kv_cache_config
).items():
k_indexed = f"{k}_{idx}"
cm.add_cache(k_indexed, get_cache)
cm.add_resource(k_indexed, resource_handler)
cache_in_nodes.append(self._process_cache_node(gm, k_indexed))
# setup + store global buffer initializers and buffers as input nodes
# NOTE: we have to check against existing keys to make sure nothing is registered twice...
buffer_in_nodes = []
for k, get_buffer in attn_descriptor.get_global_buffer_initializers(attn_node).items():
if k not in buffer_in_lookup:
cm.add_cache(k, get_buffer)
buffer_in_lookup[k] = self._process_cache_node(gm, k)
buffer_in_nodes.append(buffer_in_lookup[k]) # store buffer nodes for this op
# retrieve constants for attention_op
constants = attn_descriptor.get_constants(attn_node)
@ -249,7 +221,6 @@ class InsertCachedAttention(BaseTransform):
meta_nodes_std,
meta_nodes_extra,
cache_in_nodes,
buffer_in_nodes,
constants,
)
@ -276,27 +247,15 @@ class InsertCachedMLAAttention(InsertCachedAttention):
pass
class ResizeKVCacheConfig(TransformConfig):
"""Configuration for the resize kv cache transform."""
free_mem_ratio: float = Field(
default=0.0, ge=0.0, le=1.0, description="The fraction of available memory to occupy."
)
@TransformRegistry.register("resize_kv_cache")
class ResizeKVCache(BaseTransform):
"""Inflate the kv cache to occupy the available GPU memory.
"""Resize the KV cache to occupy available GPU memory.
free_mem_ratio specifies the fraction of available memory to occupy.
This implements the two-phase approach:
1. Run a forward pass to allocate intermediate memory (activations, workspaces, etc.)
2. Call resize_kv_cache_manager() to recreate KVCacheManager with optimal capacity
"""
config: ResizeKVCacheConfig
@classmethod
def get_config_class(cls) -> Type[TransformConfig]:
return ResizeKVCacheConfig
def _apply_to_full_model(
self,
mod: nn.Module,
@ -304,100 +263,29 @@ class ResizeKVCache(BaseTransform):
factory: ModelFactory,
shared_config: SharedConfig,
) -> Tuple[nn.Module, TransformInfo]:
free_mem_ratio = self.config.free_mem_ratio
free_mem, total_mem = get_mem_info_in_mb(empty_cache=True)
self._log_info(f"Free memory (MB): {free_mem}, Total memory (MB): {total_mem}")
current_cache_size = cm.current_cache_size_bytes()
current_kv_cache_size = getattr(cm, "current_kv_cache_size_bytes", None)
current_kv_cache_size = (
current_kv_cache_size() if callable(current_kv_cache_size) else current_cache_size
)
current_num_pages = cm.info.num_pages
self._log_info(
f"Current cache size (MB): {current_cache_size // 1024**2}, "
f"Current num pages: {current_num_pages}"
)
if current_kv_cache_size != current_cache_size:
self._log_info(
f"Current KV-only cache size (MB): {current_kv_cache_size // 1024 // 1024}"
)
if free_mem_ratio == 0.0:
self._log_info(f"Skipping cache resize for {free_mem_ratio=}")
# check if we need a resize or not
if not cm.needs_resize():
return mod, TransformInfo(
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
)
# TODO: the manual PyTorch workflow respects max_num_tokens if set and does _NOT_ resize
# the cache in this case. Should we do the same here?
# Let's run a forward pass to get the memory usage
# Run a forward pass to get the extra memory usage
cm.info.set_max_num_tokens_sample()
free_mem_pre, _ = get_mem_info_in_mb(empty_cache=True)
self._log_info(f"Free memory before forward pass (MB): {free_mem_pre}")
# Reset peak memory stats to get the extra memory used during the forward pass
torch.cuda.reset_peak_memory_stats()
memory_allocated_before_forward_pass_mb = torch.cuda.memory_allocated() // 1024**2
try:
mod(**cm.named_args)
except torch.OutOfMemoryError as e:
ad_logger.error(
self._log_info(
f"OutOfMemoryError in forward pass while trying to resize the kv-cache:\n{e}"
)
raise e
peak_memory_during_forward_pass_mb = torch.cuda.max_memory_allocated() // 1024**2
mem_used_during_forward_pass_mb = (
peak_memory_during_forward_pass_mb - memory_allocated_before_forward_pass_mb
)
self._log_info(
f"Peak memory uasge during forward pass (MB): {peak_memory_during_forward_pass_mb}"
)
self._log_info(
f"Extra memory used during forward pass (MB): {mem_used_during_forward_pass_mb}"
)
# NOTE: use fragmented memory without empty cache (peak forward memory + fragmented memory)
# as a proxy for the memory reserved for the forward pass. This is a rough estimate and
# may not be accurate.
*_, mem_reserved_for_forward = get_mem_info(empty_cache=False, unit="B")
# TODO (lucaslie): logic needs overhaul, too much going on. For now, this is just reverting
# to the original logic. Full overhaulwill be done as part of #10013
free_mem_post, _ = get_mem_info_in_mb(empty_cache=False)
self._log_info(f"Free memory after forward pass (MB): {free_mem_post}")
memory_for_forward_pass = free_mem_pre - free_mem_post
self._log_info(f"Memory for forward pass (MB): {memory_for_forward_pass}")
# Compute new pages using KV-only bytes to avoid SSM/conv inflating per-page cost
# Reserve headroom to avoid OOM from other allocations (workspaces, cudagraph pools, etc.)
reserve_mb = max(1024, (total_mem * 5) // 100) # at least 1 GiB or 5% of total
available_mb = max(0, free_mem_post - reserve_mb)
new_kv_total_bytes = int(
available_mb * 1024 * 1024 * free_mem_ratio + current_kv_cache_size
)
per_page_bytes = max(1, current_kv_cache_size // max(1, current_num_pages))
new_num_pages = int(new_kv_total_bytes // per_page_bytes)
# Need to sync all the GPUs if distributed group is initialized
log_msg = f"Using local new_num_pages: {new_num_pages}"
if is_distributed_initialized():
gathered_num_pages = [None] * get_world_size()
all_gather_object(gathered_num_pages, new_num_pages)
new_num_pages = min(gathered_num_pages)
log_msg = f"After all_gather - new_num_pages: {new_num_pages}"
self._log_info(log_msg)
cm.resize_cache(new_num_pages)
# Log the final cache size for performance measurement, do not remove this log.
final_cache_size_bytes = cm.current_cache_size_bytes()
final_cache_size_gb = final_cache_size_bytes / (1024**3) # Convert to GiB
self._log_info(
f"Final KV cache size after resize: {final_cache_size_gb:.2f} GiB ({new_num_pages} pages)"
)
# Free memory
torch.cuda.empty_cache()
# Resize - KVCacheManager will compute optimal capacity based on free memory
cm.resize_kv_cache_manager(mem_reserved_for_forward)
info = TransformInfo(
skipped=False,
@ -411,6 +299,13 @@ class ResizeKVCache(BaseTransform):
@TransformRegistry.register("initialize_cache")
class InitializeCache(BaseTransform):
"""Initialize KV caches using KVCacheManager.
Gets kv_cache_config from shared_config.ad_config and creates the KVCacheManager
in estimation mode with conservative capacity. The ResizeKVCache transform will
later recreate it with optimal capacity after measuring memory usage.
"""
def _apply_to_full_model(
self,
mod: nn.Module,
@ -418,7 +313,9 @@ class InitializeCache(BaseTransform):
factory: ModelFactory,
shared_config: SharedConfig,
) -> Tuple[nn.Module, TransformInfo]:
num_caches = cm.initialize_caches()
# Initialize with estimation mode
# This allows resize_kv_cache to recreate with correct capacity after measuring memory
num_caches = cm.initialize_resources()
self._log_info(f"Initialized {num_caches} caches for cached attention")
info = TransformInfo(

View File

@ -163,8 +163,8 @@ def get_cached_attn(
query,
key,
value,
# metadata+caches+buffers with name lookup set up during kvcache transform
*[kwargs[k] for k in module._node_ref.meta["metadata_cache_buffer_keys"]],
# metadata+caches with name lookup set up during kvcache transform
*[kwargs[k] for k in module._node_ref.meta["metadata_cache_keys"]],
# constants set up during kvcache transform
*module._node_ref.meta["constants"],
)
@ -242,17 +242,11 @@ class HFReplaceCachedAttn(InsertCachedAttention):
meta_nodes_std: List[Node],
meta_nodes_extra: List[Node],
cache_nodes: List[Node],
buffer_nodes: List[Node],
constants: List[Constant],
):
"""Here we now need to actually do the correct mapping of the cached attn nodes."""
# store reference to metadata, caches, buffers, and constants for this attn node
attn_node.meta["metadata_cache_buffer_keys"] = (
*meta_nodes_std,
*meta_nodes_extra,
*cache_nodes,
*buffer_nodes,
)
# store reference to metadata, caches, and constants for this attn node
attn_node.meta["metadata_cache_keys"] = (*meta_nodes_std, *meta_nodes_extra, *cache_nodes)
attn_node.meta["constants"] = constants
def _apply_to_full_model(

View File

@ -8,6 +8,7 @@ from pydantic import Field
from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface
from ...utils._graph import move_to_device
from ...utils.cuda_mem_tracker import bytes_to
from ..interface import (
BaseTransform,
SharedConfig,
@ -43,10 +44,11 @@ class LoadWeightsToDevice(BaseTransform):
factory: ModelFactory,
shared_config: SharedConfig,
) -> Tuple[nn.Module, TransformInfo]:
factory.load_or_random_init(
mod,
device=self.config.checkpoint_device or cm.device,
)
params_size = sum(p.numel() * p.element_size() for p in mod.parameters())
total_size_GB = bytes_to(params_size, unit="GB")
self._log_info(f"Estimated parameters memory: {total_size_GB:.2f} GB")
factory.load_or_random_init(mod, device=self.config.checkpoint_device or cm.device)
move_to_device(mod, cm.device)
info = TransformInfo(skipped=False, num_matches=0, is_clean=True, has_valid_shapes=True)

View File

@ -16,12 +16,15 @@
import gc
from contextlib import contextmanager
from typing import Tuple
from typing import Literal, Tuple, Union
import torch
from .logger import ad_logger
Number = Union[int, float]
ByteUnit = Literal["B", "KB", "MB", "GB", "TB"]
@contextmanager
def cuda_memory_tracker(logger=ad_logger):
@ -43,10 +46,38 @@ def cuda_memory_tracker(logger=ad_logger):
logger.warning(f"Potential memory leak detected, leaked memory: {leaked} bytes")
def get_mem_info_in_mb(empty_cache: bool = True) -> Tuple[int, int]:
def bytes_to(bytes: int, *more_bytes: int, unit: ByteUnit) -> Union[Number, Tuple[Number, ...]]:
units = {"KB": 1 << 10, "MB": 1 << 20, "GB": 1 << 30, "TB": 1 << 40}
bytes_converted = (bytes,) + more_bytes
unit = unit.upper()
if unit != "B":
bytes_converted = tuple(float(x) / units[unit.upper()] for x in bytes_converted)
return bytes_converted if more_bytes else bytes_converted[0]
def get_mem_info(
empty_cache: bool = True, unit: ByteUnit = "B"
) -> Tuple[Number, Number, Number, Number, Number]:
"""Get the memory information of the current device.
Args:
empty_cache: Whether to empty the memory cache.
unit: The unit of the memory information. Defaults to bytes.
Returns:
A tuple of the
- total memory,
- free memory,
- reserved memory,
- allocated memory,
- fragmented memory
in the specified unit.
"""
if empty_cache:
# Clear the memory cache to get the exact free memory
torch.cuda.empty_cache()
free_mem, total_mem = torch.cuda.mem_get_info()
MB = 1024**2
return free_mem // MB, total_mem // MB
res_mem = torch.cuda.memory_reserved()
alloc_mem = torch.cuda.memory_allocated()
frag_mem = res_mem - alloc_mem
return bytes_to(total_mem, free_mem, res_mem, alloc_mem, frag_mem, unit=unit)

View File

@ -193,6 +193,7 @@ class MambaCacheManager(BaseResourceManager):
dtype=torch.int32,
device=device)
@torch.inference_mode()
def _prepare_mamba_cache_blocks(self, request_ids: List[int]):
self.state_indices_list.clear()
for r in request_ids:

View File

@ -40,10 +40,10 @@ class TestLlama3_1_8B(LlmapiAccuracyTestHarness):
# Set it explicitly here to 8192 which is the default in build_config.
"max_num_tokens": 8192,
"skip_loading_weights": False,
"kv_cache_config": {
"free_gpu_memory_fraction": 0.7
},
"transforms": {
"resize_kv_cache": {
"free_mem_ratio": 0.7
},
"compile_model": {
"backend":
"torch-cudagraph",
@ -55,7 +55,7 @@ class TestLlama3_1_8B(LlmapiAccuracyTestHarness):
if enable_chunked_prefill:
config["enable_chunked_prefill"] = True
config[
"max_num_tokens"] = 512 # NOTE: must be > max(attn_page_size, max_batch_size)
"max_num_tokens"] = 512 # NOTE: must be > max(tokens_per_block, max_batch_size)
return config
def get_default_sampling_params(self):
@ -110,7 +110,8 @@ class TestNemotronH(LlmapiAccuracyTestHarness):
"trust_remote_code": True,
# SSMs do not support cache reuse.
"kv_cache_config": {
"enable_block_reuse": False
"enable_block_reuse": False,
"free_gpu_memory_fraction": 0.7
},
# Keep max_batch_size as in the PyTorch test to avoid OOM
"max_batch_size": 128,
@ -120,9 +121,6 @@ class TestNemotronH(LlmapiAccuracyTestHarness):
"max_num_tokens": 8192,
"skip_loading_weights": False,
"transforms": {
"resize_kv_cache": {
"free_mem_ratio": 0.7
},
"compile_model": {
"backend": "torch-cudagraph",
"cuda_graph_batch_sizes": [1, 2, 4, 8, 16, 32, 64, 128],
@ -132,7 +130,7 @@ class TestNemotronH(LlmapiAccuracyTestHarness):
if enable_chunked_prefill:
config["enable_chunked_prefill"] = True
config[
"max_num_tokens"] = 512 # NOTE: must be > max(attn_page_size, max_batch_size)
"max_num_tokens"] = 512 # NOTE: must be > max(tokens_per_block, max_batch_size)
return config
def get_default_sampling_params(self):
@ -169,7 +167,10 @@ class TestNemotronMOE(LlmapiAccuracyTestHarness):
"trust_remote_code": True,
# SSMs do not support cache reuse.
"kv_cache_config": {
"enable_block_reuse": False
"enable_block_reuse": False,
"free_gpu_memory_fraction": 0.7
# NOTE: some accuracy benchmarks may require fp32 precision for mamba cache
# "mamba_ssm_cache_dtype": "float32",
},
# Keep max_batch_size as in the PyTorch test to avoid OOM
"max_batch_size": 128,
@ -180,7 +181,6 @@ class TestNemotronMOE(LlmapiAccuracyTestHarness):
"max_num_tokens": 8192,
"skip_loading_weights": False,
"compile_backend": "torch-cudagraph",
"free_mem_ratio": 0.7,
"cuda_graph_batch_sizes": [1, 2, 4, 8, 16, 32, 64, 128],
"transforms": {
"detect_sharding": {
@ -191,12 +191,6 @@ class TestNemotronMOE(LlmapiAccuracyTestHarness):
"stage": "compile",
"enabled": True,
},
# NOTE: some accuracy benchmarks may require fp32 precision for mamba cache
# "insert_cached_ssm_attention": {
# "cache_config": {
# "mamba_dtype": "float32",
# },
# },
}
}
@ -213,7 +207,8 @@ class TestNemotronMOE(LlmapiAccuracyTestHarness):
kwargs = self.get_default_kwargs()
# TODO: multi-stream MOE seems to increase the memory usage
kwargs["max_batch_size"] = 32
kwargs["free_mem_ratio"] = 0.4
kwargs["kv_cache_config"] = {"free_gpu_memory_fraction": 0.4}
sampling_params = self.get_default_sampling_params()
with AutoDeployLLM(model=self.MODEL_PATH_BF16,
tokenizer=self.MODEL_PATH_BF16,
**kwargs) as llm:
@ -279,7 +274,6 @@ class TestNemotronSuperV3(LlmapiAccuracyTestHarness):
"trust_remote_code": True,
"skip_loading_weights": False,
"compile_backend": "torch-cudagraph",
"free_mem_ratio": 0.9,
"max_batch_size": 128,
"max_seq_len": self.MAX_SEQ_LEN,
"max_num_tokens": self.MAX_SEQ_LEN,

View File

@ -576,7 +576,6 @@ class PerfTestConfig:
extra: bool = False,
# _autodeploy backend specific parameters
ad_compile_backend: str = "torch-opt",
free_mem_ratio: float = 0.9,
extra_runtime: str = "trtllm",
skip_loading_weights: bool = False,
):
@ -636,7 +635,6 @@ class PerfTestConfig:
self.extra = extra
# _autodeploy backend specific parameters
self.ad_compile_backend = ad_compile_backend
self.free_mem_ratio = free_mem_ratio
self.extra_runtime = extra_runtime
self.skip_loading_weights = skip_loading_weights
# Just build engines
@ -1422,9 +1420,6 @@ class MultiMetricPerfTest(AbstractPerfScriptTestClass):
'compile_model': {
'backend': self._config.ad_compile_backend
},
'resize_kv_cache': {
'free_mem_ratio': self._config.free_mem_ratio
},
},
'runtime': self._config.extra_runtime,
'skip_loading_weights': self._config.skip_loading_weights

View File

@ -18,11 +18,11 @@ from tensorrt_llm._torch.auto_deploy.transform.library.sharding import ShardingT
class FakeFactory(ModelFactory):
"""Dummy factory to pass cache_config for testing."""
"""Dummy factory to pass cache_config_updates for testing."""
def __init__(self, model=None, cache_config=None, quant_config=None):
def __init__(self, model=None, cache_config_updates=None, quant_config=None):
self._model = model
self.cache_config = cache_config
self.cache_config_updates = cache_config_updates
self.quant_config = quant_config
def build_model(self, device: str):
@ -34,8 +34,8 @@ class FakeFactory(ModelFactory):
def _load_checkpoint(self, model, device):
return
def get_cache_config(self):
return self.cache_config
def get_cache_config_updates(self):
return self.cache_config_updates
def get_quant_config(self):
return self.quant_config

View File

@ -511,7 +511,10 @@ def get_small_model_config(model_hub_id: str, **llm_args_kwargs) -> Dict[str, An
# add some defaults to llm_args
llm_args["skip_loading_weights"] = True # No weight loading to speed up things
llm_args["attn_page_size"] = 4 # Make sure paging is activated despite small max_tokens
llm_args["kv_cache_config"] = {
"tokens_per_block": 4, # Make sure paging is activated despite small max_tokens
"free_gpu_memory_fraction": 0.0, # No resizing of the cache to keep the mem footprint small
}
llm_args["max_batch_size"] = 2 # Minimum batching to speed up things
# update with custom llm_args kwargs
llm_args.update(llm_args_kwargs)

View File

@ -84,8 +84,7 @@ def test_flashinfer_attention_op_context(seq_length, n_heads, batch_size, dtype,
)
# make sure planner is initialized
workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device)
_GlobalFlashInferPlanner.init_workspace(workspace)
_GlobalFlashInferPlanner.reset(torch.device(device))
batch_indices, positions = flashinfer.get_batch_indices_positions(
qo_indptr,
@ -118,8 +117,6 @@ def test_flashinfer_attention_op_context(seq_length, n_heads, batch_size, dtype,
# CACHES
k_cache,
v_cache,
# BUFFERS
workspace,
# CONSTANTS
None,
1.0,
@ -229,8 +226,7 @@ def test_flashinfer_attention_op_decode(
)
# make sure planner is initialized
workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device)
_GlobalFlashInferPlanner.init_workspace(workspace)
_GlobalFlashInferPlanner.reset(torch.device(device))
batch_indices, positions = flashinfer.get_batch_indices_positions(
qo_indptr,
@ -262,8 +258,6 @@ def test_flashinfer_attention_op_decode(
# CACHES
k_cache,
v_cache,
# BUFFERS
workspace,
# CONSTANTS
None,
1.0,
@ -361,8 +355,7 @@ def test_flashinfer_attention_context_and_generate(
)
# make sure planner is initialized
workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device)
_GlobalFlashInferPlanner.init_workspace(workspace)
_GlobalFlashInferPlanner.reset(torch.device(device))
batch_indices, positions = flashinfer.get_batch_indices_positions(
qo_indptr,
@ -395,8 +388,6 @@ def test_flashinfer_attention_context_and_generate(
# CACHES
k_cache,
v_cache,
# BUFFERS
workspace,
# CONSTANTS
None,
1.0,
@ -454,7 +445,7 @@ def test_flashinfer_attention_context_and_generate(
v_3 = torch.randn(BATCH_SIZE, 1, N_HEADS * D_HEAD, dtype=DTYPE).to(device)
# Create FlashInferAttention class before calling the custom op
_GlobalFlashInferPlanner.reset()
_GlobalFlashInferPlanner.reset(torch.device(device))
batch_indices, positions = flashinfer.get_batch_indices_positions(
qo_indptr,
@ -485,8 +476,6 @@ def test_flashinfer_attention_context_and_generate(
# CACHES
k_cache,
v_cache,
# BUFFERS
workspace,
# CONSTANTS
None,
1.0,
@ -586,8 +575,7 @@ def test_flashinfer_attention_op_context_input_pos(seq, batch_size, n_heads, dty
)
# make sure planner is initialized
workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device)
_GlobalFlashInferPlanner.init_workspace(workspace)
_GlobalFlashInferPlanner.reset(torch.device(device))
batch_indices, positions = flashinfer.get_batch_indices_positions(
qo_indptr,
@ -620,8 +608,6 @@ def test_flashinfer_attention_op_context_input_pos(seq, batch_size, n_heads, dty
# CACHES
k_cache,
v_cache,
# BUFFERS
workspace,
# CONSTANTS
None,
1.0,
@ -748,8 +734,7 @@ def test_flashinfer_attention_with_fp8_cache(
v_cache = v_cache.to(torch.float8_e4m3fn)
# make sure planner is initialized
workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device)
_GlobalFlashInferPlanner.init_workspace(workspace)
_GlobalFlashInferPlanner.reset(torch.device(device))
batch_indices, positions = flashinfer.get_batch_indices_positions(
qo_indptr,
@ -782,8 +767,6 @@ def test_flashinfer_attention_with_fp8_cache(
# CACHES
k_cache,
v_cache,
# BUFFERS
workspace,
# CONSTANTS
None,
K_SCALE,
@ -859,8 +842,7 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de
seq_len_with_cache_host = (offsets + seq_len_tensor).cpu()
# make sure planner is initialized
workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device)
_GlobalFlashInferPlanner.init_workspace(workspace)
_GlobalFlashInferPlanner.reset(torch.device(device))
batch_indices, positions = flashinfer.get_batch_indices_positions(
qo_indptr,
@ -891,8 +873,6 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de
# CACHES
k_cache,
v_cache,
# BUFFERS
workspace,
# CONSTANTS
None,
1.0,
@ -956,7 +936,7 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de
seq_len_with_cache2_host = (offsets2 + seq_len_tensor2).cpu()
# Create FlashInferAttention class before calling the custom op
_GlobalFlashInferPlanner.reset()
_GlobalFlashInferPlanner.reset(torch.device(device))
batch_indices, positions = flashinfer.get_batch_indices_positions(
qo_indptr2,
@ -987,8 +967,6 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de
# CACHES
k_cache,
v_cache,
# BUFFERS
workspace,
# CONSTANTS
None,
1.0,

View File

@ -0,0 +1,237 @@
"""Unit tests for ResourceHandler classes in attention_interface.py.
Tests the new resource handler abstraction for cache management:
- PagedResourceHandler (for paged KV caches)
- StateResourceHandler (for SSM/conv states)
- UnpagedResourceHandler (for unpaged local caches)
- AttentionDescriptor.resolve_cache_dtype()
"""
import pytest
import torch
from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import (
AttentionDescriptor,
ManagedResourceHandler,
PagedResourceHandler,
ResourceHandler,
SequenceInfo,
StateResourceHandler,
UnpagedResourceHandler,
)
# =============================================================================
# PagedResourceHandler Tests
# =============================================================================
def test_paged_handler_stores_token_shape_and_dtype():
"""Verify PagedResourceHandler stores token_shape and dtype correctly."""
handler = PagedResourceHandler(8, 64, dtype=torch.float16)
assert handler.token_shape == (8, 64)
assert handler.dtype == torch.float16
def test_paged_handler_single_dimension_token_shape():
"""Test PagedResourceHandler with single dimension token shape."""
handler = PagedResourceHandler(128, dtype=torch.bfloat16)
assert handler.token_shape == (128,)
assert handler.dtype == torch.bfloat16
def test_paged_handler_multi_dimension_token_shape():
"""Test PagedResourceHandler with multiple dimension token shape."""
handler = PagedResourceHandler(4, 8, 16, dtype=torch.float32)
assert handler.token_shape == (4, 8, 16)
assert handler.dtype == torch.float32
def test_paged_handler_allocate_raises_not_implemented():
"""Verify PagedResourceHandler.allocate() raises NotImplementedError."""
handler = PagedResourceHandler(8, 64, dtype=torch.float16)
seq_info = SequenceInfo(max_seq_len=128, max_batch_size=4)
with pytest.raises(NotImplementedError, match="Managed resources should not be allocated"):
handler.allocate(seq_info)
def test_paged_handler_is_resource_handler():
"""Verify PagedResourceHandler is a ResourceHandler subclass."""
handler = PagedResourceHandler(8, 64, dtype=torch.float16)
assert isinstance(handler, ResourceHandler)
def test_paged_handler_is_managed_resource():
"""Verify PagedResourceHandler is a ManagedResourceHandler."""
handler = PagedResourceHandler(8, 64, dtype=torch.float16)
assert isinstance(handler, ManagedResourceHandler)
# =============================================================================
# StateResourceHandler Tests
# =============================================================================
def test_state_handler_stores_state_shape_and_dtype():
"""Verify StateResourceHandler stores state_shape and dtype correctly."""
handler = StateResourceHandler(4, 64, 16, dtype=torch.bfloat16)
assert handler.state_shape == (4, 64, 16)
assert handler.dtype == torch.bfloat16
def test_state_handler_single_dimension_state_shape():
"""Test StateResourceHandler with single dimension state shape."""
handler = StateResourceHandler(256, dtype=torch.float16)
assert handler.state_shape == (256,)
assert handler.dtype == torch.float16
def test_state_handler_conv_state_shape():
"""Test StateResourceHandler with typical conv state shape [in_channels, kernel_size-1]."""
handler = StateResourceHandler(512, 3, dtype=torch.bfloat16)
assert handler.state_shape == (512, 3)
assert handler.dtype == torch.bfloat16
def test_state_handler_ssm_state_shape():
"""Test StateResourceHandler with typical SSM state shape [num_heads, head_dim, ssm_state_size]."""
handler = StateResourceHandler(4, 64, 16, dtype=torch.float32)
assert handler.state_shape == (4, 64, 16)
assert handler.dtype == torch.float32
def test_state_handler_allocate_raises_not_implemented():
"""Verify StateResourceHandler.allocate() raises NotImplementedError."""
handler = StateResourceHandler(4, 64, 16, dtype=torch.bfloat16)
seq_info = SequenceInfo(max_seq_len=128, max_batch_size=4)
with pytest.raises(NotImplementedError, match="Managed resources should not be allocated"):
handler.allocate(seq_info)
def test_state_handler_is_resource_handler():
"""Verify StateResourceHandler is a ResourceHandler subclass."""
handler = StateResourceHandler(4, 64, 16, dtype=torch.bfloat16)
assert isinstance(handler, ResourceHandler)
def test_state_handler_is_managed_resource():
"""Verify StateResourceHandler is a ManagedResourceHandler."""
handler = StateResourceHandler(4, 64, 16, dtype=torch.bfloat16)
assert isinstance(handler, ManagedResourceHandler)
# =============================================================================
# UnpagedResourceHandler Tests
# =============================================================================
def test_unpaged_handler_stores_token_shape_and_dtype():
"""Verify UnpagedResourceHandler stores token_shape and dtype correctly."""
handler = UnpagedResourceHandler(8, 64, dtype=torch.float16)
assert handler.token_shape == (8, 64)
assert handler.dtype == torch.float16
def test_unpaged_handler_single_dimension_token_shape():
"""Test UnpagedResourceHandler with single dimension token shape."""
handler = UnpagedResourceHandler(128, dtype=torch.bfloat16)
assert handler.token_shape == (128,)
assert handler.dtype == torch.bfloat16
@pytest.mark.parametrize(
"num_kv_heads,head_dim,dtype",
[
(8, 64, torch.float16),
(4, 128, torch.bfloat16),
(1, 64, torch.float32),
],
)
def test_unpaged_handler_allocate_returns_correct_shape(num_kv_heads, head_dim, dtype):
"""Verify UnpagedResourceHandler.allocate() returns tensor with correct shape."""
max_batch_size = 4
max_seq_len = 128
handler = UnpagedResourceHandler(num_kv_heads, head_dim, dtype=dtype)
seq_info = SequenceInfo(max_seq_len=max_seq_len, max_batch_size=max_batch_size)
seq_info.to("cuda")
tensor = handler.allocate(seq_info)
expected_shape = (seq_info.max_num_state_slots, max_seq_len, num_kv_heads, head_dim)
assert tensor.shape == expected_shape
assert tensor.dtype == dtype
assert tensor.device.type == "cuda"
def test_unpaged_handler_allocate_correct_device():
"""Verify UnpagedResourceHandler allocated tensor is on the correct device."""
handler = UnpagedResourceHandler(8, 64, dtype=torch.float16)
seq_info = SequenceInfo(max_seq_len=128, max_batch_size=4)
seq_info.to("cuda")
tensor = handler.allocate(seq_info)
assert tensor.device == seq_info.device
def test_unpaged_handler_is_resource_handler():
"""Verify UnpagedResourceHandler is a ResourceHandler subclass."""
handler = UnpagedResourceHandler(8, 64, dtype=torch.float16)
assert isinstance(handler, ResourceHandler)
def test_unpaged_handler_is_not_managed_resource():
"""Verify UnpagedResourceHandler is NOT a ManagedResourceHandler."""
handler = UnpagedResourceHandler(8, 64, dtype=torch.float16)
assert not isinstance(handler, ManagedResourceHandler)
# =============================================================================
# AttentionDescriptor.resolve_cache_dtype() Tests
# =============================================================================
def test_resolve_cache_dtype_auto_returns_fallback_float16():
"""Test 'auto' returns the fallback dtype (float16)."""
result = AttentionDescriptor.resolve_cache_dtype("auto", torch.float16)
assert result == torch.float16
def test_resolve_cache_dtype_auto_returns_fallback_bfloat16():
"""Test 'auto' returns the fallback dtype (bfloat16)."""
result = AttentionDescriptor.resolve_cache_dtype("auto", torch.bfloat16)
assert result == torch.bfloat16
def test_resolve_cache_dtype_auto_returns_fallback_float32():
"""Test 'auto' returns the fallback dtype (float32)."""
result = AttentionDescriptor.resolve_cache_dtype("auto", torch.float32)
assert result == torch.float32
def test_resolve_cache_dtype_explicit_float16():
"""Test explicit 'float16' dtype string resolves correctly."""
result = AttentionDescriptor.resolve_cache_dtype("float16", torch.bfloat16)
assert result == torch.float16
def test_resolve_cache_dtype_explicit_bfloat16():
"""Test explicit 'bfloat16' dtype string resolves correctly."""
result = AttentionDescriptor.resolve_cache_dtype("bfloat16", torch.float16)
assert result == torch.bfloat16
def test_resolve_cache_dtype_explicit_float32():
"""Test explicit 'float32' dtype string resolves correctly."""
result = AttentionDescriptor.resolve_cache_dtype("float32", torch.float16)
assert result == torch.float32
@pytest.mark.skipif(
torch.cuda.get_device_capability(0) < (8, 9), reason="FP8 requires compute capability >= 8.9"
)
def test_resolve_cache_dtype_explicit_fp8():
"""Test explicit 'fp8' dtype string resolves correctly."""
result = AttentionDescriptor.resolve_cache_dtype("fp8", torch.float16)
assert result == torch.float8_e4m3fn

View File

@ -329,7 +329,7 @@ def test_trtllm_fused_moe_fp8(
pytest.skip(f"top_k ({top_k}) cannot be greater than num_experts ({num_experts})")
assert itype == torch.float8_e4m3fn and wtype == torch.float8_e4m3fn, (
"FP8 test only supports float8_e4m3fn"
"FP8 test only supports torch.float8_e4m3fn"
)
assert otype == torch.bfloat16 or otype == torch.float16, (
"FP8 test only supports bfloat16 or float16 output type"

View File

@ -0,0 +1,724 @@
"""Unit tests for CachedSequenceInterface in interface.py.
Tests the refactored CachedSequenceInterface which now:
- Creates SequenceInfo internally with tokens_per_block from KvCacheConfig
- Manages resources via KVCacheManager or MambaHybridCacheManager
- Supports paged resources (KV caches) and state resources (SSM states)
"""
import pytest
import torch
from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import (
PagedResourceHandler,
SequenceInfo,
StateResourceHandler,
UnpagedResourceHandler,
)
from tensorrt_llm._torch.auto_deploy.shim.interface import CachedSequenceInterface
from tensorrt_llm._torch.pyexecutor.mamba_cache_manager import MambaHybridCacheManager
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
from tensorrt_llm.llmapi.llm_args import KvCacheConfig
# =============================================================================
# Fixtures
# =============================================================================
@pytest.fixture
def default_kv_cache_config():
"""KvCacheConfig with default settings."""
return KvCacheConfig()
@pytest.fixture
def paged_kv_cache_config():
"""KvCacheConfig with paging enabled and no resizing."""
return KvCacheConfig(
tokens_per_block=32,
max_tokens=1024,
free_gpu_memory_fraction=0.0, # Disable dynamic resizing
)
@pytest.fixture
def resizable_kv_cache_config():
"""KvCacheConfig with dynamic resizing enabled."""
return KvCacheConfig(
tokens_per_block=32,
max_tokens=1024,
free_gpu_memory_fraction=0.5,
)
# =============================================================================
# Initialization Tests
# =============================================================================
def test_init_creates_sequence_info_with_tokens_per_block(paged_kv_cache_config):
"""Verify SequenceInfo is created with correct tokens_per_block from KvCacheConfig."""
interface = CachedSequenceInterface(
max_seq_len=128,
max_batch_size=4,
device="cuda",
kv_cache_config=paged_kv_cache_config,
)
assert interface.info.tokens_per_block == paged_kv_cache_config.tokens_per_block
assert interface.info.max_seq_len == 128
assert interface.info.max_batch_size == 4
def test_init_uses_default_kv_cache_config_when_not_provided():
"""Verify default KvCacheConfig is used when not provided."""
interface = CachedSequenceInterface(
max_seq_len=128,
max_batch_size=4,
device="cuda",
)
# Default KvCacheConfig should be created
assert interface.kv_cache_config is not None
# Default tokens_per_block is 64 in KvCacheConfig
assert interface.info.tokens_per_block == interface.kv_cache_config.tokens_per_block
def test_init_propagates_max_num_tokens():
"""Verify max_num_tokens propagates to SequenceInfo."""
max_num_tokens = 512
interface = CachedSequenceInterface(
max_seq_len=128,
max_batch_size=4,
max_num_tokens=max_num_tokens,
device="cuda",
)
assert interface.info.max_num_tokens == max_num_tokens
def test_init_propagates_vocab_size_padded():
"""Verify vocab_size_padded propagates to SequenceInfo."""
vocab_size_padded = 32000
interface = CachedSequenceInterface(
max_seq_len=128,
max_batch_size=4,
vocab_size_padded=vocab_size_padded,
device="cuda",
)
assert interface.info.vocab_size_padded == vocab_size_padded
def test_init_stores_device():
"""Verify device is stored correctly."""
interface = CachedSequenceInterface(
max_seq_len=128,
max_batch_size=4,
device="cuda:0",
)
assert interface.device == "cuda:0"
def test_init_default_device_is_cuda():
"""Verify default device is 'cuda' when not specified."""
interface = CachedSequenceInterface(
max_seq_len=128,
max_batch_size=4,
)
assert interface.device == "cuda"
# =============================================================================
# Resource Registration Tests
# =============================================================================
def test_add_resource_paged_handler(paged_kv_cache_config):
"""Test adding a PagedResourceHandler resource."""
interface = CachedSequenceInterface(
max_seq_len=128,
max_batch_size=4,
device="cuda",
kv_cache_config=paged_kv_cache_config,
)
handler = PagedResourceHandler(8, 64, dtype=torch.float16)
interface.add_resource("k_cache_0", handler)
assert "k_cache_0" in interface._resource_lookup
assert interface._resource_lookup["k_cache_0"] is handler
def test_add_resource_state_handler(paged_kv_cache_config):
"""Test adding a StateResourceHandler resource."""
interface = CachedSequenceInterface(
max_seq_len=128,
max_batch_size=4,
device="cuda",
kv_cache_config=paged_kv_cache_config,
)
handler = StateResourceHandler(4, 64, 16, dtype=torch.bfloat16)
interface.add_resource("ssm_state_0", handler)
assert interface._resource_lookup["ssm_state_0"] is handler
def test_add_resource_unpaged_handler(paged_kv_cache_config):
"""Test adding an UnpagedResourceHandler resource."""
interface = CachedSequenceInterface(
max_seq_len=128,
max_batch_size=4,
device="cuda",
kv_cache_config=paged_kv_cache_config,
)
handler = UnpagedResourceHandler(8, 64, dtype=torch.float16)
interface.add_resource("unpaged_cache", handler)
assert "unpaged_cache" in interface._resource_lookup
assert interface._resource_lookup["unpaged_cache"] is handler
def test_add_multiple_resources(paged_kv_cache_config):
"""Test adding multiple resources of different types."""
interface = CachedSequenceInterface(
max_seq_len=128,
max_batch_size=4,
device="cuda",
kv_cache_config=paged_kv_cache_config,
)
k_handler = PagedResourceHandler(8, 64, dtype=torch.float16)
v_handler = PagedResourceHandler(8, 64, dtype=torch.float16)
ssm_handler = StateResourceHandler(4, 64, 16, dtype=torch.bfloat16)
interface.add_resource("k_cache_0", k_handler)
interface.add_resource("v_cache_0", v_handler)
interface.add_resource("ssm_state_0", ssm_handler)
assert len(interface._resource_lookup) == 3
# =============================================================================
# Resource Initialization Tests
# =============================================================================
def test_initialize_resources_paged_only_creates_kv_cache_manager(paged_kv_cache_config):
"""Test paged-only resources create KVCacheManager (not MambaHybridCacheManager)."""
interface = CachedSequenceInterface(
max_seq_len=128,
max_batch_size=4,
device="cuda",
kv_cache_config=paged_kv_cache_config,
)
# Add only paged resources
interface.add_resource("k_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16))
interface.add_resource("v_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16))
num_caches = interface.initialize_resources()
assert num_caches == 2
assert isinstance(interface.kv_cache_manager, KVCacheManager)
assert not isinstance(interface.kv_cache_manager, MambaHybridCacheManager)
def test_initialize_resources_mixed_creates_mamba_hybrid_cache_manager(paged_kv_cache_config):
"""Test mixed paged + state resources create MambaHybridCacheManager."""
interface = CachedSequenceInterface(
max_seq_len=128,
max_batch_size=4,
device="cuda",
kv_cache_config=paged_kv_cache_config,
)
# Add paged and state resources
interface.add_resource("k_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16))
interface.add_resource("v_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16))
interface.add_resource("ssm_state_0", StateResourceHandler(4, 64, 16, dtype=torch.bfloat16))
num_caches = interface.initialize_resources()
assert num_caches == 3
assert isinstance(interface.kv_cache_manager, MambaHybridCacheManager)
def test_initialize_resources_creates_cache_views_with_correct_shape(paged_kv_cache_config):
"""Verify cache views are created with correct shapes."""
interface = CachedSequenceInterface(
max_seq_len=128,
max_batch_size=4,
device="cuda",
kv_cache_config=paged_kv_cache_config,
)
num_kv_heads = 8
head_dim = 64
interface.add_resource(
"k_cache_0", PagedResourceHandler(num_kv_heads, head_dim, dtype=torch.float16)
)
interface.add_resource(
"v_cache_0", PagedResourceHandler(num_kv_heads, head_dim, dtype=torch.float16)
)
interface.initialize_resources()
# Check cache views exist
assert "k_cache_0" in interface._caches
assert "v_cache_0" in interface._caches
# Check shapes: [num_blocks, tokens_per_block, num_kv_heads, head_dim]
k_cache = interface._caches["k_cache_0"]
assert k_cache is not None
assert k_cache.shape[1] == paged_kv_cache_config.tokens_per_block
assert k_cache.shape[2] == num_kv_heads
assert k_cache.shape[3] == head_dim
assert k_cache.dtype == torch.float16
def test_initialize_resources_creates_state_views_with_correct_shape(paged_kv_cache_config):
"""Verify state views are created with correct shapes for MambaHybridCacheManager."""
interface = CachedSequenceInterface(
max_seq_len=128,
max_batch_size=4,
device="cuda",
kv_cache_config=paged_kv_cache_config,
)
num_heads = 4
head_dim = 64
ssm_state_size = 16
interface.add_resource("k_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16))
interface.add_resource(
"ssm_state_0",
StateResourceHandler(num_heads, head_dim, ssm_state_size, dtype=torch.bfloat16),
)
interface.initialize_resources()
# Check state view exists
ssm_cache = interface._caches["ssm_state_0"]
assert ssm_cache is not None
# Shape: [num_states, num_heads, head_dim, ssm_state_size]
assert ssm_cache.shape[1] == num_heads
assert ssm_cache.shape[2] == head_dim
assert ssm_cache.shape[3] == ssm_state_size
assert ssm_cache.dtype == torch.bfloat16
def test_initialize_resources_unpaged_allocated_locally(paged_kv_cache_config):
"""Verify UnpagedResourceHandler resources are allocated locally (not via cache manager)."""
interface = CachedSequenceInterface(
max_seq_len=128,
max_batch_size=4,
device="cuda",
kv_cache_config=paged_kv_cache_config,
)
num_kv_heads = 8
head_dim = 64
interface.add_resource(
"unpaged_cache", UnpagedResourceHandler(num_kv_heads, head_dim, dtype=torch.float16)
)
interface.initialize_resources()
# Check unpaged cache was allocated
assert "unpaged_cache" in interface._caches
unpaged_cache = interface._caches["unpaged_cache"]
assert unpaged_cache is not None
# Shape: [max_batch_size + 1, max_seq_len, num_kv_heads, head_dim]
assert unpaged_cache.shape == (4 + 1, 128, num_kv_heads, head_dim)
def test_is_paged_returns_true_for_paged_only(paged_kv_cache_config):
"""Test is_paged() returns True when all resources are paged."""
interface = CachedSequenceInterface(
max_seq_len=128,
max_batch_size=4,
device="cuda",
kv_cache_config=paged_kv_cache_config,
)
interface.add_resource("k_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16))
interface.add_resource("v_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16))
interface.initialize_resources()
assert interface.is_paged() is True
def test_is_paged_returns_false_for_hybrid(paged_kv_cache_config):
"""Test is_paged() returns False when state resources exist."""
interface = CachedSequenceInterface(
max_seq_len=128,
max_batch_size=4,
device="cuda",
kv_cache_config=paged_kv_cache_config,
)
interface.add_resource("k_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16))
interface.add_resource("ssm_state_0", StateResourceHandler(4, 64, 16, dtype=torch.bfloat16))
interface.initialize_resources()
assert interface.is_paged() is False
# =============================================================================
# KV Cache Resize Tests
# =============================================================================
def test_needs_resize_returns_false_when_fraction_is_zero(paged_kv_cache_config):
"""Test needs_resize() returns False when free_gpu_memory_fraction is 0.0."""
interface = CachedSequenceInterface(
max_seq_len=128,
max_batch_size=4,
device="cuda",
kv_cache_config=paged_kv_cache_config,
)
interface.add_resource("k_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16))
interface.initialize_resources()
assert interface.needs_resize() is False
def test_needs_resize_returns_true_when_fraction_is_positive(resizable_kv_cache_config):
"""Test needs_resize() returns True when free_gpu_memory_fraction is positive."""
interface = CachedSequenceInterface(
max_seq_len=128,
max_batch_size=4,
device="cuda",
kv_cache_config=resizable_kv_cache_config,
)
interface.add_resource("k_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16))
interface.initialize_resources()
assert interface.needs_resize() is True
def test_resize_kv_cache_manager_skipped_when_not_needed(paged_kv_cache_config):
"""Test resize_kv_cache_manager() does nothing when resize not needed."""
interface = CachedSequenceInterface(
max_seq_len=128,
max_batch_size=4,
device="cuda",
kv_cache_config=paged_kv_cache_config,
)
interface.add_resource("k_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16))
interface.initialize_resources()
# Get initial state
initial_manager = interface.kv_cache_manager
# Resize should be a no-op
interface.resize_kv_cache_manager()
# Manager should be the same object (no recreation)
assert interface.kv_cache_manager is initial_manager
# =============================================================================
# Shutdown and Cleanup Tests
# =============================================================================
def test_shutdown_clears_caches(paged_kv_cache_config):
"""Test shutdown() clears all caches."""
interface = CachedSequenceInterface(
max_seq_len=128,
max_batch_size=4,
device="cuda",
kv_cache_config=paged_kv_cache_config,
)
interface.add_resource("k_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16))
interface.add_resource("v_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16))
interface.initialize_resources()
assert len(interface._caches) == 2
interface.shutdown()
assert len(interface._caches) == 0
def test_clear_cache_views_sets_views_to_none(paged_kv_cache_config):
"""Test _clear_cache_views() sets paged and state cache views to None."""
interface = CachedSequenceInterface(
max_seq_len=128,
max_batch_size=4,
device="cuda",
kv_cache_config=paged_kv_cache_config,
)
interface.add_resource("k_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16))
interface.add_resource("ssm_state_0", StateResourceHandler(4, 64, 16, dtype=torch.bfloat16))
interface.initialize_resources()
# Manually call _clear_cache_views
interface._clear_cache_views()
# Paged and state caches should be None
assert interface._caches["k_cache_0"] is None
assert interface._caches["ssm_state_0"] is None
# =============================================================================
# Configuration Update Tests
# =============================================================================
def test_update_kv_cache_config_valid_field():
"""Test update_kv_cache_config() with valid field updates."""
interface = CachedSequenceInterface(
max_seq_len=128,
max_batch_size=4,
device="cuda",
)
interface.update_kv_cache_config(tokens_per_block=64)
assert interface.kv_cache_config.tokens_per_block == 64
def test_update_kv_cache_config_multiple_fields():
"""Test update_kv_cache_config() with multiple field updates."""
interface = CachedSequenceInterface(
max_seq_len=128,
max_batch_size=4,
device="cuda",
)
interface.update_kv_cache_config(
tokens_per_block=64,
max_tokens=2048,
free_gpu_memory_fraction=0.8,
)
assert interface.kv_cache_config.tokens_per_block == 64
assert interface.kv_cache_config.max_tokens == 2048
assert interface.kv_cache_config.free_gpu_memory_fraction == 0.8
def test_update_kv_cache_config_invalid_field_raises():
"""Test update_kv_cache_config() raises ValueError for invalid fields."""
interface = CachedSequenceInterface(
max_seq_len=128,
max_batch_size=4,
device="cuda",
)
with pytest.raises(ValueError, match="Invalid KVCacheConfig field"):
interface.update_kv_cache_config(invalid_field=123)
# =============================================================================
# named_args and args Tests
# =============================================================================
def test_named_args_includes_sequence_info_and_caches(paged_kv_cache_config):
"""Verify named_args includes both SequenceInfo args and caches."""
interface = CachedSequenceInterface(
max_seq_len=128,
max_batch_size=4,
device="cuda",
kv_cache_config=paged_kv_cache_config,
)
interface.add_resource("k_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16))
interface.initialize_resources()
named_args = interface.named_args
# Should contain sequence info args
assert "input_ids" in named_args
assert "position_ids" in named_args
# Should contain cache
assert "k_cache_0" in named_args
def test_args_returns_tuple_of_tensors(paged_kv_cache_config):
"""Verify args returns a tuple of tensors."""
interface = CachedSequenceInterface(
max_seq_len=128,
max_batch_size=4,
device="cuda",
kv_cache_config=paged_kv_cache_config,
)
interface.add_resource("k_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16))
interface.initialize_resources()
args = interface.args
assert isinstance(args, tuple)
assert all(isinstance(a, torch.Tensor) for a in args)
# =============================================================================
# to() method Tests
# =============================================================================
def test_to_moves_sequence_info(paged_kv_cache_config):
"""Verify to() moves SequenceInfo to the target device."""
interface = CachedSequenceInterface(
max_seq_len=128,
max_batch_size=4,
device="cpu",
kv_cache_config=paged_kv_cache_config,
)
interface.to("cuda")
assert interface.info.device.type == "cuda"
# =============================================================================
# SequenceInfo API Tests
# =============================================================================
def test_sequence_info_tokens_per_block_from_constructor():
"""Verify tokens_per_block is set correctly from constructor."""
seq_info = SequenceInfo(max_seq_len=128, max_batch_size=4, tokens_per_block=32)
assert seq_info.tokens_per_block == 32
def test_sequence_info_tokens_per_block_defaults_to_max_seq_len():
"""Verify tokens_per_block defaults to max_seq_len when not provided."""
seq_info = SequenceInfo(max_seq_len=128, max_batch_size=4)
assert seq_info.tokens_per_block == 128
def test_sequence_info_estimate_cache_tokens_per_forward():
"""Test estimate_cache_tokens_per_forward() calculation."""
# With max_num_tokens=64, max_batch_size=4, tokens_per_block=16
# seq_len = ceil(64/4) = 16
# num_blocks_per_seq = ceil(16/16) = 1
# num_blocks_total = 1 * 4 = 4
# result = 4 * 16 = 64
seq_info = SequenceInfo(
max_seq_len=128,
max_batch_size=4,
tokens_per_block=16,
max_num_tokens=64,
)
result = seq_info.estimate_cache_tokens_per_forward()
# Expected: ceil(64/4) = 16 tokens per seq
# ceil(16/16) = 1 block per seq
# 1 * 4 = 4 blocks total
# 4 * 16 = 64 tokens
assert result == 64
def test_sequence_info_estimate_cache_tokens_per_forward_with_overflow():
"""Test estimate_cache_tokens_per_forward() with sequence overflow into extra blocks."""
# With max_num_tokens=100, max_batch_size=4, tokens_per_block=16
# seq_len = ceil(100/4) = 25
# num_blocks_per_seq = ceil(25/16) = 2
# num_blocks_total = 2 * 4 = 8
# result = 8 * 16 = 128
seq_info = SequenceInfo(
max_seq_len=128,
max_batch_size=4,
tokens_per_block=16,
max_num_tokens=100,
)
result = seq_info.estimate_cache_tokens_per_forward()
assert result == 128
def test_sequence_info_estimate_cache_loc_capacity_no_resize():
"""Test estimate_cache_loc_capacity() when capacity is sufficient."""
seq_info = SequenceInfo(
max_seq_len=128,
max_batch_size=4,
tokens_per_block=32,
max_num_tokens=256,
)
initial_capacity = seq_info._input_buffer.get_capacity("cache_loc")
# Request a small capacity that should already be available
seq_info.estimate_cache_loc_capacity(num_blocks=4)
# Capacity should not have changed if it was already sufficient
if initial_capacity >= 4 * 4 + 1: # num_blocks * max_batch_size + 1
assert seq_info._input_buffer.get_capacity("cache_loc") == initial_capacity
def test_sequence_info_estimate_cache_loc_capacity_resizes():
"""Test estimate_cache_loc_capacity() resizes buffer when needed."""
seq_info = SequenceInfo(
max_seq_len=128,
max_batch_size=4,
tokens_per_block=32,
max_num_tokens=128, # Small to have small initial capacity
)
initial_capacity = seq_info._input_buffer.get_capacity("cache_loc")
# Request a large capacity
large_num_blocks = 1000
seq_info.estimate_cache_loc_capacity(num_blocks=large_num_blocks)
expected_capacity = large_num_blocks * 4 + 1 # num_blocks * max_batch_size + 1
if expected_capacity > initial_capacity:
assert seq_info._input_buffer.get_capacity("cache_loc") >= expected_capacity
def test_sequence_info_last_page_len_uses_tokens_per_block():
"""Verify nest_sequences calculates last_page_len using tokens_per_block."""
seq_info = SequenceInfo(
max_seq_len=128,
max_batch_size=4,
tokens_per_block=16,
)
# Set up a sequence with 25 tokens
# last_page_len = (25 - 1) % 16 + 1 = 24 % 16 + 1 = 8 + 1 = 9
input_ids = [[1] * 25]
seq_info.nest_sequences(
input_ids,
input_pos=0,
cache_loc=[0, 1], # 2 pages
pages_per_seq=[2],
)
expected_last_page_len = (25 - 1) % 16 + 1
assert seq_info._args_list["last_page_len"][0] == expected_last_page_len
def test_sequence_info_page_assignments():
"""Test page_assignments property returns correct structure."""
seq_info = SequenceInfo(
max_seq_len=128,
max_batch_size=4,
tokens_per_block=16,
)
# Set up two sequences with different page assignments
input_ids = [[1] * 10, [1] * 20]
seq_info.nest_sequences(
input_ids,
input_pos=0,
cache_loc=[0, 1, 2], # seq 0 has page 0, seq 1 has pages 1 and 2
pages_per_seq=[1, 2],
)
page_assignments = seq_info.page_assignments
assert page_assignments == [[0], [1, 2]]

View File

@ -107,6 +107,7 @@ def test_create_autodeploy_executor_with_guided_decoding(
512 # placeholder to satisfy ADEngine.build_from_config
)
mock_engine.cache_seq_interface.info.vocab_size_padded = vocab_size_padded
mock_engine.cache_seq_interface.max_num_state_slots = max_batch_size
# Mock the specific dependencies requested, plus minimal additional mocks to prevent errors
with (

View File

@ -5,10 +5,11 @@ import torch
import torch.nn as nn
from tensorrt_llm import SamplingParams
from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import SequenceInfo
from tensorrt_llm._torch.auto_deploy.shim.ad_executor import ADEngine
from tensorrt_llm._torch.auto_deploy.shim.demollm import DemoEngine
from tensorrt_llm._torch.auto_deploy.shim.interface import CachedSequenceInterface
from tensorrt_llm._torch.pyexecutor.scheduler import ScheduledRequests
from tensorrt_llm.llmapi.llm_args import KvCacheConfig
class TransformerLikeModelwithFakeCachePool(nn.Module):
@ -38,12 +39,13 @@ def get_inference_model(cache_seq_interface):
model = TransformerLikeModelwithFakeCachePool(vocab_size, embed_dim, hidden_dim)
model.eval().to(device)
return model
@pytest.mark.parametrize("engine_cls", [ADEngine, DemoEngine])
@pytest.mark.parametrize("attn_page_size", [0, 2, 0])
def test_engine(engine_cls: Type[ADEngine], attn_page_size: int):
@pytest.mark.parametrize("tokens_per_block", [0, 2, 0])
def test_engine(engine_cls: Type[ADEngine], tokens_per_block: int):
"""Test the SimpleEngine functionality."""
seed = 42 # Set random seed for model param init
@ -55,21 +57,24 @@ def test_engine(engine_cls: Type[ADEngine], attn_page_size: int):
max_seq_len = 64
max_batch_size = 8
sequence_info = SequenceInfo(
# Create KvCacheConfig with specified tokens_per_block (use 32 as default if 0)
kv_cache_config = KvCacheConfig(tokens_per_block=tokens_per_block or max_seq_len)
cache_seq_interface = CachedSequenceInterface(
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
page_size=attn_page_size,
device=device,
kv_cache_config=kv_cache_config,
)
sequence_info.to(device)
cache_seq_interface.to(device)
engine = engine_cls(get_inference_model, sequence_info, device)
engine = engine_cls(get_inference_model, cache_seq_interface)
# Test basic token generation
with torch.inference_mode():
# Test logits
input_ids = [torch.tensor([0, 1, 2], device=device)]
sequence_info.reset()
sequence_info.nest_sequences(input_ids)
cache_seq_interface.info.reset()
cache_seq_interface.info.nest_sequences(input_ids)
logits = engine._compute_logits()
assert logits is not None, "Logits are None"
@ -77,9 +82,11 @@ def test_engine(engine_cls: Type[ADEngine], attn_page_size: int):
original_logits = get_inference_model(mock_input)(input_ids[0].unsqueeze(0))[0]
assert torch.allclose(logits, original_logits, atol=1e-5), "Generated Token ID mismatch"
cache_seq_interface.shutdown()
@pytest.mark.parametrize("attn_page_size", [0, 2])
def test_demo_engine_sampling(attn_page_size: int):
@pytest.mark.parametrize("tokens_per_block", [0, 2])
def test_demo_engine_sampling(tokens_per_block: int):
"""Test sampling logic specific to DemoEngine."""
seed = 0
torch.manual_seed(seed)
@ -90,19 +97,22 @@ def test_demo_engine_sampling(attn_page_size: int):
max_seq_len = 64
max_batch_size = 8
sequence_info = SequenceInfo(
# Create KvCacheConfig with specified tokens_per_block (use 32 as default if 0)
kv_cache_config = KvCacheConfig(tokens_per_block=tokens_per_block or 32)
cache_seq_interface = CachedSequenceInterface(
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
page_size=attn_page_size,
device=device,
kv_cache_config=kv_cache_config,
)
sequence_info.to(device)
cache_seq_interface.to(device)
engine = DemoEngine(get_inference_model, sequence_info, device)
engine = DemoEngine(get_inference_model, cache_seq_interface)
with torch.inference_mode():
input_ids = [torch.tensor([1, 2, 3, 4], device=device)]
sequence_info.reset()
sequence_info.nest_sequences(input_ids)
cache_seq_interface.info.reset()
cache_seq_interface.info.nest_sequences(input_ids)
logits = engine._compute_logits()
vocab_size = logits.size(-1)
@ -127,6 +137,8 @@ def test_demo_engine_sampling(attn_page_size: int):
torch.testing.assert_close(token_ids_1, token_ids_2)
cache_seq_interface.shutdown()
class _DummyKVCacheManager:
def __init__(self, tokens_per_block: int):
@ -163,8 +175,8 @@ class _DummyRequest:
return self._tokens
@pytest.mark.parametrize("attn_page_size", [256, 2])
def test_ad_engine_chunked_prefill_equivalence(attn_page_size: int):
@pytest.mark.parametrize("tokens_per_block", [256, 2])
def test_ad_engine_chunked_prefill_equivalence(tokens_per_block: int):
"""Verify ADEngine logits match between chunked and non-chunked prefill.
We simulate chunking by splitting a single context request into two chunks and
@ -179,19 +191,22 @@ def test_ad_engine_chunked_prefill_equivalence(attn_page_size: int):
max_seq_len = 64
max_batch_size = 8
sequence_info = SequenceInfo(
# Create KvCacheConfig with specified tokens_per_block
kv_cache_config = KvCacheConfig(tokens_per_block=tokens_per_block)
cache_seq_interface = CachedSequenceInterface(
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
page_size=attn_page_size,
device=device,
kv_cache_config=kv_cache_config,
)
sequence_info.to(device)
cache_seq_interface.to(device)
engine = ADEngine(get_inference_model, sequence_info, device)
engine = ADEngine(get_inference_model, cache_seq_interface)
# A simple prompt; model is position-wise so last token dominates the last logit
tokens = [1, 2, 3, 4, 5, 6]
kv_manager = _DummyKVCacheManager(tokens_per_block=attn_page_size)
kv_manager = _DummyKVCacheManager(tokens_per_block=tokens_per_block)
resource_manager = _DummyResourceManager(kv_manager)
# No-chunk: whole prompt in one request
@ -215,3 +230,237 @@ def test_ad_engine_chunked_prefill_equivalence(attn_page_size: int):
logits_chunked_last = engine.forward(scheduled_requests_part2, resource_manager)["logits"][-1]
torch.testing.assert_close(logits_full_last, logits_chunked_last) # , atol=1e-5)
cache_seq_interface.shutdown()
# =============================================================================
# Hybrid Cache Manager Integration Tests
# =============================================================================
class _DummyHybridKVCacheManager:
"""Simulates MambaHybridCacheManager with mamba_cache_index."""
def __init__(self, tokens_per_block: int, num_slots: int = 8):
self.tokens_per_block = tokens_per_block
# mamba_cache_index maps request_id to slot_idx
self.mamba_cache_index = {i: num_slots - 1 - i for i in range(num_slots)}
self.mamba_cache_free_blocks = num_slots
def get_cache_indices(self, request):
return list(range(1024))
def get_num_kv_blocks(self, num_tokens: int) -> int:
if self.tokens_per_block and self.tokens_per_block > 0:
return (num_tokens + self.tokens_per_block - 1) // self.tokens_per_block
return num_tokens
def get_num_free_blocks(self):
return 100
class _DummyRequestWithRequestId:
"""Request with py_request_id for hybrid cache manager testing."""
def __init__(
self,
tokens: List[int],
begin: int,
size: int,
seq_slot: int = 0,
request_id: int = 0,
):
self._tokens = tokens
self.context_current_position = begin
self.context_chunk_size = size
self.seq_slot = seq_slot
self.py_request_id = request_id
self.py_batch_idx = None
self.py_multimodal_data = None
def get_tokens(self, _beam: int) -> List[int]:
return self._tokens
def test_ad_engine_prepare_inputs_with_hybrid_cache_manager():
"""Test ADEngine _prepare_inputs uses mamba_cache_index when available."""
seed = 42
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
device = torch.device("cuda")
max_seq_len = 64
max_batch_size = 8
tokens_per_block = 16
kv_cache_config = KvCacheConfig(tokens_per_block=tokens_per_block)
cache_seq_interface = CachedSequenceInterface(
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
device=device,
kv_cache_config=kv_cache_config,
)
cache_seq_interface.to(device)
engine = ADEngine(get_inference_model, cache_seq_interface)
# Create hybrid KV cache manager with specific mamba_cache_index mapping
hybrid_manager = _DummyHybridKVCacheManager(tokens_per_block=tokens_per_block)
class _HybridResourceManager:
def __init__(self, kv_mgr):
self._kv = kv_mgr
def get_resource_manager(self, _):
return self._kv
resource_manager = _HybridResourceManager(hybrid_manager)
# Create request with specific request_id
request_id = 3
tokens = [1, 2, 3, 4]
req = _DummyRequestWithRequestId(
tokens=tokens,
begin=0,
size=len(tokens),
seq_slot=0,
request_id=request_id,
)
scheduled = ScheduledRequests()
scheduled.context_requests.append(req)
# Call _prepare_inputs
engine._prepare_inputs(scheduled, resource_manager, new_tokens=None)
# Verify slot_idx was taken from mamba_cache_index, not seq_slot
expected_slot_idx = hybrid_manager.mamba_cache_index[request_id]
actual_slot_idx = cache_seq_interface.info._args_list["slot_idx"][0]
assert actual_slot_idx == expected_slot_idx
cache_seq_interface.shutdown()
def test_ad_engine_prepare_inputs_generation_with_hybrid_cache():
"""Test ADEngine _prepare_inputs handles generation requests with hybrid cache."""
seed = 42
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
device = torch.device("cuda")
max_seq_len = 64
max_batch_size = 8
tokens_per_block = 16
kv_cache_config = KvCacheConfig(tokens_per_block=tokens_per_block)
cache_seq_interface = CachedSequenceInterface(
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
device=device,
kv_cache_config=kv_cache_config,
)
cache_seq_interface.to(device)
engine = ADEngine(get_inference_model, cache_seq_interface)
# Create hybrid KV cache manager
hybrid_manager = _DummyHybridKVCacheManager(tokens_per_block=tokens_per_block)
class _HybridResourceManager:
def __init__(self, kv_mgr):
self._kv = kv_mgr
def get_resource_manager(self, _):
return self._kv
resource_manager = _HybridResourceManager(hybrid_manager)
# Create generation request
class _GenRequest:
def __init__(self, request_id: int, seq_slot: int, num_tokens: int):
self.py_request_id = request_id
self.seq_slot = seq_slot
self.py_batch_idx = None
self.is_dummy = False
self.py_draft_tokens = []
# Mock methods for generation request
def get_token(beam, idx):
return 42 # Dummy token
self.get_token = get_token
self.get_num_tokens = lambda beam: num_tokens
self.max_beam_num_tokens = num_tokens
def get_draft_token_length(self):
return 0
# Create a generation request with specific request_id
request_id = 2
gen_req = _GenRequest(request_id=request_id, seq_slot=5, num_tokens=10)
scheduled = ScheduledRequests()
scheduled.generation_requests.append(gen_req)
# Call _prepare_inputs
engine._prepare_inputs(scheduled, resource_manager, new_tokens=None)
# Verify slot_idx was taken from mamba_cache_index
expected_slot_idx = hybrid_manager.mamba_cache_index[request_id]
actual_slot_idx = cache_seq_interface.info._args_list["slot_idx"][0]
assert actual_slot_idx == expected_slot_idx
cache_seq_interface.shutdown()
def test_ad_engine_with_regular_kv_cache_manager():
"""Test ADEngine falls back to seq_slot when mamba_cache_index not available."""
seed = 42
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
device = torch.device("cuda")
max_seq_len = 64
max_batch_size = 8
tokens_per_block = 16
kv_cache_config = KvCacheConfig(tokens_per_block=tokens_per_block)
cache_seq_interface = CachedSequenceInterface(
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
device=device,
kv_cache_config=kv_cache_config,
)
cache_seq_interface.to(device)
engine = ADEngine(get_inference_model, cache_seq_interface)
# Use regular (non-hybrid) KV cache manager without mamba_cache_index
regular_manager = _DummyKVCacheManager(tokens_per_block=tokens_per_block)
resource_manager = _DummyResourceManager(regular_manager)
# Create request with specific seq_slot
expected_seq_slot = 3
tokens = [1, 2, 3, 4]
req = _DummyRequest(
tokens=tokens,
begin=0,
size=len(tokens),
seq_slot=expected_seq_slot,
)
scheduled = ScheduledRequests()
scheduled.context_requests.append(req)
# Call _prepare_inputs
engine._prepare_inputs(scheduled, resource_manager, new_tokens=None)
# Verify slot_idx falls back to seq_slot
actual_slot_idx = cache_seq_interface.info._args_list["slot_idx"][0]
assert actual_slot_idx == expected_seq_slot
cache_seq_interface.shutdown()

View File

@ -4,7 +4,6 @@ import pydantic
import pytest
from tensorrt_llm._torch.auto_deploy import LLM, DemoLLM, LlmArgs
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
def test_custom_values():
@ -14,7 +13,6 @@ def test_custom_values():
"model_factory": "AutoModelForImageTextToText",
"model_kwargs": {"custom_param": True},
"skip_loading_weights": True,
"attn_page_size": 128,
"max_seq_len": 2048,
"transforms": {
"detect_sharding": {
@ -25,10 +23,6 @@ def test_custom_values():
"stage": "cache_init",
"backend": "flashinfer",
},
"resize_kv_cache": {
"stage": "cache_init",
"free_mem_ratio": 0.9,
},
},
}
@ -39,32 +33,12 @@ def test_custom_values():
"custom_param": True,
}
assert args.skip_loading_weights
assert args.transforms["resize_kv_cache"]["free_mem_ratio"] == 0.9
assert args.transforms["detect_sharding"]["simple_shard_only"]
assert args.attn_page_size == 128
assert args.max_seq_len == 2048
# backend should be overridden if it was 'TRTLLM'
assert args.transforms["insert_cached_attention"]["backend"] == "flashinfer"
def test_free_mem_ratio_validation():
"""Test free_mem_ratio validation."""
def get_transform_config(free_mem_ratio):
return {"resize_kv_cache": {"stage": "cache_init", "free_mem_ratio": free_mem_ratio}}
# Valid values
InferenceOptimizer(None, get_transform_config(0.0))
InferenceOptimizer(None, get_transform_config(1.0))
InferenceOptimizer(None, get_transform_config(0.5))
# Invalid values
with pytest.raises(ValueError):
InferenceOptimizer(None, get_transform_config(-0.1))
with pytest.raises(ValueError):
InferenceOptimizer(None, get_transform_config(1.1))
# ================================
# Config Flow Tests
# ================================
@ -77,7 +51,6 @@ def test_config_params():
"model": "test-model",
"model_factory": "AutoModelForImageTextToText",
"skip_loading_weights": True,
"attn_page_size": 17,
"max_seq_len": 19,
"max_batch_size": 5,
"world_size": 3,
@ -90,10 +63,6 @@ def test_config_params():
"stage": "cache_init",
"backend": "flashinfer",
},
"resize_kv_cache": {
"stage": "cache_init",
"free_mem_ratio": 0.7,
},
},
}
@ -151,16 +120,11 @@ def test_config_flow(
# Common assertions for both APIs
assert instance.args.model_factory == test_config_params["model_factory"]
assert (
instance.args.transforms["resize_kv_cache"]["free_mem_ratio"]
== test_config_params["transforms"]["resize_kv_cache"]["free_mem_ratio"]
)
assert (
instance.args.transforms["detect_sharding"]["simple_shard_only"]
== test_config_params["transforms"]["detect_sharding"]["simple_shard_only"]
)
assert instance.args.skip_loading_weights == test_config_params["skip_loading_weights"]
assert instance.args.attn_page_size == test_config_params["attn_page_size"]
assert instance.args.max_seq_len == test_config_params["max_seq_len"]
assert instance.args.max_batch_size == test_config_params["max_batch_size"]
@ -215,23 +179,6 @@ def test_parallel_config_validation(parallel_field, invalid_value):
LlmArgs(**kwargs)
@pytest.mark.parametrize(
"backend,expected_attn_page_size",
[
("flashinfer", 64), # Default attn_page_size
("triton", 1024), # Should equal max_seq_len
],
)
def test_attention_backend_page_size_logic(backend, expected_attn_page_size):
"""Test attn_page_size logic for different attention backends."""
args = LlmArgs(
model="test-model",
max_seq_len=1024,
transforms={"insert_cached_attention": {"stage": "cache_init", "backend": backend}},
)
assert args.attn_page_size == expected_attn_page_size
# ================================
# CUDA Graph Batch Sizes Tests
# ================================
@ -351,7 +298,7 @@ class TestSequenceInfoExampleBatchSize:
max_batch_size=1,
max_seq_len=128,
max_num_tokens=128,
page_size=64,
tokens_per_block=64,
)
# Set example sequence (this is what's used during export)
@ -371,7 +318,7 @@ class TestSequenceInfoExampleBatchSize:
max_batch_size=32,
max_seq_len=128,
max_num_tokens=128,
page_size=64,
tokens_per_block=64,
)
seq_info.set_example_sequence()

View File

@ -41,8 +41,10 @@ def _check_ad_config(experiment_config: ExperimentConfig, llm_args: LlmArgs):
(
"meta-llama/Meta-Llama-3.1-8B-Instruct",
{
"kv_cache_config": {
"free_gpu_memory_fraction": 0.0001,
},
"transforms": {
"resize_kv_cache": {"free_mem_ratio": 0.0001},
"insert_cached_attention": {"backend": "flashinfer"},
# TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/9878
# "compile_model": {"backend": "torch-opt"},

View File

@ -84,25 +84,27 @@ def prepare_dataset(root_dir: str, temp_dir: str, model_path_or_name: str):
@pytest.mark.parametrize("compile_backend", ["torch-compile", "torch-opt", "torch-cudagraph"])
@pytest.mark.parametrize("model_name", ["TinyLlama/TinyLlama-1.1B-Chat-v1.0"])
def test_trtllm_bench(llm_root, compile_backend, model_name): # noqa: F811
config = get_small_model_config(model_name)
args = get_small_model_config(model_name)["args"]
# remove kv_cache_config and max_batch_size to avoid conflicts with trtllm-bench
args.pop("kv_cache_config", None)
args.pop("max_batch_size", None)
with tempfile.TemporaryDirectory() as temp_dir:
extra_llm_api_options_path = f"{temp_dir}/extra_llm_api_options.yaml"
with open(extra_llm_api_options_path, "w") as f:
yaml.dump(
{
**config["args"],
**args,
"transforms": {
"resize_kv_cache": {"enabled": False}, # rely on default estimation
"compile_model": {
"stage": "compile",
"cuda_graph_batch_sizes": [1, 2, 4, 8, 16, 32, 64, 128],
"backend": compile_backend,
}
},
},
},
f,
)
dataset_path = prepare_dataset(llm_root, temp_dir, config["args"]["model"])
run_benchmark(
model_name, str(config["args"]["model"]), dataset_path, extra_llm_api_options_path
)
dataset_path = prepare_dataset(llm_root, temp_dir, args["model"])
run_benchmark(model_name, str(args["model"]), dataset_path, extra_llm_api_options_path)

View File

@ -22,7 +22,6 @@ from torch.export import Dim
from torch.fx import GraphModule
# Import to register custom op
from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import SequenceInfo
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
from tensorrt_llm._torch.auto_deploy.shim.interface import CachedSequenceInterface
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
@ -154,13 +153,12 @@ class TestGatherLogitsBeforeLmHeadTransform:
def _create_cached_sequence_interface(self, max_batch_size: int = 8, device: str = "cuda"):
"""Create a mock CachedSequenceInterface for testing."""
seq_info = SequenceInfo(
return CachedSequenceInterface(
max_seq_len=64,
max_batch_size=max_batch_size,
device=device,
max_num_tokens=1024,
)
seq_info.to(device)
return CachedSequenceInterface(seq_info, device=device)
def _check_gather_op_in_graph(self, gm: GraphModule) -> bool:
"""Check if gather_logits_before_lm_head op is in the graph."""

View File

@ -1,4 +1,5 @@
from typing import List, Optional
from unittest.mock import MagicMock
import pytest
import torch
@ -6,7 +7,8 @@ import torch.nn as nn
from _model_test_utils import GQA
from _torch_test_utils import all_close
from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import CacheConfig, SequenceInfo
# Initialize resources first
from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import PagedResourceHandler
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
from tensorrt_llm._torch.auto_deploy.models.factory import (
FullModelExportInfo,
@ -14,15 +16,18 @@ from tensorrt_llm._torch.auto_deploy.models.factory import (
SubModuleExportInfo,
)
from tensorrt_llm._torch.auto_deploy.shim.interface import CachedSequenceInterface
from tensorrt_llm._torch.auto_deploy.transform.interface import Stages, TransformConfig
from tensorrt_llm._torch.auto_deploy.transform.library.kvcache import InitializeCache, ResizeKVCache
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
from tensorrt_llm.llmapi.llm_args import KvCacheConfig
class DummyFactory(ModelFactory):
"""Dummy factory to pass cache_config for testing."""
"""Dummy factory to pass cache_config_updates for testing."""
def __init__(self, model, cache_config):
def __init__(self, model, cache_config_updates):
self._model = model
self.cache_config = cache_config
self.cache_config_updates = cache_config_updates
def build_model(self, device: str):
return self._model.to(device=device)
@ -33,8 +38,8 @@ class DummyFactory(ModelFactory):
def _load_checkpoint(self, model, device):
return
def get_cache_config(self):
return self.cache_config
def get_cache_config_updates(self):
return self.cache_config_updates
def get_export_infos(self, model: nn.Module) -> List[SubModuleExportInfo]:
return [FullModelExportInfo()]
@ -151,12 +156,19 @@ def test_sdpa_with_kv_cache(dtype, attn_backend, gqa_config):
max_position_embeddings = 128
vocab_size = 1000
# set up sequence+cache objects using standard SequenceInfo
ci = SequenceInfo(
# set up sequence+cache objects using CachedSequenceInterface
# Use tokens_per_block=max_position_embeddings so each sequence fits in 1 page for the test
kv_cache_config = KvCacheConfig(
tokens_per_block=max_position_embeddings,
max_tokens=batch_size * max_position_embeddings,
free_gpu_memory_fraction=0.0, # Disable dynamic resizing for test
)
cm = CachedSequenceInterface(
max_seq_len=max_position_embeddings,
max_batch_size=batch_size,
device="cuda",
kv_cache_config=kv_cache_config,
)
cm = CachedSequenceInterface(sequence_info=ci, device="cuda")
# Create the model with embedding layer and SDPA, wrap it in a fake factory
model = GQAWithSdpaAndEmbedding(
@ -175,7 +187,7 @@ def test_sdpa_with_kv_cache(dtype, attn_backend, gqa_config):
# Apply the transformation
optimizer = InferenceOptimizer(
DummyFactory(model, CacheConfig()),
DummyFactory(model, cache_config_updates={}),
{
"build_model": {
"stage": "factory",
@ -204,7 +216,7 @@ def test_sdpa_with_kv_cache(dtype, attn_backend, gqa_config):
gm = optimizer(cm)
gm.to("cuda")
num_caches = cm.initialize_caches()
num_caches = cm.initialize_resources()
print(f"num_caches: {num_caches}")
# Helper function to call the model with proper sequence nesting
@ -248,3 +260,273 @@ def test_sdpa_with_kv_cache(dtype, attn_backend, gqa_config):
# Test 4: Exportability of the transformed model
exported_gm = torch_export_to_gm(gm, args=(), kwargs=cm.named_args)
assert exported_gm is not None
# =============================================================================
# Transform Unit Tests for Refactored Pipeline
# =============================================================================
@pytest.fixture
def dummy_cached_interface():
"""Create a CachedSequenceInterface for transform testing."""
kv_cache_config = KvCacheConfig(
tokens_per_block=32,
max_tokens=256,
free_gpu_memory_fraction=0.0,
)
return CachedSequenceInterface(
max_seq_len=128,
max_batch_size=4,
device="cuda",
kv_cache_config=kv_cache_config,
)
def test_initialize_cache_transform_calls_initialize_resources(dummy_cached_interface):
"""Verify InitializeCache transform calls cm.initialize_resources()."""
# Create a mock module
mock_module = MagicMock()
# Create the transform with a proper config
transform = InitializeCache(config=TransformConfig(stage=Stages.PATTERN_MATCHER))
# Add a resource to verify initialize_resources is called
from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import PagedResourceHandler
dummy_cached_interface.add_resource(
"k_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16)
)
# Mock the factory and shared_config
mock_factory = MagicMock()
mock_shared_config = MagicMock()
# Run the transform
result_mod, info = transform._apply_to_full_model(
mock_module, dummy_cached_interface, mock_factory, mock_shared_config
)
# Verify caches were initialized
assert info.skipped is False
assert info.num_matches >= 1
assert dummy_cached_interface.kv_cache_manager is not None
def test_resize_kv_cache_transform_skipped_when_not_needed(dummy_cached_interface):
"""Verify ResizeKVCache transform is skipped when resize not needed."""
dummy_cached_interface.add_resource(
"k_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16)
)
dummy_cached_interface.initialize_resources()
# Create the transform with a proper config
transform = ResizeKVCache(config=TransformConfig(stage=Stages.PATTERN_MATCHER))
# Create a mock module
mock_module = MagicMock()
# Mock forward call
mock_module.side_effect = lambda **kwargs: None
mock_factory = MagicMock()
mock_shared_config = MagicMock()
# Run the transform - should be skipped since free_gpu_memory_fraction=0.0
result_mod, info = transform._apply_to_full_model(
mock_module, dummy_cached_interface, mock_factory, mock_shared_config
)
# Verify transform was skipped
assert info.skipped is True
def test_resize_kv_cache_transform_runs_when_needed():
"""Verify ResizeKVCache transform runs when resize is needed."""
# Create interface with resizing enabled
kv_cache_config = KvCacheConfig(
tokens_per_block=32,
max_tokens=256,
free_gpu_memory_fraction=0.5, # Enable resizing
)
cm = CachedSequenceInterface(
max_seq_len=128,
max_batch_size=4,
device="cuda",
kv_cache_config=kv_cache_config,
)
cm.add_resource("k_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16))
cm.initialize_resources()
# Create the transform with a proper config
transform = ResizeKVCache(config=TransformConfig(stage=Stages.PATTERN_MATCHER))
# Create a simple mock module that just returns None
class MockModule:
def __call__(self, **kwargs):
return None
mock_module = MockModule()
mock_factory = MagicMock()
mock_shared_config = MagicMock()
# Run the transform
result_mod, info = transform._apply_to_full_model(
mock_module, cm, mock_factory, mock_shared_config
)
# Verify transform was not skipped
assert info.skipped is False
def test_insert_cached_attention_uses_add_resource():
"""Verify InsertCachedAttention uses cm.add_resource() for cache registration."""
# This test verifies the integration point between InsertCachedAttention
# and CachedSequenceInterface.add_resource() by checking that after the
# transform, resources are registered in the interface.
num_attention_heads = 8
hidden_size = 512
num_key_value_heads = 8
vocab_size = 1000
batch_size = 4
max_seq_len = 64
kv_cache_config = KvCacheConfig(
tokens_per_block=max_seq_len,
max_tokens=batch_size * max_seq_len,
free_gpu_memory_fraction=0.0,
)
cm = CachedSequenceInterface(
max_seq_len=max_seq_len,
max_batch_size=batch_size,
device="cuda",
kv_cache_config=kv_cache_config,
)
# Create a model
model = GQAWithSdpaAndEmbedding(
num_attention_heads,
hidden_size,
num_key_value_heads,
vocab_size=vocab_size,
).to(dtype=torch.float16, device="cuda")
# Apply transformation
optimizer = InferenceOptimizer(
DummyFactory(model, cache_config_updates={}),
{
"build_model": {
"stage": "factory",
"run_per_gm": False,
"device": "cuda",
"run_graph_cleanup": False,
"requires_clean_graph": False,
},
"export_to_gm": {
"stage": "export",
"strict": False,
"run_per_gm": False,
"clone_state_dict": True,
"run_graph_cleanup": False,
"requires_clean_graph": False,
},
"cleanup_input_constraints": {
"stage": "post_export",
},
"insert_cached_attention": {
"stage": "cache_init",
"backend": "triton",
},
},
)
optimizer(cm)
# Verify resources were added
assert len(cm._resource_lookup) > 0
# Should have k_cache and v_cache resources registered
resource_names = list(cm._resource_lookup.keys())
assert any("k_cache" in name for name in resource_names)
assert any("v_cache" in name for name in resource_names)
def test_insert_cached_attention_passes_kv_cache_config():
"""Verify InsertCachedAttention passes cm.kv_cache_config to get_cache_initializers."""
# This test verifies that the KvCacheConfig from the interface is used
# when initializing cache resources (e.g., for dtype configuration).
num_attention_heads = 8
hidden_size = 512
num_key_value_heads = 8
vocab_size = 1000
batch_size = 4
max_seq_len = 64
# Use specific dtype in kv_cache_config
kv_cache_config = KvCacheConfig(
tokens_per_block=max_seq_len,
max_tokens=batch_size * max_seq_len,
free_gpu_memory_fraction=0.0,
dtype="bfloat16", # Specify explicit dtype
)
cm = CachedSequenceInterface(
max_seq_len=max_seq_len,
max_batch_size=batch_size,
device="cuda",
kv_cache_config=kv_cache_config,
)
# Verify kv_cache_config is accessible
assert cm.kv_cache_config.dtype == "bfloat16"
# Create a model
model = GQAWithSdpaAndEmbedding(
num_attention_heads,
hidden_size,
num_key_value_heads,
vocab_size=vocab_size,
).to(dtype=torch.bfloat16, device="cuda")
# Apply transformation
optimizer = InferenceOptimizer(
DummyFactory(model, cache_config_updates={}),
{
"build_model": {
"stage": "factory",
"run_per_gm": False,
"device": "cuda",
"run_graph_cleanup": False,
"requires_clean_graph": False,
},
"export_to_gm": {
"stage": "export",
"strict": False,
"run_per_gm": False,
"clone_state_dict": True,
"run_graph_cleanup": False,
"requires_clean_graph": False,
},
"cleanup_input_constraints": {
"stage": "post_export",
},
"insert_cached_attention": {
"stage": "cache_init",
"backend": "triton",
},
},
)
optimizer(cm)
# Initialize resources
cm.initialize_resources()
assert not cm.is_paged(), "triton should not use paged resources"
assert cm._caches, "at least some resources should be present"
# Verify cache dtype matches config
for name, handler in cm._resource_lookup.items():
if hasattr(handler, "dtype"):
assert handler.dtype == torch.bfloat16