mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-04 18:21:52 +08:00
Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
This commit is contained in:
parent
7f8c260601
commit
ff3a494f5c
@ -55,14 +55,17 @@ skip_loading_weights: false
|
||||
# Sequence configuration
|
||||
max_batch_size: 256
|
||||
|
||||
# transform options
|
||||
# KV cache configuration
|
||||
kv_cache_config:
|
||||
# fraction of free memory to use for kv-caches
|
||||
free_gpu_memory_fraction: 0.8
|
||||
|
||||
# transform options
|
||||
transforms:
|
||||
insert_cached_attention:
|
||||
# attention backend
|
||||
backend: flashinfer
|
||||
resize_kv_cache:
|
||||
# fraction of free memory to use for kv-caches
|
||||
free_mem_ratio: 0.8
|
||||
compile_model:
|
||||
# compilation backend
|
||||
backend: torch-opt
|
||||
@ -80,7 +83,7 @@ Enable multi-GPU execution by specifying `--tp n`, where `n` is the number of GP
|
||||
|-----------|---------|-------------|
|
||||
| `compile_backend` | `torch-compile` | Compilation backend: `torch-simple`, `torch-compile`, `torch-cudagraph`, `torch-opt` |
|
||||
| `runtime` | `trtllm` | Runtime engine: `trtllm`, `demollm` |
|
||||
| `free_mem_ratio` | `0.0` | Fraction of available GPU memory for KV cache (0.0-1.0) |
|
||||
| `kv_cache_config.free_gpu_memory_fraction` | `0.9` | Fraction of available GPU memory for KV cache (0.0-1.0) |
|
||||
| `skip_loading_weights` | `false` | Skip weight loading for architecture-only benchmarks |
|
||||
|
||||
### CUDA Graph Optimization
|
||||
@ -95,7 +98,7 @@ For optimal CUDA graph performance, specify batch sizes that match your expected
|
||||
|
||||
## Performance Optimization Tips
|
||||
|
||||
1. **Memory Management**: Set `free_mem_ratio` to 0.8-0.9 for optimal KV cache utilization
|
||||
1. **Memory Management**: Set `kv_cache_config.free_gpu_memory_fraction` to 0.8-0.9 for optimal KV cache utilization
|
||||
1. **Compilation Backend**: Use `torch-opt` for production workloads
|
||||
1. **Attention Backend**: `flashinfer` generally provides the best performance for most models
|
||||
1. **CUDA Graphs**: Enable CUDA graphs for batch sizes that match your production traffic patterns.
|
||||
|
||||
@ -15,10 +15,8 @@ llm = LLM(
|
||||
compile_backend="torch-compile",
|
||||
model_kwargs={"num_hidden_layers": 2}, # test with smaller model configuration
|
||||
attn_backend="flashinfer", # choose between "triton" and "flashinfer"
|
||||
attn_page_size=64, # page size for attention (tokens_per_block, should be == max_seq_len for triton)
|
||||
skip_loading_weights=False,
|
||||
model_factory="AutoModelForCausalLM", # choose appropriate model factory
|
||||
free_mem_ratio=0.8, # fraction of available memory for cache
|
||||
max_seq_len=<MAX_SEQ_LEN>,
|
||||
max_batch_size=<MAX_BATCH_SIZE>,
|
||||
)
|
||||
|
||||
@ -49,14 +49,17 @@ skip_loading_weights: false
|
||||
# Sequence configuration
|
||||
max_batch_size: 256
|
||||
|
||||
# transform options
|
||||
# KV cache configuration
|
||||
kv_cache_config:
|
||||
# fraction of free memory to use for kv-caches
|
||||
free_gpu_memory_fraction: 0.9
|
||||
|
||||
# transform options
|
||||
transforms:
|
||||
insert_cached_attention:
|
||||
# attention backend
|
||||
backend: flashinfer
|
||||
resize_kv_cache:
|
||||
# fraction of free memory to use for kv-caches
|
||||
free_mem_ratio: 0.8
|
||||
compile_model:
|
||||
# compilation backend
|
||||
backend: torch-opt
|
||||
@ -74,7 +77,7 @@ Enable multi-GPU execution by specifying `--tp n`, where `n` is the number of GP
|
||||
|-----------|---------|-------------|
|
||||
| `compile_backend` | `torch-compile` | Compilation backend: `torch-simple`, `torch-compile`, `torch-cudagraph`, `torch-opt` |
|
||||
| `runtime` | `trtllm` | Runtime engine: `trtllm`, `demollm` |
|
||||
| `free_mem_ratio` | `0.0` | Fraction of available GPU memory for KV cache (0.0-1.0) |
|
||||
| `kv_cache_config.free_gpu_memory_fraction` | `0.9` | Fraction of available GPU memory for KV cache (0.0-1.0) |
|
||||
| `skip_loading_weights` | `false` | Skip weight loading for architecture-only benchmarks |
|
||||
|
||||
### CUDA Graph Optimization
|
||||
@ -89,7 +92,7 @@ For optimal CUDA graph performance, specify batch sizes that match your expected
|
||||
|
||||
## Performance Optimization Tips
|
||||
|
||||
1. **Memory Management**: Set `free_mem_ratio` to 0.8-0.9 for optimal KV cache utilization
|
||||
1. **Memory Management**: Set `kv_cache_config.free_gpu_memory_fraction` to 0.8-0.9 for optimal KV cache utilization
|
||||
1. **Compilation Backend**: Use `torch-opt` for production workloads
|
||||
1. **Attention Backend**: `flashinfer` generally provides the best performance for most models
|
||||
1. **CUDA Graphs**: Enable CUDA graphs for batch sizes that match your production traffic patterns.
|
||||
|
||||
@ -54,14 +54,17 @@ max_batch_size: 256
|
||||
# multi-gpu execution
|
||||
world_size: 1
|
||||
|
||||
# transform options
|
||||
# KV cache configuration
|
||||
kv_cache_config:
|
||||
# fraction of free memory to use for kv-caches
|
||||
free_gpu_memory_fraction: 0.9
|
||||
|
||||
# transform options
|
||||
transforms:
|
||||
insert_cached_attention:
|
||||
# attention backend
|
||||
backend: flashinfer
|
||||
resize_kv_cache:
|
||||
# fraction of free memory to use for kv-caches
|
||||
free_mem_ratio: 0.8
|
||||
compile_model:
|
||||
# compilation backend
|
||||
backend: torch-opt
|
||||
@ -77,7 +80,7 @@ transforms:
|
||||
- Prefer `compile_backend: torch-opt`
|
||||
- Use `attn_backend: flashinfer`
|
||||
- Set realistic `cuda_graph_batch_sizes` that match expected traffic
|
||||
- Tune `free_mem_ratio` to 0.8–0.9
|
||||
- Tune `kv_cache_config.free_gpu_memory_fraction` to 0.8–0.9
|
||||
|
||||
## See also
|
||||
|
||||
|
||||
@ -17,12 +17,9 @@ llm = LLM(
|
||||
transforms={
|
||||
"insert_cached_attention": {"backend": "flashinfer"}, # or "triton"
|
||||
"insert_cached_mla_attention": {"backend": "MultiHeadLatentAttention"},
|
||||
"resize_kv_cache": {"free_mem_ratio": 0.8},
|
||||
"compile_model": {"backend": "torch-compile"},
|
||||
"detect_sharding": {"simple_shard_only": False},
|
||||
|
||||
},
|
||||
attn_page_size=64, # page size for attention
|
||||
skip_loading_weights=False,
|
||||
max_seq_len=<MAX_SEQ_LEN>,
|
||||
max_batch_size=<MAX_BATCH_SIZE>,
|
||||
|
||||
@ -49,7 +49,6 @@ Below is a non-exhaustive list of common configuration options:
|
||||
| `--args.mla-backend` | Specifies implementation for multi-head latent attention |
|
||||
| `--args.max-seq-len` | Maximum sequence length for inference/cache |
|
||||
| `--args.max-batch-size` | Maximum dimension for statically allocated KV cache |
|
||||
| `--args.attn-page-size` | Page size for attention |
|
||||
| `--prompt.batch-size` | Number of queries to generate |
|
||||
| `--benchmark.enabled` | Whether to run the built-in benchmark (true/false) |
|
||||
|
||||
@ -125,10 +124,8 @@ llm = LLM(
|
||||
compile_backend="torch-compile",
|
||||
model_kwargs={"num_hidden_layers": 2}, # test with smaller model configuration
|
||||
attn_backend="flashinfer", # choose between "triton" and "flashinfer"
|
||||
attn_page_size=64, # page size for attention (tokens_per_block, should be == max_seq_len for triton)
|
||||
skip_loading_weights=False,
|
||||
model_factory="AutoModelForCausalLM", # choose appropriate model factory
|
||||
free_mem_ratio=0.8, # fraction of available memory for cache
|
||||
max_seq_len=<MAX_SEQ_LEN>,
|
||||
max_batch_size=<MAX_BATCH_SIZE>,
|
||||
)
|
||||
|
||||
@ -3,7 +3,6 @@
|
||||
|
||||
max_batch_size: 1024
|
||||
max_num_tokens: 2048
|
||||
free_mem_ratio: 0.9
|
||||
trust_remote_code: true
|
||||
cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 768, 1024]
|
||||
kv_cache_config:
|
||||
|
||||
@ -3,7 +3,6 @@
|
||||
|
||||
max_batch_size: 1024
|
||||
max_num_tokens: 2048
|
||||
free_mem_ratio: 0.9
|
||||
trust_remote_code: true
|
||||
cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 768, 1024]
|
||||
kv_cache_config:
|
||||
|
||||
@ -6,12 +6,12 @@ enable_chunked_prefill: true
|
||||
attn_backend: flashinfer
|
||||
model_factory: AutoModelForCausalLM
|
||||
skip_loading_weights: false
|
||||
# TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/9884
|
||||
free_mem_ratio: 0.88
|
||||
cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 24, 32, 64, 128, 256, 320, 384]
|
||||
kv_cache_config:
|
||||
# disable kv_cache reuse since not supported for hybrid/ssm models
|
||||
enable_block_reuse: false
|
||||
free_gpu_memory_fraction: 0.88
|
||||
# tunable mamba cache dtype
|
||||
# --> use float32 for accuracy and default (auto) for speed
|
||||
mamba_ssm_cache_dtype: auto
|
||||
transforms:
|
||||
detect_sharding:
|
||||
allreduce_strategy: SYMM_MEM
|
||||
@ -39,12 +39,6 @@ transforms:
|
||||
multi_stream_moe:
|
||||
stage: compile
|
||||
enabled: true
|
||||
# tunable mamba cache dtype
|
||||
# --> use float32 for accuracy and default (null) for speed
|
||||
insert_cached_ssm_attention:
|
||||
cache_config:
|
||||
# mamba_dtype: float32
|
||||
mamba_dtype: null
|
||||
gather_logits_before_lm_head:
|
||||
# TODO: fix https://github.com/NVIDIA/TensorRT-LLM/issues/9878 to enable by default
|
||||
enabled: true
|
||||
|
||||
@ -4,11 +4,7 @@ max_seq_len: 2097152
|
||||
max_num_tokens: 8192
|
||||
enable_chunked_prefill: true
|
||||
model_factory: NemotronFlashForCausalLM
|
||||
free_mem_ratio: 0.9
|
||||
cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 24, 32, 64, 96, 128, 256, 320, 384]
|
||||
kv_cache_config:
|
||||
# disable kv_cache reuse since not supported for hybrid/ssm models
|
||||
enable_block_reuse: false
|
||||
transforms:
|
||||
gather_logits_before_lm_head:
|
||||
# TODO: fix https://github.com/NVIDIA/TensorRT-LLM/issues/9878 to enable by default
|
||||
|
||||
@ -6,11 +6,11 @@ enable_chunked_prefill: true
|
||||
attn_backend: flashinfer
|
||||
model_factory: AutoModelForCausalLM
|
||||
skip_loading_weights: false
|
||||
free_mem_ratio: 0.9
|
||||
cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 24, 32, 64, 128, 256, 320, 384]
|
||||
kv_cache_config:
|
||||
# disable kv_cache reuse since not supported for hybrid/ssm models
|
||||
enable_block_reuse: false
|
||||
# tunable mamba cache dtype
|
||||
# --> use float32 for accuracy and default (auto) for speed
|
||||
mamba_ssm_cache_dtype: auto
|
||||
transforms:
|
||||
detect_sharding:
|
||||
allreduce_strategy: SYMM_MEM
|
||||
@ -38,12 +38,6 @@ transforms:
|
||||
multi_stream_moe:
|
||||
stage: compile
|
||||
enabled: false
|
||||
# tunable mamba cache dtype
|
||||
# --> use float32 for accuracy and default (null) for speed
|
||||
insert_cached_ssm_attention:
|
||||
cache_config:
|
||||
# mamba_dtype: float32
|
||||
mamba_dtype: null
|
||||
gather_logits_before_lm_head:
|
||||
# TODO: fix https://github.com/NVIDIA/TensorRT-LLM/issues/9878 to enable by default
|
||||
enabled: true
|
||||
|
||||
@ -37,15 +37,16 @@ enable_chunked_prefill: true
|
||||
attn_backend: flashinfer
|
||||
model_factory: AutoModelForCausalLM
|
||||
skip_loading_weights: false
|
||||
free_mem_ratio: 0.9
|
||||
cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 24, 32, 64, 128, 256, 320, 384]
|
||||
kv_cache_config:
|
||||
# disable kv_cache reuse since not supported for hybrid/ssm models
|
||||
enable_block_reuse: false
|
||||
free_gpu_memory_fraction: 0.88
|
||||
# tunable mamba cache dtype
|
||||
# --> use float32 for accuracy and default (auto) for speed
|
||||
# mamba_ssm_cache_dtype: float32
|
||||
transforms:
|
||||
detect_sharding:
|
||||
sharding_dims: ['ep', 'bmm']
|
||||
allreduce_strategy: 'SYMM_MEM'
|
||||
sharding_dims: ['ep', 'bmm']
|
||||
manual_config:
|
||||
head_dim: 128
|
||||
tp_plan:
|
||||
@ -69,9 +70,8 @@ transforms:
|
||||
multi_stream_moe:
|
||||
stage: compile
|
||||
enabled: true
|
||||
insert_cached_ssm_attention:
|
||||
cache_config:
|
||||
mamba_dtype: float32
|
||||
gather_logits_before_lm_head:
|
||||
enabled: true
|
||||
fuse_mamba_a_log:
|
||||
stage: post_load_fusion
|
||||
enabled: true
|
||||
|
||||
@ -99,9 +99,11 @@ transforms:
|
||||
stage: weight_load
|
||||
run_per_gm: false
|
||||
checkpoint_device: null
|
||||
expect_mem_change: true
|
||||
move_inputs_to_device:
|
||||
stage: weight_load
|
||||
run_per_gm: false
|
||||
expect_mem_change: true
|
||||
############################################################################################
|
||||
# RUN POST-LOAD FUSION AND OPTIMIZATIONS
|
||||
############################################################################################
|
||||
@ -122,18 +124,17 @@ transforms:
|
||||
backend: trtllm
|
||||
fuse_moe:
|
||||
stage: post_load_fusion
|
||||
enabled: true
|
||||
expect_mem_change: true
|
||||
backend: trtllm
|
||||
fuse_fp8_moe:
|
||||
stage: post_load_fusion
|
||||
enabled: true
|
||||
expect_mem_change: true
|
||||
backend: trtllm
|
||||
fuse_nvfp4_moe:
|
||||
stage: post_load_fusion
|
||||
enabled: true
|
||||
expect_mem_change: true
|
||||
fuse_allreduce_residual_rmsnorm:
|
||||
stage: post_load_fusion
|
||||
# TODO (lucaslie): add backend selection as part of configurable inference optimizers
|
||||
fuse_rmsnorm:
|
||||
stage: post_load_fusion
|
||||
rmsnorm_backend: flashinfer
|
||||
@ -175,11 +176,12 @@ transforms:
|
||||
backend: cached_residual_add
|
||||
initialize_cache:
|
||||
stage: cache_init
|
||||
expect_mem_change: true
|
||||
run_per_gm: false
|
||||
resize_kv_cache:
|
||||
stage: cache_init
|
||||
expect_mem_change: true
|
||||
run_per_gm: false
|
||||
free_mem_ratio: 0.0
|
||||
############################################################################################
|
||||
# COMPILE MODEL
|
||||
############################################################################################
|
||||
@ -190,6 +192,7 @@ transforms:
|
||||
enabled: false
|
||||
compile_model:
|
||||
stage: compile
|
||||
expect_mem_change: true
|
||||
run_per_gm: false
|
||||
cuda_graph_batch_sizes: null
|
||||
backend: torch-compile
|
||||
|
||||
@ -7,12 +7,14 @@ transforms:
|
||||
build_and_load_factory_model:
|
||||
stage: factory
|
||||
run_per_gm: false
|
||||
expect_mem_change: true
|
||||
############################################################################################
|
||||
# MOVE ARGUMENTS TO DEVICE
|
||||
############################################################################################
|
||||
move_inputs_to_device:
|
||||
stage: weight_load
|
||||
run_per_gm: false
|
||||
expect_mem_change: true
|
||||
############################################################################################
|
||||
# SWITCH TO CACHED+FLATTENED ATTENTION + INITIALIZE CACHES
|
||||
############################################################################################
|
||||
@ -26,10 +28,11 @@ transforms:
|
||||
initialize_cache:
|
||||
stage: cache_init
|
||||
run_per_gm: false
|
||||
expect_mem_change: true
|
||||
resize_kv_cache:
|
||||
stage: cache_init
|
||||
run_per_gm: false
|
||||
free_mem_ratio: 0.0
|
||||
expect_mem_change: true
|
||||
############################################################################################
|
||||
# COMPILE MODEL
|
||||
############################################################################################
|
||||
|
||||
@ -9,16 +9,18 @@ and operates on a purely functional paradigm that is compatible with the torch c
|
||||
|
||||
"""
|
||||
|
||||
import math
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Literal, Optional, Protocol, Sequence, Set, Tuple, Type, Union
|
||||
from typing import Dict, List, Literal, Optional, Protocol, Sequence, Set, Tuple, Type, Union, final
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
from torch._ops import OpOverloadPacket
|
||||
from torch.fx import Node
|
||||
from torch.types import Number
|
||||
|
||||
from ...._utils import nvtx_range
|
||||
from tensorrt_llm.llmapi.llm_args import KvCacheConfig
|
||||
|
||||
from ...._utils import nvtx_range, str_dtype_to_torch
|
||||
from ..utils.logger import ad_logger
|
||||
|
||||
Constant = Union[int, float, str, None]
|
||||
@ -281,44 +283,6 @@ class InputBuffer:
|
||||
self._device_views = self._create_views(self._device_buffer)
|
||||
|
||||
|
||||
class CacheConfig(BaseModel):
|
||||
"""Cache configuration for attention-related dtypes."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
dtype: Optional[torch.dtype] = Field(default=None, description="KV cache dtype.")
|
||||
mamba_dtype: Optional[torch.dtype] = Field(default=None, description="Mamba cache dtype.")
|
||||
delta_dtype: Optional[torch.dtype] = Field(
|
||||
default=torch.float32, description="Delta cache dtype. Defaults to float32."
|
||||
)
|
||||
|
||||
@field_validator("dtype", "mamba_dtype", "delta_dtype", mode="before")
|
||||
@classmethod
|
||||
def _coerce_dtype(cls, value):
|
||||
if value is None or isinstance(value, torch.dtype):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
dtype = getattr(torch, value, None)
|
||||
assert isinstance(dtype, torch.dtype), f"Invalid {dtype=}"
|
||||
return dtype
|
||||
return value
|
||||
|
||||
def __or__(self, other: "CacheConfig") -> "CacheConfig":
|
||||
"""Combine two CacheConfig objects field-wise using Python's `or` semantics.
|
||||
|
||||
For each field, selects the first non-None value between `self` and `other`.
|
||||
"""
|
||||
if not isinstance(other, CacheConfig):
|
||||
raise NotImplementedError(f"Cannot combine CacheConfig with {type(other)}")
|
||||
merged_kwargs = {}
|
||||
for field_name in type(self).model_fields.keys():
|
||||
merged_kwargs[field_name] = getattr(self, field_name) or getattr(other, field_name)
|
||||
return CacheConfig(**merged_kwargs)
|
||||
|
||||
|
||||
class SequenceInfo:
|
||||
"""An interface to hold information about how the sequence is laid out and stored in cache.
|
||||
|
||||
@ -398,9 +362,9 @@ class SequenceInfo:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_seq_len: int = 1,
|
||||
max_batch_size: int = 1,
|
||||
page_size: int = 0,
|
||||
max_seq_len: int,
|
||||
max_batch_size: int,
|
||||
tokens_per_block: Optional[int] = None,
|
||||
max_num_tokens: Optional[int] = None,
|
||||
vocab_size_padded: Optional[int] = None,
|
||||
):
|
||||
@ -411,9 +375,7 @@ class SequenceInfo:
|
||||
includes the tokens in the input sequence and the tokens generated by the model.
|
||||
max_batch_size: corresponds to the maximum number of sequences (or requests) that the
|
||||
model can process.
|
||||
page_size: corresponds to the page size of the cache. For an unpaged cache, the page
|
||||
size should be set to max_seq_len. Also note that two sequences in a batch can not
|
||||
share a page.
|
||||
tokens_per_block: corresponds to the tokens per block of the cache.
|
||||
max_num_tokens: corresponds to the maximum number of tokens that the model can process
|
||||
across all sequences in the batch. If a batch is composed of context-only requests
|
||||
of input sequence length ISL, then the maximum number of sequences possible in the
|
||||
@ -427,46 +389,27 @@ class SequenceInfo:
|
||||
# set up basic attributes
|
||||
self.max_seq_len = max_seq_len
|
||||
self.max_batch_size = max_batch_size
|
||||
self.page_size = page_size if page_size > 0 else max_seq_len
|
||||
self.vocab_size_padded = vocab_size_padded
|
||||
# NOTE (lucaslie): WAR to address issue when using flashinfer attention with
|
||||
self.tokens_per_block = tokens_per_block or max_seq_len
|
||||
# NOTE (lucaslie): +1 is a WAR to address issue when using flashinfer attention with
|
||||
# (max_batch_size, max_seq_len) input in trtllm runtime.
|
||||
# see https://github.com/NVIDIA/TensorRT-LLM/issues/4504
|
||||
max_seq_len_adjusted = self.max_seq_len + 1
|
||||
self.max_num_tokens = max_num_tokens or (max_seq_len + 1) * max_batch_size
|
||||
|
||||
# TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/9883 clean up this hack
|
||||
self.max_state_slots = max_batch_size + 1
|
||||
# TODO (lucaslie): can we remove this eventually from this i/f?
|
||||
self.vocab_size_padded = vocab_size_padded
|
||||
|
||||
# if the provided max_num_tokens is less than the max_batch_size * max_seq_len_adjusted,
|
||||
# we use the provided max_num_tokens. If max_num_tokens provided is more, we still use
|
||||
# max_batch_size * max_seq_len_adjusted since the extra tokens cannot be used.
|
||||
self.max_num_tokens = self.max_batch_size * max_seq_len_adjusted
|
||||
if max_num_tokens is not None and max_num_tokens > 0:
|
||||
self.max_num_tokens = min(self.max_num_tokens, max_num_tokens)
|
||||
|
||||
# Num pages can not be less than max_batch_size.
|
||||
self._num_pages = max(
|
||||
self.max_batch_size,
|
||||
(self.max_num_tokens) // self.page_size # floored number of pages
|
||||
+ (self.max_num_tokens / self.max_batch_size % self.page_size > 0) # check for overflow
|
||||
* self.max_batch_size, # +1 page per sequence if overflow is required
|
||||
)
|
||||
# sanity check
|
||||
assert self.num_pages >= self.max_batch_size, "num_pages can't be less than max_batch_size"
|
||||
|
||||
# cache_loc requires some special treatment due to block reuse. Note that the constraint for
|
||||
# cache_loc with block_reuse is as follows:
|
||||
# 0 <= cache_loc < num_pages
|
||||
# len(cache_loc) <= max_num_cache_loc_assignments
|
||||
max_num_cache_loc_assignments = (
|
||||
max_seq_len_adjusted // self.page_size + 1
|
||||
) * self.max_batch_size
|
||||
# NOTE: we keep an extra state slot around to simplify cuda graph padding
|
||||
# WHY?
|
||||
# Requests that just finished won't free their used resources immediately. Specifically, the
|
||||
# running order is self.scheduler.schedule_request, self._forward_step() and
|
||||
# self._process_previous_batch() in the PyExecutor. Hence, the current forward step will
|
||||
# remove finished requests but will not remove mamba_cache immediately and therefore it
|
||||
# won't be available in time for padding in the next forward step.
|
||||
self.max_num_state_slots = max_batch_size + 1
|
||||
|
||||
# log parameters
|
||||
ad_logger.info(
|
||||
f"[SequenceInfo:] {self.max_seq_len=}, {self.max_batch_size=}, {self.page_size=}, "
|
||||
f"{self.max_num_tokens=} (inferred), {max_num_tokens=} (provided), {self.num_pages=}, "
|
||||
f"{max_num_cache_loc_assignments=}"
|
||||
f"[SequenceInfo:] {self.max_seq_len=}, {self.max_batch_size=}, {self.max_num_tokens=}"
|
||||
)
|
||||
|
||||
# indicator if extra args are activated that are needed for cached attention backends
|
||||
@ -496,8 +439,10 @@ class SequenceInfo:
|
||||
# OTHER FIELDS WHERE WE NEED EFFICIENT HOST<>DEVICE TRANSFER
|
||||
("_gather_idx", self.max_num_tokens, torch.int),
|
||||
("_mask_scatter_indices", self.max_num_tokens, torch.int),
|
||||
# cache_loc is LAST for truncation optimization (it's the largest tensor)
|
||||
("cache_loc", max_num_cache_loc_assignments, torch.int),
|
||||
# cache_loc is LAST for truncation optimization (it can be the largest tensor)
|
||||
# NOTE: sufficient for max_num_tokens forward pass. will be resized when KVCacheManager
|
||||
# is created.
|
||||
("cache_loc", self.max_num_tokens, torch.int),
|
||||
]
|
||||
|
||||
# Create the InputBuffer that manages contiguous host and device memory
|
||||
@ -619,35 +564,44 @@ class SequenceInfo:
|
||||
def is_generate(self) -> bool:
|
||||
return all(sl == 1 for sl in self.seq_len)
|
||||
|
||||
@property
|
||||
def num_pages(self) -> int:
|
||||
return self._num_pages
|
||||
|
||||
@num_pages.setter
|
||||
def num_pages(self, value):
|
||||
self._num_pages = value
|
||||
# Check if we need to resize cache_loc (it's the last tensor in the buffer)
|
||||
cache_loc_capacity = self._input_buffer.get_capacity("cache_loc")
|
||||
if value > cache_loc_capacity:
|
||||
ad_logger.info(
|
||||
f"Resizing cache_loc capacity from {cache_loc_capacity} to {value} "
|
||||
f"to accommodate num_pages={value}"
|
||||
)
|
||||
# Resize the input buffer (cache_loc is the last tensor, so this is supported)
|
||||
self._input_buffer.resize("cache_loc", value)
|
||||
# Also resize the args_list to match
|
||||
old_size = len(self._args_list["cache_loc"])
|
||||
self._args_list["cache_loc"].extend([0] * (value - old_size))
|
||||
|
||||
@property
|
||||
def is_paged(self) -> bool:
|
||||
return self.page_size < self.max_seq_len
|
||||
|
||||
@property
|
||||
def page_assignments(self) -> List[List[int]]:
|
||||
"""Return the page assignments for each sequence."""
|
||||
return self._get_page_assignments(self.cache_loc, self.pages_per_seq)
|
||||
|
||||
def estimate_cache_tokens_per_forward(self) -> int:
|
||||
"""Estimate the max number of tokens that will be cached for a forward pass.
|
||||
|
||||
It is estimated assuming a worst-case allocation of tokens across sequences in a batch.
|
||||
"""
|
||||
seq_len = math.ceil(self.max_num_tokens / self.max_batch_size)
|
||||
num_blocks_estimate_per_seq = math.ceil(seq_len / self.tokens_per_block)
|
||||
num_blocks_estimate = num_blocks_estimate_per_seq * self.max_batch_size
|
||||
return num_blocks_estimate * self.tokens_per_block
|
||||
|
||||
def estimate_cache_loc_capacity(self, num_blocks: int) -> None:
|
||||
"""Estimate needed capacity of cache_loc based on available blocks and resize."""
|
||||
cache_loc_capacity = self._input_buffer.get_capacity("cache_loc")
|
||||
|
||||
# cache_loc requires some special treatment due to block reuse. Note that the constraint for
|
||||
# cache_loc with block_reuse is as follows:
|
||||
# 0 <= cache_loc < num_blocks
|
||||
# len(cache_loc) <= num_blocks * max_batch_size
|
||||
# NOTE: # However, num_blocks * max_batch_size may be a potentially very large number and an
|
||||
# overestimation, see we assume at most 10x reuse of blocks.
|
||||
estimated_capacity = num_blocks * min(10, self.max_batch_size)
|
||||
|
||||
# NOTE (lucaslie): WAR to address issue when using flashinfer attention with
|
||||
# (max_batch_size, max_seq_len) input in trtllm runtime.
|
||||
# see https://github.com/NVIDIA/TensorRT-LLM/issues/4504
|
||||
estimated_capacity = estimated_capacity + 1
|
||||
|
||||
if estimated_capacity > cache_loc_capacity:
|
||||
self._input_buffer.resize("cache_loc", estimated_capacity)
|
||||
# Also resize the args_list to match
|
||||
old_size = len(self._args_list["cache_loc"])
|
||||
self._args_list["cache_loc"].extend([0] * (estimated_capacity - old_size))
|
||||
|
||||
@staticmethod
|
||||
def _get_page_assignments(
|
||||
cache_locations: List[int], pages_per_sequence: List[int]
|
||||
@ -732,7 +686,8 @@ class SequenceInfo:
|
||||
|
||||
# figure out page assignments
|
||||
pages_per_seq = [
|
||||
len(ids_one_seq) // self.page_size + (len(ids_one_seq) % self.page_size > 0)
|
||||
len(ids_one_seq) // self.tokens_per_block
|
||||
+ (len(ids_one_seq) % self.tokens_per_block > 0)
|
||||
for ids_one_seq in input_ids
|
||||
]
|
||||
cache_loc = list(range(sum(pages_per_seq)))
|
||||
@ -889,7 +844,7 @@ class SequenceInfo:
|
||||
|
||||
This i/f will ensure that all sequence info args are updated accordingly. Reset values are
|
||||
chosen as "neutral" values so that for cases like rounding up batch sizes for cudagraph we
|
||||
only write to unused buffers/caches.
|
||||
only write to unused caches.
|
||||
"""
|
||||
### UPDATE SEQUENCE LENGTH AND INPUT POSITION FIRST SINCE IT'S USED FOR OTHER UPDATES ######
|
||||
if seq_len is None:
|
||||
@ -953,7 +908,7 @@ class SequenceInfo:
|
||||
|
||||
# update last page length
|
||||
if last_page_len is None:
|
||||
last_page_len = [(slwc - 1) % self.page_size + 1 for slwc in seq_len_with_cache]
|
||||
last_page_len = [(slwc - 1) % self.tokens_per_block + 1 for slwc in seq_len_with_cache]
|
||||
self._store_arg("last_page_len", last_page_len)
|
||||
|
||||
# check for updated slot_idx
|
||||
@ -1049,6 +1004,90 @@ class SequenceInfo:
|
||||
host_function(**{arg: self._get_arg(arg) for arg in args})
|
||||
|
||||
|
||||
class ResourceHandler(ABC):
|
||||
"""An abstract interface to handle a generic resource needed by attention operators.
|
||||
|
||||
The ResourceHandler interface standardizes operations that the cached sequence interface
|
||||
performs on the resources providing an abstract handle.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def allocate(self, sequence_info: SequenceInfo) -> torch.Tensor:
|
||||
"""Initialize the resource for the given sequence info."""
|
||||
|
||||
|
||||
class ManagedResourceHandler(ResourceHandler):
|
||||
"""An abstract interface to handle a resource that is managed by the cache manager."""
|
||||
|
||||
@final
|
||||
def allocate(self, sequence_info: SequenceInfo) -> torch.Tensor:
|
||||
"""Allocate the resource for the given sequence info."""
|
||||
raise NotImplementedError("Managed resources should not be allocated directly!")
|
||||
|
||||
|
||||
class PagedResourceHandler(ManagedResourceHandler):
|
||||
"""An abstract interface to handle a paged resource.
|
||||
|
||||
The PagedResourceHandler can be used to handle resources that support paging such as kv-caches.
|
||||
"""
|
||||
|
||||
def __init__(self, *token_shape: int, dtype: torch.dtype) -> None:
|
||||
"""Initialize the PagedResourceHandler.
|
||||
|
||||
Args:
|
||||
page_shape: The shape of a single page of the resource.
|
||||
dtype: The dtype of the resource.
|
||||
"""
|
||||
self.token_shape = token_shape
|
||||
self.dtype = dtype
|
||||
|
||||
|
||||
class StateResourceHandler(ManagedResourceHandler):
|
||||
"""Handler for per-sequence state resources (e.g., Mamba SSM/conv states).
|
||||
|
||||
These resources have shape [max_batch_size, *state_shape] and are
|
||||
managed by MambaHybridCacheManager via byte-level pooling.
|
||||
"""
|
||||
|
||||
def __init__(self, *state_shape: int, dtype: torch.dtype) -> None:
|
||||
"""Initialize the StateResourceHandler.
|
||||
|
||||
Args:
|
||||
state_shape: The shape of a single state resource.
|
||||
dtype: The dtype of the state resource.
|
||||
"""
|
||||
self.state_shape = state_shape
|
||||
self.dtype = dtype
|
||||
|
||||
|
||||
class UnpagedResourceHandler(ResourceHandler):
|
||||
"""Handler for per-token unpaged resources (e.g., unpaged KV caches).
|
||||
|
||||
These resources have shape [max_batch_size, max_seq_len, *token_shape].
|
||||
They are allocated locally and not managed by MambaHybridCacheManager.
|
||||
"""
|
||||
|
||||
def __init__(self, *token_shape: int, dtype: torch.dtype) -> None:
|
||||
"""Initialize the UnpagedResourceHandler.
|
||||
|
||||
Args:
|
||||
token_shape: The shape of the resource per token.
|
||||
dtype: The dtype of the resource.
|
||||
"""
|
||||
self.token_shape = token_shape
|
||||
self.dtype = dtype
|
||||
|
||||
def allocate(self, sequence_info: SequenceInfo) -> torch.Tensor:
|
||||
"""Initialize the unpaged resource for the given sequence info."""
|
||||
return torch.empty(
|
||||
sequence_info.max_num_state_slots,
|
||||
sequence_info.max_seq_len,
|
||||
*self.token_shape,
|
||||
device=sequence_info.device,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
|
||||
class MHACallable(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
@ -1062,18 +1101,10 @@ class PrepareMetadataCallable(Protocol):
|
||||
) -> List[torch.Tensor]: ...
|
||||
|
||||
|
||||
class GetCacheCallable(Protocol):
|
||||
def __call__(self, sequence_info: SequenceInfo) -> torch.Tensor: ...
|
||||
|
||||
|
||||
class GetBufferCallable(GetCacheCallable):
|
||||
pass
|
||||
|
||||
|
||||
CacheInitializerDict = Dict[str, GetCacheCallable]
|
||||
BufferInitializerDict = Dict[str, GetBufferCallable]
|
||||
AttentionLayout = Literal["bsnd", "bnsd"]
|
||||
|
||||
ResourceHandlerDict = Dict[str, ResourceHandler]
|
||||
|
||||
|
||||
class AttentionDescriptor(ABC):
|
||||
"""An interface to define a functional attention operator.
|
||||
@ -1083,11 +1114,6 @@ class AttentionDescriptor(ABC):
|
||||
specific to the attention op.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def is_paged(cls) -> bool:
|
||||
"""Return if the attention op is paged or not."""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_attention_layout(cls) -> AttentionLayout:
|
||||
@ -1116,7 +1142,6 @@ class AttentionDescriptor(ABC):
|
||||
*meta_std, # standard metadata fields identified by matching arg names!
|
||||
*meta_extra,# metadata about the sequences as returned by the prepare_metadata op
|
||||
*caches, # contains layer-specific caches per provided cache initializers
|
||||
*buffers, # global buffers used by the attention op as provided by buffer initializers
|
||||
*constants, # basic arguments (int, float, str, None) added as CONSTANTS in the graph
|
||||
) -> torch.Tensor: ...
|
||||
```
|
||||
@ -1172,52 +1197,46 @@ class AttentionDescriptor(ABC):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_cache_initializers(
|
||||
cls, source_attn_node: Node, cache_config: CacheConfig
|
||||
) -> CacheInitializerDict:
|
||||
"""Provide a dictionary of function pointers that can be used to initialize the caches.
|
||||
cls, source_attn_node: Node, cache_config: KvCacheConfig
|
||||
) -> ResourceHandlerDict:
|
||||
"""Provide a dictionary of resource handlers that can be used to initialize the resources.
|
||||
|
||||
The key corresponds to the argument name used in the attention op signature. The function
|
||||
key doesn't need to be unique across multiple attention nodes in the graph. The key used to
|
||||
describe the cache in the graph will be patched with the attention node index to ensure
|
||||
uniqueness.
|
||||
|
||||
``get_cache_initializers`` will be called *once* during cache initialization and before
|
||||
the initial forward pass for each attention op detected in the graph. The caches will be
|
||||
managed by the global CacheManager and passed back to the attention op during the forward
|
||||
pass.
|
||||
The resource will be initialized before the initial forward pass and will be managed by the
|
||||
global CacheManager and passed back to the model during the forward pass.
|
||||
|
||||
If the cache initializer requires information about the attention op, it can retrieve
|
||||
the necessary information from the source attention node and cache config.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_global_buffer_initializers(cls, source_attn_node: Node) -> BufferInitializerDict:
|
||||
"""Provide a dictionary of function pointers that can be used to initialize buffers.
|
||||
|
||||
The key corresponds to the buffer name used in the graph module and will **not**
|
||||
be patched unlike a cache key. Hence, it is a **global** key that is shared across all
|
||||
attention ops in the model much like a regular buffer in an nn.Module. That means if this
|
||||
i/f is called for multiple attention ops, the same buffer will be shared across all of them
|
||||
if this function provides the same key multiple times.
|
||||
|
||||
Buffers are initialize *once* after the model initialization and before the initial forward
|
||||
pass for each attention op detected in the graph. The buffer will be managed by the global
|
||||
CacheManager and passed back to the attention op during the forward pass.
|
||||
|
||||
If the buffer initializer requires information about the attention op, it can retrieve
|
||||
the necessary information from the source attention node.
|
||||
"""
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def get_constants(cls, source_attn_node: Node) -> List[Constant]:
|
||||
"""Provide a list of constant arguments to be passed to the attention op.
|
||||
|
||||
The constant arguments are passed to the attention op as additional arguments after the
|
||||
caches and buffers. The constants are expected to be of type int, float, str, or None.
|
||||
caches. The constants are expected to be of type int, float, str, or None.
|
||||
"""
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def resolve_cache_dtype(dtype_config: str, fallback_dtype: torch.dtype) -> torch.dtype:
|
||||
"""Resolve cache dtype from KvCacheConfig dtype string to torch.dtype.
|
||||
|
||||
Args:
|
||||
dtype_config: The dtype string from KvCacheConfig (e.g., "auto", "float16", "bfloat16").
|
||||
fallback_dtype: The fallback dtype to use when dtype_config is "auto".
|
||||
|
||||
Returns:
|
||||
The resolved torch.dtype.
|
||||
"""
|
||||
if dtype_config == "auto":
|
||||
return fallback_dtype
|
||||
return str_dtype_to_torch(dtype_config)
|
||||
|
||||
@classmethod
|
||||
def get_host_prepare_metadata_function(cls) -> Optional[PrepareMetadataHostCallable]:
|
||||
"""Get function that performs host-side prep for the forward pass for the attention op.
|
||||
|
||||
@ -11,17 +11,16 @@ import torch
|
||||
from torch._ops import OpOverloadPacket
|
||||
from torch.fx import Node
|
||||
|
||||
from .....llmapi.llm_args import KvCacheConfig
|
||||
from ...utils.node_utils import extract_op_args
|
||||
from ..attention_interface import (
|
||||
AttentionDescriptor,
|
||||
AttentionLayout,
|
||||
AttentionRegistry,
|
||||
BufferInitializerDict,
|
||||
CacheConfig,
|
||||
CacheInitializerDict,
|
||||
Constant,
|
||||
MHACallable,
|
||||
SequenceInfo,
|
||||
ResourceHandlerDict,
|
||||
StateResourceHandler,
|
||||
)
|
||||
from .delta_rule.chunk import chunk_delta_rule_fwd
|
||||
from .delta_rule.fused_recurrent import fused_recurrent_delta_rule_fwd
|
||||
@ -136,11 +135,6 @@ def fla_cached_delta_rule_fake(
|
||||
|
||||
@AttentionRegistry.register("fla_delta")
|
||||
class FlaDeltaBackend(AttentionDescriptor):
|
||||
@classmethod
|
||||
def is_paged(cls) -> bool:
|
||||
# TODO: we should refine our notion of "is_paged" --> seems counterintuitive for ssm nows
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_attention_layout(cls) -> AttentionLayout:
|
||||
return "bsnd"
|
||||
@ -164,8 +158,8 @@ class FlaDeltaBackend(AttentionDescriptor):
|
||||
|
||||
@classmethod
|
||||
def get_cache_initializers(
|
||||
cls, source_attn_node: Node, cache_config: CacheConfig
|
||||
) -> CacheInitializerDict:
|
||||
cls, source_attn_node: Node, cache_config: KvCacheConfig
|
||||
) -> ResourceHandlerDict:
|
||||
key_node = source_attn_node.args[1]
|
||||
value_node = source_attn_node.args[2]
|
||||
num_heads = key_node.meta["val"].shape[-2]
|
||||
@ -173,21 +167,15 @@ class FlaDeltaBackend(AttentionDescriptor):
|
||||
value_dim = value_node.meta["val"].shape[-1]
|
||||
key_dtype = key_node.meta["val"].dtype
|
||||
|
||||
def _get_delta_cache(si: SequenceInfo):
|
||||
return torch.empty(
|
||||
si.max_state_slots,
|
||||
return {
|
||||
"delta_cache": StateResourceHandler(
|
||||
num_heads,
|
||||
key_dim,
|
||||
value_dim,
|
||||
device=si.device,
|
||||
dtype=cache_config.delta_dtype or key_dtype,
|
||||
# NOTE: not configurable at the moment, using auto to match the key dtype
|
||||
dtype=cls.resolve_cache_dtype("auto", key_dtype),
|
||||
)
|
||||
|
||||
return {"delta_cache": _get_delta_cache}
|
||||
|
||||
@classmethod
|
||||
def get_global_buffer_initializers(cls, source_attn_node: Node) -> BufferInitializerDict:
|
||||
return {}
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_constants(cls, source_attn_node: Node) -> List[Constant]:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass, fields
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import flashinfer
|
||||
import torch
|
||||
@ -7,6 +7,7 @@ from torch._ops import OpOverloadPacket
|
||||
from torch._subclasses import FakeTensor
|
||||
from torch.fx import Node
|
||||
|
||||
from ....llmapi.llm_args import KvCacheConfig
|
||||
from ...flashinfer_utils import get_env_enable_pdl
|
||||
from ..utils.cuda_graph import cuda_graph_state
|
||||
from ..utils.logger import ad_logger
|
||||
@ -15,14 +16,12 @@ from .attention_interface import (
|
||||
AttentionDescriptor,
|
||||
AttentionLayout,
|
||||
AttentionRegistry,
|
||||
BufferInitializerDict,
|
||||
CacheConfig,
|
||||
CacheInitializerDict,
|
||||
Constant,
|
||||
MHACallable,
|
||||
PagedResourceHandler,
|
||||
PrepareMetadataCallable,
|
||||
PrepareMetadataHostCallable,
|
||||
SequenceInfo,
|
||||
ResourceHandlerDict,
|
||||
)
|
||||
|
||||
|
||||
@ -57,6 +56,9 @@ class _FlashInferPlanner:
|
||||
]
|
||||
plan_params_prefill: Optional[PlanParams]
|
||||
plan_params_decode: Optional[PlanParams]
|
||||
kv_layout: Literal["NHD", "HND"] = (
|
||||
"NHD" # TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/10966
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
self.workspace_buffer = None
|
||||
@ -77,7 +79,7 @@ class _FlashInferPlanner:
|
||||
if use_cuda_graph:
|
||||
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
||||
self.workspace_buffer,
|
||||
"NHD",
|
||||
self.kv_layout,
|
||||
use_cuda_graph=True,
|
||||
paged_kv_indptr_buffer=indptr,
|
||||
paged_kv_indices_buffer=indices,
|
||||
@ -88,28 +90,33 @@ class _FlashInferPlanner:
|
||||
else:
|
||||
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
||||
self.workspace_buffer,
|
||||
"NHD",
|
||||
self.kv_layout,
|
||||
use_tensor_cores=True,
|
||||
backend="fa2" if torch.cuda.get_device_capability(0) == (9, 0) else "auto",
|
||||
)
|
||||
|
||||
def init_workspace(self, workspace_buffer: torch.Tensor):
|
||||
def reset(self, device: torch.device) -> None:
|
||||
self.plan_params_prefill = None
|
||||
self.plan_params_decode = None
|
||||
|
||||
if isinstance(self.workspace_buffer, torch.Tensor):
|
||||
return
|
||||
|
||||
self.__init__() # reset all state
|
||||
|
||||
self.workspace_buffer = workspace_buffer
|
||||
# NOTE (lucaslie): avoid OOM for many cudagraphs,
|
||||
# see https://github.com/NVIDIA/TensorRT-LLM/pull/3686
|
||||
self.workspace_buffer = torch.empty(320 * 1024 * 1024, device=device, dtype=torch.uint8)
|
||||
|
||||
# NOTE (lucaslie): flashinfer fa3 backend has accuracy issue + illegal memory access issues
|
||||
# on H100 PCIe, see https://github.com/NVIDIA/TensorRT-LLM/issues/4504
|
||||
self.prefill_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
|
||||
self.workspace_buffer,
|
||||
"NHD",
|
||||
self.kv_layout,
|
||||
backend="fa2",
|
||||
)
|
||||
self.decode_wrapper = self._init_decode_wrapper()
|
||||
|
||||
def reset(self) -> None:
|
||||
self.plan_params_prefill = None
|
||||
self.plan_params_decode = None
|
||||
|
||||
def plan_generate_only(
|
||||
self,
|
||||
num_seq: int,
|
||||
@ -248,7 +255,7 @@ def prepare_flashinfer_metadata(
|
||||
num_seq = num_prefill + num_decode
|
||||
num_tokens = num_prefill_tokens + num_decode
|
||||
|
||||
_GlobalFlashInferPlanner.reset()
|
||||
_GlobalFlashInferPlanner.reset(position_ids.device)
|
||||
|
||||
qo_indptr = cu_seqlen[: num_seq + 1]
|
||||
|
||||
@ -316,8 +323,6 @@ def flashinfer_mha_with_cache(
|
||||
# CACHES
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
# BUFFERS
|
||||
workspace_buffer: torch.Tensor,
|
||||
# CONSTANTS
|
||||
scale: Optional[float],
|
||||
k_scale: float,
|
||||
@ -458,8 +463,6 @@ def flashinfer_mha_with_cache_fake(
|
||||
# CACHES
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
# BUFFERS
|
||||
workspace_buffer: torch.Tensor,
|
||||
# CONSTANTS
|
||||
scale: Optional[float],
|
||||
k_scale: float,
|
||||
@ -474,11 +477,6 @@ class FlashInferAttention(AttentionDescriptor):
|
||||
def _get_planner(cls) -> _FlashInferPlanner:
|
||||
return _GlobalFlashInferPlanner
|
||||
|
||||
@classmethod
|
||||
def is_paged(cls):
|
||||
"""Return if the attention op is paged or not."""
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_attention_layout(cls) -> AttentionLayout:
|
||||
"""Get the attention layout expected by the backend."""
|
||||
@ -519,35 +517,25 @@ class FlashInferAttention(AttentionDescriptor):
|
||||
|
||||
@classmethod
|
||||
def get_cache_initializers(
|
||||
cls, source_attn_node: Node, cache_config: CacheConfig
|
||||
) -> CacheInitializerDict:
|
||||
cls, source_attn_node: Node, cache_config: KvCacheConfig
|
||||
) -> ResourceHandlerDict:
|
||||
# source op is [bsnd] layout already
|
||||
k_fake: FakeTensor = source_attn_node.args[1].meta["val"]
|
||||
num_kv_heads = k_fake.shape[2]
|
||||
head_dim = k_fake.shape[3]
|
||||
|
||||
def _get_cache(si: SequenceInfo):
|
||||
return torch.empty(
|
||||
si.num_pages,
|
||||
si.page_size,
|
||||
return {
|
||||
"k_cache": PagedResourceHandler(
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
device=si.device,
|
||||
dtype=cache_config.dtype or k_fake.dtype,
|
||||
)
|
||||
|
||||
return {"k_cache": _get_cache, "v_cache": _get_cache}
|
||||
|
||||
@classmethod
|
||||
def get_global_buffer_initializers(cls, source_attn_node: Node) -> BufferInitializerDict:
|
||||
def _init_workspace(si: SequenceInfo) -> torch.Tensor:
|
||||
# NOTE (lucaslie): avoid OOM for many cudagraphs,
|
||||
# see https://github.com/NVIDIA/TensorRT-LLM/pull/3686
|
||||
buffer = torch.empty(320 * 1024 * 1024, dtype=torch.uint8, device=si.device)
|
||||
cls._get_planner().init_workspace(buffer)
|
||||
return buffer
|
||||
|
||||
return {"workspace_buffer": _init_workspace}
|
||||
dtype=cls.resolve_cache_dtype(cache_config.dtype, k_fake.dtype),
|
||||
),
|
||||
"v_cache": PagedResourceHandler(
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
dtype=cls.resolve_cache_dtype(cache_config.dtype, k_fake.dtype),
|
||||
),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_host_prepare_metadata_function(cls) -> Optional[PrepareMetadataHostCallable]:
|
||||
|
||||
@ -33,16 +33,16 @@ from torch.fx import Node
|
||||
from tensorrt_llm._torch.modules.mamba import PAD_SLOT_ID
|
||||
from tensorrt_llm._torch.modules.mamba.causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
||||
|
||||
from .....llmapi.llm_args import KvCacheConfig
|
||||
from ...utils.node_utils import extract_op_args
|
||||
from ..attention_interface import (
|
||||
AttentionDescriptor,
|
||||
AttentionLayout,
|
||||
AttentionRegistry,
|
||||
CacheConfig,
|
||||
CacheInitializerDict,
|
||||
Constant,
|
||||
MHACallable,
|
||||
SequenceInfo,
|
||||
ResourceHandlerDict,
|
||||
StateResourceHandler,
|
||||
)
|
||||
|
||||
|
||||
@ -164,10 +164,6 @@ def cuda_cached_causal_conv1d_wrapper(input: torch.Tensor, *args, **kwargs) -> t
|
||||
|
||||
@AttentionRegistry.register("cuda_causal_conv")
|
||||
class CudaBackendCausalConv(AttentionDescriptor):
|
||||
@classmethod
|
||||
def is_paged(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_attention_layout(cls) -> AttentionLayout:
|
||||
# Hidden states follow [b, s, c]
|
||||
@ -193,24 +189,21 @@ class CudaBackendCausalConv(AttentionDescriptor):
|
||||
|
||||
@classmethod
|
||||
def get_cache_initializers(
|
||||
cls, source_attn_node: Node, cache_config: CacheConfig
|
||||
) -> CacheInitializerDict:
|
||||
cls, source_attn_node: Node, cache_config: KvCacheConfig
|
||||
) -> ResourceHandlerDict:
|
||||
inp_fake: torch.Tensor = source_attn_node.args[0].meta["val"]
|
||||
w_fake: torch.Tensor = source_attn_node.args[1].meta["val"]
|
||||
|
||||
in_channels = inp_fake.shape[-1]
|
||||
kernel_size = w_fake.shape[-1]
|
||||
|
||||
def _get_conv_cache(si: SequenceInfo):
|
||||
return torch.empty(
|
||||
si.max_state_slots,
|
||||
in_channels,
|
||||
max(1, kernel_size - 1),
|
||||
device=si.device,
|
||||
dtype=inp_fake.dtype,
|
||||
)
|
||||
|
||||
return {"conv_state_cache": _get_conv_cache}
|
||||
conv_state_handler = StateResourceHandler(
|
||||
in_channels,
|
||||
max(1, kernel_size - 1),
|
||||
# NOTE: not configurable at the moment, using auto to match the input dtype
|
||||
dtype=cls.resolve_cache_dtype("auto", inp_fake.dtype),
|
||||
)
|
||||
return {"conv_state_cache": conv_state_handler}
|
||||
|
||||
@classmethod
|
||||
def get_constants(cls, source_attn_node: Node) -> List[Constant]:
|
||||
|
||||
@ -16,17 +16,16 @@ import torch.nn.functional as F
|
||||
from torch._ops import OpOverloadPacket
|
||||
from torch.fx import Node
|
||||
|
||||
from .....llmapi.llm_args import KvCacheConfig
|
||||
from ...utils.node_utils import extract_op_args
|
||||
from ..attention_interface import (
|
||||
AttentionDescriptor,
|
||||
AttentionLayout,
|
||||
AttentionRegistry,
|
||||
BufferInitializerDict,
|
||||
CacheConfig,
|
||||
CacheInitializerDict,
|
||||
Constant,
|
||||
MHACallable,
|
||||
SequenceInfo,
|
||||
ResourceHandlerDict,
|
||||
StateResourceHandler,
|
||||
)
|
||||
|
||||
|
||||
@ -270,11 +269,6 @@ def _torch_cached_causal_conv1d_fake(
|
||||
|
||||
@AttentionRegistry.register("torch_causal_conv")
|
||||
class TorchBackendCausalConv(AttentionDescriptor):
|
||||
@classmethod
|
||||
def is_paged(cls) -> bool:
|
||||
# TODO: we should refine our notion of "is_paged" --> seems counterintuitive for ssm nows
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_attention_layout(cls) -> AttentionLayout:
|
||||
# Hidden states follow [b, s, c]
|
||||
@ -300,28 +294,22 @@ class TorchBackendCausalConv(AttentionDescriptor):
|
||||
|
||||
@classmethod
|
||||
def get_cache_initializers(
|
||||
cls, source_attn_node: Node, cache_config: CacheConfig
|
||||
) -> CacheInitializerDict:
|
||||
cls, source_attn_node: Node, cache_config: KvCacheConfig
|
||||
) -> ResourceHandlerDict:
|
||||
inp_fake: torch.Tensor = source_attn_node.args[0].meta["val"]
|
||||
w_fake: torch.Tensor = source_attn_node.args[1].meta["val"]
|
||||
|
||||
in_channels = inp_fake.shape[-1]
|
||||
kernel_size = w_fake.shape[-1]
|
||||
|
||||
def _get_conv_cache(si: SequenceInfo):
|
||||
return torch.empty(
|
||||
si.max_state_slots,
|
||||
return {
|
||||
"conv_state_cache": StateResourceHandler(
|
||||
in_channels,
|
||||
kernel_size,
|
||||
device=si.device,
|
||||
dtype=inp_fake.dtype,
|
||||
# NOTE: not configurable at the moment, using auto to match the input dtype
|
||||
dtype=cls.resolve_cache_dtype("auto", inp_fake.dtype),
|
||||
)
|
||||
|
||||
return {"conv_state_cache": _get_conv_cache}
|
||||
|
||||
@classmethod
|
||||
def get_global_buffer_initializers(cls, source_attn_node: Node) -> BufferInitializerDict:
|
||||
return {}
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_constants(cls, source_attn_node: Node) -> List[Constant]:
|
||||
|
||||
@ -12,17 +12,16 @@ import torch
|
||||
from torch._ops import OpOverloadPacket
|
||||
from torch.fx import Node
|
||||
|
||||
from .....llmapi.llm_args import KvCacheConfig
|
||||
from ...utils.node_utils import extract_op_args
|
||||
from ..attention_interface import (
|
||||
AttentionDescriptor,
|
||||
AttentionLayout,
|
||||
AttentionRegistry,
|
||||
BufferInitializerDict,
|
||||
CacheConfig,
|
||||
CacheInitializerDict,
|
||||
Constant,
|
||||
MHACallable,
|
||||
SequenceInfo,
|
||||
ResourceHandlerDict,
|
||||
StateResourceHandler,
|
||||
)
|
||||
from .torch_mamba import _torch_ssm_prefill
|
||||
|
||||
@ -268,11 +267,6 @@ def _torch_cached_ssm_fake(
|
||||
|
||||
@AttentionRegistry.register("torch_ssm")
|
||||
class TorchBackendSSM(AttentionDescriptor):
|
||||
@classmethod
|
||||
def is_paged(cls) -> bool:
|
||||
# TODO: we should refine our notion of "is_paged" --> seems counterintuitive for ssm now
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_attention_layout(cls) -> AttentionLayout:
|
||||
# Hidden states follow [b, s, n, d]
|
||||
@ -297,8 +291,8 @@ class TorchBackendSSM(AttentionDescriptor):
|
||||
|
||||
@classmethod
|
||||
def get_cache_initializers(
|
||||
cls, source_attn_node: Node, cache_config: CacheConfig
|
||||
) -> CacheInitializerDict:
|
||||
cls, source_attn_node: Node, cache_config: KvCacheConfig
|
||||
) -> ResourceHandlerDict:
|
||||
# Shapes from fake tensors
|
||||
hs_fake: torch.Tensor = source_attn_node.args[0].meta["val"]
|
||||
B_fake: torch.Tensor = source_attn_node.args[2].meta["val"]
|
||||
@ -315,23 +309,13 @@ class TorchBackendSSM(AttentionDescriptor):
|
||||
ssm_state_size = max(1, B_fake.shape[-1])
|
||||
|
||||
# extract ssm_state_dtype from cache_config or hs_fake
|
||||
ssm_state_dtype = cache_config.mamba_dtype or hs_fake.dtype
|
||||
ssm_state_dtype = cls.resolve_cache_dtype(cache_config.mamba_ssm_cache_dtype, hs_fake.dtype)
|
||||
|
||||
def _get_ssm_cache(si: SequenceInfo):
|
||||
return torch.empty(
|
||||
si.max_state_slots,
|
||||
num_heads,
|
||||
head_dim,
|
||||
ssm_state_size,
|
||||
device=si.device,
|
||||
dtype=ssm_state_dtype,
|
||||
return {
|
||||
"ssm_state_cache": StateResourceHandler(
|
||||
num_heads, head_dim, ssm_state_size, dtype=ssm_state_dtype
|
||||
)
|
||||
|
||||
return {"ssm_state_cache": _get_ssm_cache}
|
||||
|
||||
@classmethod
|
||||
def get_global_buffer_initializers(cls, source_attn_node: Node) -> BufferInitializerDict:
|
||||
return {}
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_constants(cls, source_attn_node: Node) -> List[Constant]:
|
||||
|
||||
@ -24,17 +24,17 @@ from tensorrt_llm._torch.modules.mamba.mamba2_metadata import cu_seqlens_to_chun
|
||||
from tensorrt_llm._torch.modules.mamba.selective_state_update import selective_state_update
|
||||
from tensorrt_llm._torch.modules.mamba.ssd_combined import mamba_chunk_scan_combined
|
||||
|
||||
from .....llmapi.llm_args import KvCacheConfig
|
||||
from ...utils.node_utils import extract_op_args
|
||||
from ..attention_interface import (
|
||||
AttentionDescriptor,
|
||||
AttentionLayout,
|
||||
AttentionRegistry,
|
||||
CacheConfig,
|
||||
CacheInitializerDict,
|
||||
Constant,
|
||||
MHACallable,
|
||||
PrepareMetadataCallable,
|
||||
SequenceInfo,
|
||||
ResourceHandlerDict,
|
||||
StateResourceHandler,
|
||||
)
|
||||
|
||||
|
||||
@ -274,10 +274,6 @@ def _triton_cached_ssm_fake(
|
||||
|
||||
@AttentionRegistry.register("triton_ssm")
|
||||
class TritonBackendSSM(AttentionDescriptor):
|
||||
@classmethod
|
||||
def is_paged(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_attention_layout(cls) -> AttentionLayout:
|
||||
# Hidden states follow [b, s, n, d]
|
||||
@ -313,8 +309,8 @@ class TritonBackendSSM(AttentionDescriptor):
|
||||
|
||||
@classmethod
|
||||
def get_cache_initializers(
|
||||
cls, source_attn_node: Node, cache_config: CacheConfig
|
||||
) -> CacheInitializerDict:
|
||||
cls, source_attn_node: Node, cache_config: KvCacheConfig
|
||||
) -> ResourceHandlerDict:
|
||||
# Shapes from fake tensors
|
||||
hs_fake: torch.Tensor = source_attn_node.args[0].meta["val"]
|
||||
B_fake: torch.Tensor = source_attn_node.args[2].meta["val"]
|
||||
@ -328,19 +324,13 @@ class TritonBackendSSM(AttentionDescriptor):
|
||||
ssm_state_size = max(1, B_fake.shape[-1])
|
||||
|
||||
# extract ssm_state_dtype from cache_config or hs_fake
|
||||
ssm_state_dtype = cache_config.mamba_dtype or hs_fake.dtype
|
||||
ssm_state_dtype = cls.resolve_cache_dtype(cache_config.mamba_ssm_cache_dtype, hs_fake.dtype)
|
||||
|
||||
def _get_ssm_cache(si: SequenceInfo):
|
||||
return torch.empty(
|
||||
si.max_state_slots,
|
||||
num_heads,
|
||||
head_dim,
|
||||
ssm_state_size,
|
||||
device=si.device,
|
||||
dtype=ssm_state_dtype,
|
||||
return {
|
||||
"ssm_state_cache": StateResourceHandler(
|
||||
num_heads, head_dim, ssm_state_size, dtype=ssm_state_dtype
|
||||
)
|
||||
|
||||
return {"ssm_state_cache": _get_ssm_cache}
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_constants(cls, source_attn_node: Node) -> List[Constant]:
|
||||
|
||||
@ -6,21 +6,35 @@ import torch
|
||||
from torch._ops import OpOverloadPacket
|
||||
from torch.fx import Node
|
||||
|
||||
from ....llmapi.llm_args import KvCacheConfig
|
||||
from .attention_interface import (
|
||||
AttentionDescriptor,
|
||||
AttentionLayout,
|
||||
AttentionRegistry,
|
||||
BufferInitializerDict,
|
||||
CacheConfig,
|
||||
CacheInitializerDict,
|
||||
MHACallable,
|
||||
SequenceInfo,
|
||||
ResourceHandlerDict,
|
||||
UnpagedResourceHandler,
|
||||
)
|
||||
from .triton_attention import _flattened_context_mha, _generate_mha
|
||||
|
||||
Constant = Union[int, float, str, None]
|
||||
|
||||
|
||||
def _precompute_inv_freq(
|
||||
max_seq_len: int, head_dim: int, rope_theta: float, device: torch.device
|
||||
) -> torch.Tensor:
|
||||
inv_freq = 1.0 / (
|
||||
rope_theta ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim)
|
||||
)
|
||||
t = torch.arange(max_seq_len, device=inv_freq.device, dtype=inv_freq.dtype)
|
||||
|
||||
freqs = torch.outer(t, inv_freq.to(t.device))
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos_sin_stacked = torch.stack([emb.cos().to(torch.bfloat16), emb.sin().to(torch.bfloat16)])
|
||||
return cos_sin_stacked
|
||||
|
||||
|
||||
@torch.library.custom_op(
|
||||
"auto_deploy::triton_attention_fused_flattened_mla_with_cache", mutates_args=()
|
||||
)
|
||||
@ -41,10 +55,9 @@ def fused_flattened_mla_with_cache(
|
||||
# CACHES
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
# BUFFERS
|
||||
cos_sin_stacked: torch.Tensor,
|
||||
# CONSTANTS
|
||||
softmax_scale: Optional[float] = None,
|
||||
rope_theta: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Flattened & fused MLA with cache with triton kernels."""
|
||||
# b, s info
|
||||
@ -84,7 +97,12 @@ def fused_flattened_mla_with_cache(
|
||||
k_pe = k_pe.clone().transpose(1, 2).view(*bs_view, -1, qk_rope_head_dim).contiguous()
|
||||
value_states = value_states.transpose(1, 2).view(*bs_view, -1, v_head_dim).contiguous()
|
||||
# Apply RoPE
|
||||
if cos_sin_stacked.numel() > 0:
|
||||
if rope_theta is not None:
|
||||
max_seq_len = (input_pos + seq_len).max().item()
|
||||
cos_sin_stacked = _precompute_inv_freq(
|
||||
max_seq_len, qk_rope_head_dim, rope_theta, q_pe.device
|
||||
)
|
||||
|
||||
# Extract cos and sin from freqs_cis
|
||||
cos_base = cos_sin_stacked[0, ...]
|
||||
sin_base = cos_sin_stacked[1, ...]
|
||||
@ -176,10 +194,9 @@ def fused_flattened_mla_with_cache_fake(
|
||||
# CACHES
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
# BUFFERS
|
||||
cos_sin_stacked: torch.Tensor,
|
||||
# CONSTANTS
|
||||
softmax_scale: Optional[float] = None,
|
||||
rope_theta: Optional[float] = None,
|
||||
):
|
||||
v_head_dim = kv.shape[-1] - q_nope.shape[-1]
|
||||
return torch.empty_like(kv[..., -v_head_dim:])
|
||||
@ -187,11 +204,6 @@ def fused_flattened_mla_with_cache_fake(
|
||||
|
||||
@AttentionRegistry.register("MultiHeadLatentAttention")
|
||||
class MultiHeadLatentAttention(AttentionDescriptor):
|
||||
@classmethod
|
||||
def is_paged(cls) -> bool:
|
||||
"""Return if the attention op is paged or not."""
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_attention_layout(cls) -> AttentionLayout:
|
||||
"""Get the attention layout expected by the backend."""
|
||||
@ -216,8 +228,8 @@ class MultiHeadLatentAttention(AttentionDescriptor):
|
||||
|
||||
@classmethod
|
||||
def get_cache_initializers(
|
||||
cls, source_attn_node: Node, cache_config: CacheConfig
|
||||
) -> CacheInitializerDict:
|
||||
cls, source_attn_node: Node, cache_config: KvCacheConfig
|
||||
) -> ResourceHandlerDict:
|
||||
q_nope_fake = source_attn_node.args[0].meta["val"]
|
||||
q_pe_fake = source_attn_node.args[1].meta["val"]
|
||||
kv_fake = source_attn_node.args[2].meta["val"]
|
||||
@ -226,56 +238,21 @@ class MultiHeadLatentAttention(AttentionDescriptor):
|
||||
head_dim = q_nope_fake.shape[-1]
|
||||
rope_dim = q_pe_fake.shape[-1]
|
||||
|
||||
def _get_k_cache(si: SequenceInfo):
|
||||
assert not si.is_paged, "Paged cache not supported for MultiHeadLatentAttention"
|
||||
return torch.empty(
|
||||
si.num_pages,
|
||||
si.page_size,
|
||||
return {
|
||||
"k_cache": UnpagedResourceHandler(
|
||||
num_kv_heads,
|
||||
head_dim + rope_dim,
|
||||
device=si.device,
|
||||
dtype=cache_config.dtype or kv_fake.dtype,
|
||||
)
|
||||
|
||||
def _get_v_cache(si: SequenceInfo):
|
||||
assert not si.is_paged, "Paged cache not supported for MultiHeadLatentAttention"
|
||||
return torch.empty(
|
||||
si.num_pages,
|
||||
si.page_size,
|
||||
dtype=cls.resolve_cache_dtype(cache_config.dtype, kv_fake.dtype),
|
||||
),
|
||||
"v_cache": UnpagedResourceHandler(
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
device=si.device,
|
||||
dtype=cache_config.dtype or kv_fake.dtype,
|
||||
)
|
||||
|
||||
return {"k_cache": _get_k_cache, "v_cache": _get_v_cache}
|
||||
|
||||
@classmethod
|
||||
def get_global_buffer_initializers(cls, source_attn_node: Node) -> BufferInitializerDict:
|
||||
q_pe_fake = source_attn_node.args[1].meta["val"]
|
||||
rope_head_dim = q_pe_fake.shape[-1]
|
||||
rope_theta: float = 10000.0 # TODO: remove once MLA is unfused
|
||||
|
||||
def _get_cos_sin_stacked(si: SequenceInfo):
|
||||
if rope_theta is None:
|
||||
return torch.empty(0, device=si.device)
|
||||
return cls._precompute_inv_freq(si.max_seq_len, rope_head_dim, rope_theta).to(si.device)
|
||||
|
||||
return {
|
||||
f"cos_sin_stacked_{rope_head_dim}_{rope_theta}".replace(".", "_"): _get_cos_sin_stacked
|
||||
dtype=cls.resolve_cache_dtype(cache_config.dtype, kv_fake.dtype),
|
||||
),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_constants(cls, source_attn_node: Node) -> List[Constant]:
|
||||
return [None]
|
||||
|
||||
@staticmethod
|
||||
def _precompute_inv_freq(seq_len: int, head_dim: int, rope_theta: float = 1e4) -> torch.Tensor:
|
||||
inv_freq = 1.0 / (rope_theta ** (torch.arange(0, head_dim, 2).float() / head_dim))
|
||||
t = torch.arange(seq_len, device=inv_freq.device, dtype=inv_freq.dtype)
|
||||
|
||||
freqs = torch.outer(t, inv_freq.to(t.device))
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos_sin_stacked = torch.stack([emb.cos().to(torch.bfloat16), emb.sin().to(torch.bfloat16)])
|
||||
return cos_sin_stacked
|
||||
softmax_scale = None
|
||||
rope_theta = 10000.0 # TODO: remove once MLA is unfused
|
||||
return [softmax_scale, rope_theta]
|
||||
|
||||
@ -8,18 +8,17 @@ from torch._ops import OpOverloadPacket
|
||||
from torch._subclasses import FakeTensor
|
||||
from torch.fx import Node
|
||||
|
||||
from ....llmapi.llm_args import KvCacheConfig
|
||||
from ..utils.logger import ad_logger
|
||||
from ..utils.node_utils import extract_op_args
|
||||
from .attention_interface import (
|
||||
AttentionDescriptor,
|
||||
AttentionLayout,
|
||||
AttentionRegistry,
|
||||
BufferInitializerDict,
|
||||
CacheConfig,
|
||||
CacheInitializerDict,
|
||||
Constant,
|
||||
MHACallable,
|
||||
SequenceInfo,
|
||||
ResourceHandlerDict,
|
||||
UnpagedResourceHandler,
|
||||
)
|
||||
from .torch_attention import repeat_kv, update_kv_cache
|
||||
|
||||
@ -375,11 +374,6 @@ def torch_backend_mha_with_cache_fake(
|
||||
|
||||
@AttentionRegistry.register("torch")
|
||||
class TorchBackendAttention(AttentionDescriptor):
|
||||
@classmethod
|
||||
def is_paged(cls) -> bool:
|
||||
"""Return if the attention op is paged or not."""
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_attention_layout(cls) -> AttentionLayout:
|
||||
"""Get the attention layout expected by the source op and the cached attention op."""
|
||||
@ -404,8 +398,8 @@ class TorchBackendAttention(AttentionDescriptor):
|
||||
|
||||
@classmethod
|
||||
def get_cache_initializers(
|
||||
cls, source_attn_node: Node, cache_config: CacheConfig
|
||||
) -> CacheInitializerDict:
|
||||
cls, source_attn_node: Node, cache_config: KvCacheConfig
|
||||
) -> ResourceHandlerDict:
|
||||
# source op is [bsnd] layout already
|
||||
k_fake: FakeTensor = source_attn_node.args[1].meta["val"]
|
||||
v_fake: FakeTensor = source_attn_node.args[2].meta["val"]
|
||||
@ -413,33 +407,18 @@ class TorchBackendAttention(AttentionDescriptor):
|
||||
k_head_dim = k_fake.shape[3]
|
||||
v_head_dim = v_fake.shape[3]
|
||||
|
||||
def _get_k_cache(si: SequenceInfo):
|
||||
assert not si.is_paged, "Paged cache not supported for torch backend"
|
||||
return torch.empty(
|
||||
si.num_pages,
|
||||
si.page_size,
|
||||
return {
|
||||
"k_cache": UnpagedResourceHandler(
|
||||
num_kv_heads,
|
||||
k_head_dim,
|
||||
device=si.device,
|
||||
dtype=cache_config.dtype or k_fake.dtype,
|
||||
)
|
||||
|
||||
def _get_v_cache(si: SequenceInfo):
|
||||
assert not si.is_paged, "Paged cache not supported for torch backend"
|
||||
return torch.empty(
|
||||
si.num_pages,
|
||||
si.page_size,
|
||||
dtype=cls.resolve_cache_dtype(cache_config.dtype, k_fake.dtype),
|
||||
),
|
||||
"v_cache": UnpagedResourceHandler(
|
||||
num_kv_heads,
|
||||
v_head_dim,
|
||||
device=si.device,
|
||||
dtype=cache_config.dtype or v_fake.dtype,
|
||||
)
|
||||
|
||||
return {"k_cache": _get_k_cache, "v_cache": _get_v_cache}
|
||||
|
||||
@classmethod
|
||||
def get_global_buffer_initializers(cls, source_attn_node: Node) -> BufferInitializerDict:
|
||||
return {}
|
||||
dtype=cls.resolve_cache_dtype(cache_config.dtype, v_fake.dtype),
|
||||
),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_constants(cls, source_attn_node: Node) -> List[Constant]:
|
||||
|
||||
@ -172,7 +172,7 @@ def torch_fake_quant_fp8_linear(
|
||||
- input_zp / weight_zp ignored
|
||||
"""
|
||||
if weight_quantized.dtype != torch.float8_e4m3fn:
|
||||
raise TypeError("FP8 path requires weight_quantized.dtype == float8_e4m3fn")
|
||||
raise TypeError("FP8 path requires weight_quantized.dtype == torch.float8_e4m3fn")
|
||||
s_in = _expect_single_scale(input_scale, "input_scale")
|
||||
s_w = _expect_single_scale(weight_scale, "weight_scale")
|
||||
|
||||
|
||||
@ -9,18 +9,17 @@ from torch._ops import OpOverloadPacket
|
||||
from torch._subclasses import FakeTensor
|
||||
from torch.fx import Node
|
||||
|
||||
from ....llmapi.llm_args import KvCacheConfig
|
||||
from ..utils.logger import ad_logger
|
||||
from ..utils.node_utils import extract_op_args
|
||||
from .attention_interface import (
|
||||
AttentionDescriptor,
|
||||
AttentionLayout,
|
||||
AttentionRegistry,
|
||||
BufferInitializerDict,
|
||||
CacheConfig,
|
||||
CacheInitializerDict,
|
||||
Constant,
|
||||
MHACallable,
|
||||
SequenceInfo,
|
||||
ResourceHandlerDict,
|
||||
UnpagedResourceHandler,
|
||||
)
|
||||
from .triton_kernels.attention_with_kv_cache import (
|
||||
attention_kv_stage2,
|
||||
@ -312,11 +311,6 @@ def flattened_mha_fake(
|
||||
|
||||
@AttentionRegistry.register("triton")
|
||||
class TritonAttention(AttentionDescriptor):
|
||||
@classmethod
|
||||
def is_paged(cls) -> bool:
|
||||
"""Return if the attention op is paged or not."""
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_attention_layout(cls) -> AttentionLayout:
|
||||
"""Get the attention layout expected by the source op and the cached attention op."""
|
||||
@ -341,8 +335,8 @@ class TritonAttention(AttentionDescriptor):
|
||||
|
||||
@classmethod
|
||||
def get_cache_initializers(
|
||||
cls, source_attn_node: Node, cache_config: CacheConfig
|
||||
) -> CacheInitializerDict:
|
||||
cls, source_attn_node: Node, cache_config: KvCacheConfig
|
||||
) -> ResourceHandlerDict:
|
||||
# source op is [bsnd] layout already
|
||||
k_fake: FakeTensor = source_attn_node.args[1].meta["val"]
|
||||
v_fake: FakeTensor = source_attn_node.args[2].meta["val"]
|
||||
@ -350,33 +344,18 @@ class TritonAttention(AttentionDescriptor):
|
||||
k_head_dim = k_fake.shape[3]
|
||||
v_head_dim = v_fake.shape[3]
|
||||
|
||||
def _get_k_cache(si: SequenceInfo):
|
||||
assert not si.is_paged, "Paged cache not supported for triton"
|
||||
return torch.empty(
|
||||
si.num_pages,
|
||||
si.page_size,
|
||||
return {
|
||||
"k_cache": UnpagedResourceHandler(
|
||||
num_kv_heads,
|
||||
k_head_dim,
|
||||
device=si.device,
|
||||
dtype=cache_config.dtype or k_fake.dtype,
|
||||
)
|
||||
|
||||
def _get_v_cache(si: SequenceInfo):
|
||||
assert not si.is_paged, "Paged cache not supported for triton"
|
||||
return torch.empty(
|
||||
si.num_pages,
|
||||
si.page_size,
|
||||
dtype=cls.resolve_cache_dtype(cache_config.dtype, k_fake.dtype),
|
||||
),
|
||||
"v_cache": UnpagedResourceHandler(
|
||||
num_kv_heads,
|
||||
v_head_dim,
|
||||
device=si.device,
|
||||
dtype=cache_config.dtype or v_fake.dtype,
|
||||
)
|
||||
|
||||
return {"k_cache": _get_k_cache, "v_cache": _get_v_cache}
|
||||
|
||||
@classmethod
|
||||
def get_global_buffer_initializers(cls, source_attn_node: Node) -> BufferInitializerDict:
|
||||
return {}
|
||||
dtype=cls.resolve_cache_dtype(cache_config.dtype, v_fake.dtype),
|
||||
),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_constants(cls, source_attn_node: Node) -> List[Constant]:
|
||||
|
||||
@ -9,7 +9,6 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from ...llmapi.llm_args import (
|
||||
BuildConfig,
|
||||
EagleDecodingConfig,
|
||||
KvCacheConfig,
|
||||
SamplerType,
|
||||
TorchLlmArgs,
|
||||
_ParallelConfig,
|
||||
@ -45,7 +44,6 @@ def _check_for_default_value_only(
|
||||
|
||||
_TRANSFORMS_SHORTCUT_LOOKUP = {
|
||||
"attn_backend": ("insert_cached_attention.backend", "transformers_replace_cached_attn.backend"),
|
||||
"free_mem_ratio": ("resize_kv_cache.free_mem_ratio",),
|
||||
"compile_backend": ("compile_model.backend",),
|
||||
"cuda_graph_batch_sizes": ("compile_model.cuda_graph_batch_sizes",),
|
||||
}
|
||||
@ -191,25 +189,11 @@ class LlmArgs(DynamicYamlMixInForSettings, TorchLlmArgs, BaseSettings):
|
||||
|
||||
device: str = Field(default="cuda", description="The device to use for the model.", frozen=True)
|
||||
|
||||
# TODO: see if we can just remove this field and use kv_cache_config.dtype instead?
|
||||
kv_cache_dtype: str = Field(
|
||||
default="auto",
|
||||
description="Data type for KV cache. This is a temporary field until kv_cache_dtype is "
|
||||
"supported in AutoDeploy.",
|
||||
)
|
||||
|
||||
sampler_type: Union[str, SamplerType] = Field(
|
||||
default=SamplerType.TorchSampler,
|
||||
description="The type of sampler to use. Options are TRTLLMSampler or TorchSampler. Defaults to TorchSampler.",
|
||||
)
|
||||
|
||||
# NOTE: we do not support copy_on_partial_reuse in AutoDeploy yet
|
||||
# see https://github.com/NVIDIA/TensorRT-LLM/issues/7142
|
||||
kv_cache_config: KvCacheConfig = Field(
|
||||
default_factory=lambda **kwargs: KvCacheConfig(copy_on_partial_reuse=False, **kwargs),
|
||||
description="KV cache config.",
|
||||
)
|
||||
|
||||
max_beam_width: int = Field(
|
||||
default=1,
|
||||
description="The maximum beam width. >1 is not supported by AutoDeploy.",
|
||||
@ -240,12 +224,6 @@ class LlmArgs(DynamicYamlMixInForSettings, TorchLlmArgs, BaseSettings):
|
||||
default="flashinfer",
|
||||
description=_shortcut_description("Attention backend to use.", "attn_backend"),
|
||||
)
|
||||
free_mem_ratio: float = Field(
|
||||
default=0.0,
|
||||
description=_shortcut_description(
|
||||
"The fraction of available memory to allocate for cache.", "free_mem_ratio"
|
||||
),
|
||||
)
|
||||
compile_backend: str = Field(
|
||||
default="torch-compile",
|
||||
description=_shortcut_description(
|
||||
@ -263,16 +241,8 @@ class LlmArgs(DynamicYamlMixInForSettings, TorchLlmArgs, BaseSettings):
|
||||
)
|
||||
|
||||
### SEQUENCE INTERFACE CONFIG ##################################################################
|
||||
max_num_tokens: Optional[int] = Field(default=None, description="The maximum number of tokens.")
|
||||
max_seq_len: int = Field(default=512, ge=1, description="The maximum sequence length.")
|
||||
max_batch_size: int = Field(default=8, ge=1, description="The maximum batch size.")
|
||||
attn_page_size: int = Field(
|
||||
default=64,
|
||||
ge=1,
|
||||
description="Page size for attention (tokens_per_block). For triton and torch "
|
||||
"backends, this should equal max_seq_len. Temporary field until tokens_per_block gets "
|
||||
"properly passed through.",
|
||||
)
|
||||
|
||||
def model_dump(self, *args, **kwargs):
|
||||
"""Convert the arguments to a dictionary that can be used as kwargs for the LLM API."""
|
||||
@ -292,23 +262,6 @@ class LlmArgs(DynamicYamlMixInForSettings, TorchLlmArgs, BaseSettings):
|
||||
return kwargs
|
||||
|
||||
### VALIDATION #################################################################################
|
||||
@model_validator(mode="after")
|
||||
# TODO: discuss what to do with this once we fully transition to the new inference optimizer
|
||||
def update_attn_page_size(self):
|
||||
# NOTE force attn_page_size to equal max_seq_len for triton backend
|
||||
if self.transforms.get("insert_cached_attention", {}).get("backend") in [
|
||||
"triton",
|
||||
"torch",
|
||||
]:
|
||||
self.attn_page_size = self.max_seq_len
|
||||
# NOTE: (hg) For transformers mode. This is ugly.
|
||||
if self.transforms.get("transformers_replace_cached_attn", {}).get("backend") in [
|
||||
"triton",
|
||||
"torch",
|
||||
]:
|
||||
self.attn_page_size = self.max_seq_len
|
||||
return self
|
||||
|
||||
@field_validator("model_factory", mode="after")
|
||||
@classmethod
|
||||
def model_factory_exists(cls, value: str) -> str:
|
||||
@ -358,16 +311,6 @@ class LlmArgs(DynamicYamlMixInForSettings, TorchLlmArgs, BaseSettings):
|
||||
self.update_transforms_with_shortcuts()
|
||||
return self
|
||||
|
||||
@field_validator("kv_cache_config", mode="after")
|
||||
@classmethod
|
||||
def validate_kv_cache_config(cls, kv_cache_config: KvCacheConfig) -> KvCacheConfig:
|
||||
if kv_cache_config.copy_on_partial_reuse:
|
||||
kv_cache_config.copy_on_partial_reuse = False
|
||||
ad_logger.warning(
|
||||
"copy_on_partial_reuse is not supported by AutoDeploy. Setting it to False."
|
||||
)
|
||||
return kv_cache_config
|
||||
|
||||
### UTILITY METHODS ############################################################################
|
||||
def create_factory(self) -> ModelFactory:
|
||||
"""Create a model factory from the arguments."""
|
||||
|
||||
@ -27,10 +27,6 @@ from torch._prims_common import DeviceLikeType
|
||||
from torch.export import Dim
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from ..custom_ops.attention_interface import CacheConfig
|
||||
from ..utils.cuda_mem_tracker import get_mem_info_in_mb
|
||||
from ..utils.logger import ad_logger
|
||||
|
||||
DynamicShape = Dict[int, Dim] # indicating the dynamic shape in tensor dimension
|
||||
|
||||
|
||||
@ -211,13 +207,15 @@ class ModelFactory(ABC):
|
||||
"""Returns the sharding config for this model."""
|
||||
return self._sharding_config
|
||||
|
||||
def get_cache_config(self) -> CacheConfig:
|
||||
"""Return the cache configuration for the model.
|
||||
def get_cache_config_updates(self) -> Dict[str, Any]:
|
||||
"""Return updates for the KVCacheConfig for the model.
|
||||
|
||||
Returns:
|
||||
The cache configuration for the model.
|
||||
A dictionary of updates for the KVCacheConfig for the model.
|
||||
|
||||
Check tensorrt_llm/llmapi/llm_args.py for the KVCacheConfig fields.
|
||||
"""
|
||||
return CacheConfig()
|
||||
return {}
|
||||
|
||||
def init_tokenizer(self) -> Optional[Any]:
|
||||
"""Initialize the tokenizer for the model.
|
||||
@ -301,22 +299,12 @@ class ModelFactory(ABC):
|
||||
<SIZE_OF_LARGEST_CHECKPOINT_FILE>
|
||||
|
||||
"""
|
||||
ad_logger.info("Loading and initializing weights.")
|
||||
free_mem_pre, _ = get_mem_info_in_mb()
|
||||
ad_logger.info(f"Free memory before loading weights (MB): {free_mem_pre}")
|
||||
self._to_maybe_random(model, device)
|
||||
params_size = sum(p.numel() * p.element_size() for p in model.parameters())
|
||||
total_size_GB = params_size / (1024**3)
|
||||
ad_logger.info(f"Estimated parameters memory: {total_size_GB:.2f} GB")
|
||||
|
||||
if not self.skip_loading_weights:
|
||||
self.prefetch_checkpoint(force=True)
|
||||
self._load_checkpoint(model, device)
|
||||
|
||||
ad_logger.info("Loading and initializing weights. Done.")
|
||||
free_mem_post, _ = get_mem_info_in_mb()
|
||||
ad_logger.info(f"Free memory after loading weights (MB): {free_mem_post}")
|
||||
|
||||
@staticmethod
|
||||
def _to_maybe_random(model: nn.Module, device: DeviceLikeType):
|
||||
"""A mix of ``model.to(device)`` and random initialization of parameters.
|
||||
|
||||
@ -33,7 +33,6 @@ from transformers.utils import (
|
||||
WEIGHTS_NAME,
|
||||
)
|
||||
|
||||
from ..custom_ops.attention_interface import CacheConfig
|
||||
from ..utils._config import deep_merge_dicts
|
||||
from ..utils.logger import ad_logger
|
||||
from .factory import (
|
||||
@ -261,18 +260,16 @@ class AutoModelForCausalLMFactory(AutoModelFactory):
|
||||
return self._quant_config_reader.get_config()
|
||||
return {}
|
||||
|
||||
def get_cache_config(self):
|
||||
"""Return kv cache dtype configuration."""
|
||||
def get_cache_config_updates(self):
|
||||
"""Return kv cache dtype updates."""
|
||||
if not self._quant_config_reader:
|
||||
return CacheConfig(dtype=None)
|
||||
return {}
|
||||
|
||||
kv_cache_dtype = self._quant_config_reader.get_config().get("kv_cache_dtype")
|
||||
torch_dtype = torch.float8_e4m3fn if kv_cache_dtype == "float8_e4m3fn" else None
|
||||
assert torch_dtype in (torch.float8_e4m3fn, None), (
|
||||
f"Unsupported dtype: {torch_dtype}. Only torch.float8_e4m3fn is supported."
|
||||
kv_cache_dtype = self._quant_config_reader.get_config().get("kv_cache_dtype", "auto")
|
||||
assert kv_cache_dtype in ("fp8", "auto"), (
|
||||
f"Unsupported dtype: {kv_cache_dtype}. Only fp8 and auto are supported."
|
||||
)
|
||||
|
||||
return CacheConfig(dtype=torch_dtype)
|
||||
return {"dtype": kv_cache_dtype}
|
||||
|
||||
def init_tokenizer(self) -> Optional[Any]:
|
||||
"""Initialize the tokenizer—either a custom name or the model's default."""
|
||||
|
||||
@ -106,7 +106,7 @@ class ModelOPTQuantConfigReader(QuantConfigReader):
|
||||
if kv_algo:
|
||||
if kv_algo != "FP8":
|
||||
raise ValueError(f"KV cache quantization format {kv_algo} not supported.")
|
||||
quant_config["kv_cache_dtype"] = "float8_e4m3fn"
|
||||
quant_config["kv_cache_dtype"] = "fp8"
|
||||
|
||||
self._quant_config = quant_config
|
||||
|
||||
|
||||
@ -41,6 +41,7 @@ from tensorrt_llm._utils import nvtx_range
|
||||
from tensorrt_llm.llmapi.llm_args import (
|
||||
ContextChunkingPolicy,
|
||||
EagleDecodingConfig,
|
||||
KvCacheConfig,
|
||||
LoadFormat,
|
||||
SamplerType,
|
||||
TorchLlmArgs,
|
||||
@ -48,9 +49,9 @@ from tensorrt_llm.llmapi.llm_args import (
|
||||
from tensorrt_llm.llmapi.tokenizer import TokenizerBase
|
||||
|
||||
from ...._utils import get_free_port, mpi_rank, mpi_world_size
|
||||
from ....bindings.internal.batch_manager import CacheType
|
||||
from ....mapping import Mapping
|
||||
from ...distributed import Distributed
|
||||
from ...pyexecutor.mamba_cache_manager import MambaHybridCacheManager
|
||||
from ...pyexecutor.model_engine import ModelEngine, PyTorchModelEngine
|
||||
from ...pyexecutor.py_executor import PyExecutor
|
||||
from ...pyexecutor.resource_manager import (
|
||||
@ -67,7 +68,6 @@ from ...pyexecutor.scheduler import (
|
||||
ScheduledRequests,
|
||||
SimpleScheduler,
|
||||
)
|
||||
from ..custom_ops.attention_interface import SequenceInfo
|
||||
from ..distributed.common import initialize_or_skip
|
||||
from ..llm_args import LlmArgs
|
||||
from ..transform.optimizer import InferenceOptimizer
|
||||
@ -82,44 +82,6 @@ class ReportingInfo:
|
||||
enable_iter_req_stats: bool = False
|
||||
|
||||
|
||||
class _CacheManagerWithFakePool(KVCacheManager):
|
||||
"""We use the default KVCacheManager but with a fake pool by setting head_dim=0.
|
||||
|
||||
The actual cache pools are managed by auto_deploy layerwise cache pools.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_config,
|
||||
num_blocks: int,
|
||||
tokens_per_block: int,
|
||||
max_seq_len: int,
|
||||
max_batch_size: int,
|
||||
):
|
||||
self.num_blocks = num_blocks
|
||||
super().__init__(
|
||||
kv_cache_config=kv_cache_config,
|
||||
kv_cache_type=CacheType.SELF,
|
||||
num_layers=1,
|
||||
num_kv_heads=1,
|
||||
head_dim=0,
|
||||
tokens_per_block=tokens_per_block,
|
||||
max_seq_len=max_seq_len,
|
||||
max_batch_size=max_batch_size,
|
||||
mapping=Mapping(),
|
||||
)
|
||||
|
||||
def calculate_max_num_blocks(
|
||||
self, kv_cache_config, head_dim, tokens_per_block, mapping, dtype, kv_factor
|
||||
) -> Tuple[int, int]:
|
||||
"""Calculate the maximum number of blocks needed for the cache."""
|
||||
# TODO: this is VERY hacky... Ideally, we want to compute the number of blocks
|
||||
# just like in the original implementation. However, let's wait for the layer-wise attention
|
||||
# implementation before over-optimizing the function here
|
||||
ad_logger.info("Using fake cache manager with head_dim=0 and num pages:", self.num_blocks)
|
||||
return self.num_blocks, 0
|
||||
|
||||
|
||||
class ADHiddenStateManager(Eagle3ResourceManager):
|
||||
def __init__(
|
||||
self,
|
||||
@ -275,6 +237,7 @@ def construct_draft_llm_args(
|
||||
def create_draft_kv_cache_manager_maybe(
|
||||
draft_model_engine: Optional[PyTorchModelEngine],
|
||||
ad_config: LlmArgs,
|
||||
kv_cache_config_tuned: KvCacheConfig,
|
||||
dist_mapping: Mapping,
|
||||
) -> Optional[KVCacheManager]:
|
||||
if draft_model_engine is None or not draft_model_engine.model.model_config.is_generation:
|
||||
@ -287,8 +250,8 @@ def create_draft_kv_cache_manager_maybe(
|
||||
model_engine=draft_model_engine,
|
||||
kv_cache_manager_cls=kv_cache_manager_cls,
|
||||
mapping=dist_mapping,
|
||||
kv_cache_config=ad_config.kv_cache_config,
|
||||
tokens_per_block=ad_config.attn_page_size,
|
||||
kv_cache_config=kv_cache_config_tuned,
|
||||
tokens_per_block=kv_cache_config_tuned.tokens_per_block,
|
||||
max_seq_len=ad_config.max_seq_len,
|
||||
max_batch_size=ad_config.max_batch_size,
|
||||
spec_config=ad_config.speculative_config,
|
||||
@ -321,19 +284,33 @@ def _generate_dummy_request(
|
||||
ResourceManagerType.SPEC_RESOURCE_MANAGER
|
||||
)
|
||||
|
||||
# check if we have a free slot available and free page available
|
||||
if not slot_manager.slot_manager.free_slots or kv_cache_manager.get_num_free_blocks() == 0:
|
||||
# check if it's a hybrid kv-cache manager
|
||||
is_hybrid_cache = isinstance(kv_cache_manager, MambaHybridCacheManager)
|
||||
|
||||
# check if we have a free page and free state available
|
||||
if not kv_cache_manager.get_num_free_blocks():
|
||||
return None
|
||||
if is_hybrid_cache and not kv_cache_manager.mamba_cache_free_blocks:
|
||||
return None
|
||||
|
||||
# generate a dummy request
|
||||
dummy_request = kv_cache_manager.add_dummy_requests([request_id], **request_kwargs)[0]
|
||||
dummy_request.is_cuda_graph_dummy = True
|
||||
|
||||
# generate a dummy scheduled requests object
|
||||
dummy_scheduled_requests = ScheduledRequests()
|
||||
dummy_scheduled_requests.generation_requests.append(dummy_request)
|
||||
|
||||
# if it's a hybrid kv-cache manager, we need to manually call prepare_resources again (not done
|
||||
# in add_dummy_requests)
|
||||
if is_hybrid_cache:
|
||||
kv_cache_manager.prepare_resources(dummy_scheduled_requests)
|
||||
|
||||
# add to spec resource manager
|
||||
if spec_res_mgr:
|
||||
spec_res_mgr.add_dummy_requests([request_id])
|
||||
|
||||
# TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/9883 clean up this hack
|
||||
# NOTE: hack to avoid blocking a slot for the dummy request
|
||||
dummy_request.seq_slot = slot_manager.get_max_resource_count()
|
||||
dummy_request.py_seq_slot = dummy_request.seq_slot
|
||||
|
||||
@ -448,11 +425,6 @@ class ADEngine(ModelEngine):
|
||||
):
|
||||
"""Build the ADEngine using the LlmArgs that gets passed through from the LLM."""
|
||||
|
||||
max_batch_size = ad_config.max_batch_size
|
||||
max_seq_len = ad_config.max_seq_len
|
||||
attn_page_size = ad_config.attn_page_size
|
||||
max_num_tokens = ad_config.max_num_tokens
|
||||
|
||||
# update device to contain the current default device if it's in cuda
|
||||
device = torch.device(ad_config.device)
|
||||
if device.type == "cuda" and device.index is None:
|
||||
@ -461,14 +433,17 @@ class ADEngine(ModelEngine):
|
||||
|
||||
factory = ad_config.create_factory()
|
||||
|
||||
# initialize seq info object
|
||||
seq_info = SequenceInfo(
|
||||
max_seq_len=max_seq_len,
|
||||
max_batch_size=max_batch_size,
|
||||
page_size=attn_page_size,
|
||||
max_num_tokens=max_num_tokens,
|
||||
# Initialize CachedSequenceInterface - it will create SequenceInfo internally
|
||||
# using tokens_per_block from kv_cache_config
|
||||
cache_seq_interface = CachedSequenceInterface(
|
||||
max_seq_len=ad_config.max_seq_len,
|
||||
max_batch_size=ad_config.max_batch_size,
|
||||
device=device,
|
||||
kv_cache_config=ad_config.kv_cache_config,
|
||||
max_num_tokens=ad_config.max_num_tokens,
|
||||
vocab_size_padded=factory.vocab_size_padded,
|
||||
)
|
||||
|
||||
reporting_info = ReportingInfo(
|
||||
print_log=False,
|
||||
enable_iter_perf_stats=ad_config.enable_iter_perf_stats,
|
||||
@ -483,8 +458,7 @@ class ADEngine(ModelEngine):
|
||||
# construct engine
|
||||
return cls(
|
||||
build_and_optimize,
|
||||
seq_info,
|
||||
device,
|
||||
cache_seq_interface,
|
||||
ad_config=ad_config,
|
||||
mapping=mapping,
|
||||
dist=dist,
|
||||
@ -495,14 +469,21 @@ class ADEngine(ModelEngine):
|
||||
def __init__(
|
||||
self,
|
||||
get_inference_model: GetInferenceModel,
|
||||
seq_info: SequenceInfo,
|
||||
device: DeviceLikeType,
|
||||
cache_seq_interface: CachedSequenceInterface,
|
||||
ad_config: Optional[LlmArgs] = None,
|
||||
mapping: Optional[Mapping] = None,
|
||||
dist: Optional[Distributed] = None,
|
||||
reporting_info: ReportingInfo = ReportingInfo(),
|
||||
) -> None:
|
||||
"""Initialize the engine with model and sequence information."""
|
||||
"""Initialize the engine with model and CachedSequenceInterface.
|
||||
|
||||
Args:
|
||||
get_inference_model: Callable that builds the inference model.
|
||||
cache_seq_interface: The CachedSequenceInterface containing sequence and cache config.
|
||||
ad_config: Optional LLM configuration.
|
||||
mapping: Optional distributed mapping configuration.
|
||||
reporting_info: Reporting configuration for logging.
|
||||
"""
|
||||
# NOTE (lucaslie): create a fake Namespace to satisfy PyExecutor requirements...
|
||||
# This is not correctly declared in the base ModelEngine class though...
|
||||
self.llm_args = SimpleNamespace()
|
||||
@ -514,8 +495,8 @@ class ADEngine(ModelEngine):
|
||||
self.llm_args.batch_wait_timeout_ms = 0
|
||||
self.llm_args.batch_wait_timeout_iters = 0
|
||||
self.llm_args.batch_wait_max_tokens_ratio = 0.0
|
||||
self.llm_args.max_num_tokens = seq_info.max_num_tokens
|
||||
self.llm_args.max_seq_len = seq_info.max_seq_len
|
||||
self.llm_args.max_num_tokens = cache_seq_interface.info.max_num_tokens
|
||||
self.llm_args.max_seq_len = cache_seq_interface.info.max_seq_len
|
||||
self.iter_counter = 0
|
||||
self.iter_states = {}
|
||||
|
||||
@ -546,13 +527,10 @@ class ADEngine(ModelEngine):
|
||||
)
|
||||
|
||||
# For compatibility with PyTorchModelEngine utilities
|
||||
self.batch_size = seq_info.max_batch_size
|
||||
self.batch_size = cache_seq_interface.info.max_batch_size
|
||||
|
||||
# construct cache sequence interface
|
||||
self.cache_seq_interface = CachedSequenceInterface(
|
||||
sequence_info=seq_info,
|
||||
device=device,
|
||||
)
|
||||
# Store the cache sequence interface
|
||||
self.cache_seq_interface = cache_seq_interface
|
||||
|
||||
# build model
|
||||
self.model = get_inference_model(self.cache_seq_interface)
|
||||
@ -628,12 +606,18 @@ class ADEngine(ModelEngine):
|
||||
# gather indices for logits
|
||||
logits_gather_indices: List[int] = []
|
||||
|
||||
page_size = self.cache_seq_interface.info.page_size
|
||||
page_size = kv_cache_manager.tokens_per_block
|
||||
dummy_token = -1
|
||||
num_ctx_requests = len(context_requests)
|
||||
num_ctx_tokens = 0
|
||||
num_generation_tokens = 0
|
||||
|
||||
# Helper to get slot index - use mamba_cache_index if available (MambaHybridCacheManager)
|
||||
def _get_slot_idx(request) -> int:
|
||||
if hasattr(kv_cache_manager, "mamba_cache_index"):
|
||||
return kv_cache_manager.mamba_cache_index[request.py_request_id]
|
||||
return request.seq_slot
|
||||
|
||||
# look at context requests first
|
||||
for request in context_requests:
|
||||
# store input ids and pos of first token in sequence
|
||||
@ -669,8 +653,8 @@ class ADEngine(ModelEngine):
|
||||
|
||||
position_ids.append(list(range(input_pos[-1], seq_len_with_cache[-1])))
|
||||
|
||||
# store seq slot idx
|
||||
slot_idx.append(request.seq_slot)
|
||||
# store seq slot idx (use mamba_cache_index if available)
|
||||
slot_idx.append(_get_slot_idx(request))
|
||||
use_initial_states.append(input_pos[-1] > 0)
|
||||
|
||||
# store extra arguments
|
||||
@ -749,7 +733,8 @@ class ADEngine(ModelEngine):
|
||||
|
||||
num_generation_tokens += 1 + get_draft_token_length(request)
|
||||
request.py_batch_idx = request.seq_slot
|
||||
slot_idx.append(request.seq_slot)
|
||||
# store seq slot idx (use mamba_cache_index if available)
|
||||
slot_idx.append(_get_slot_idx(request))
|
||||
use_initial_states.append(input_pos[-1] > 0)
|
||||
|
||||
seq_len.append(len(input_ids[-1]))
|
||||
@ -941,11 +926,11 @@ def create_draft_model_engine_maybe(
|
||||
|
||||
draft_spec_config = copy.copy(spec_config)
|
||||
|
||||
kv_cache_config = ad_config.kv_cache_config
|
||||
kv_cache_config_tuned = target_engine.cache_seq_interface.kv_cache_config_tuned
|
||||
|
||||
attn_runtime_features = AttentionRuntimeFeatures(
|
||||
chunked_prefill=ad_config.enable_chunked_prefill,
|
||||
cache_reuse=kv_cache_config.enable_block_reuse,
|
||||
cache_reuse=kv_cache_config_tuned.enable_block_reuse,
|
||||
has_speculative_draft_tokens=has_spec_drafter,
|
||||
chunk_size=target_engine.llm_args.max_num_tokens,
|
||||
)
|
||||
@ -1108,40 +1093,15 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer
|
||||
else None
|
||||
)
|
||||
|
||||
# check kvcache config for partial block reuse
|
||||
# TODO: copy_on_partial_reuse is not supported yet, see
|
||||
# https://github.com/NVIDIA/TensorRT-LLM/issues/7142 for more details.
|
||||
enable_block_reuse = ad_config.kv_cache_config.enable_block_reuse
|
||||
enable_partial_reuse = ad_config.kv_cache_config.enable_partial_reuse
|
||||
copy_on_partial_reuse = ad_config.kv_cache_config.copy_on_partial_reuse
|
||||
if enable_block_reuse and enable_partial_reuse and copy_on_partial_reuse:
|
||||
raise RuntimeError(
|
||||
f"partial block reuse with {copy_on_partial_reuse=} set to True is NOT supported"
|
||||
" in AutoDeploy. Please set it to False via the kv_cache_config.copy_on_partial_reuse "
|
||||
"field in tensorrt_llm._torch.auto_deploy.llm_args.LlmArgs."
|
||||
)
|
||||
|
||||
# TODO: detect whether SSM layer is present in the model and raise an error or disable block
|
||||
# reuse with a warning --> see https://github.com/NVIDIA/TensorRT-LLM/issues/7142. For now, we
|
||||
# just emit a general warning.
|
||||
if enable_block_reuse:
|
||||
ad_logger.warning(
|
||||
f"{enable_block_reuse=} is enabled. Note that this is not supported for SSM layers and"
|
||||
" may lead to incorrect results if the model contains SSM layers."
|
||||
)
|
||||
|
||||
# resource managers
|
||||
kv_cache_manager = _CacheManagerWithFakePool(
|
||||
ad_config.kv_cache_config,
|
||||
num_blocks=engine.cache_seq_interface.info.num_pages,
|
||||
tokens_per_block=ad_config.attn_page_size,
|
||||
max_seq_len=ad_config.max_seq_len,
|
||||
max_batch_size=ad_config.max_batch_size,
|
||||
)
|
||||
# KVCacheManager is now created and managed by CachedSequenceInterface during the
|
||||
# initialize_cache/resize_kv_cache transform pipeline. Get it from the interface.
|
||||
kv_cache_manager = engine.cache_seq_interface.kv_cache_manager
|
||||
kv_cache_config_tuned = engine.cache_seq_interface.kv_cache_config_tuned
|
||||
seq_slot_manager = SeqSlotManager(max_num_sequences=max_num_sequences)
|
||||
|
||||
draft_kv_cache_manager = create_draft_kv_cache_manager_maybe(
|
||||
draft_model_engine, ad_config, dist_mapping
|
||||
draft_model_engine, ad_config, kv_cache_config_tuned, dist_mapping
|
||||
)
|
||||
|
||||
resource_manager = ResourceManager(
|
||||
@ -1160,7 +1120,7 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer
|
||||
|
||||
# Chunked prefill
|
||||
if ad_config.enable_chunked_prefill:
|
||||
chunk_unit_size = ad_config.attn_page_size
|
||||
chunk_unit_size = kv_cache_config_tuned.tokens_per_block
|
||||
chunking_policy = ContextChunkingPolicy.FIRST_COME_FIRST_SERVED
|
||||
ctx_chunk_config: Tuple[StrEnum, int] = (chunking_policy, chunk_unit_size)
|
||||
else:
|
||||
|
||||
@ -71,14 +71,15 @@ class DemoEngine(ADEngine):
|
||||
currently available token slots in the assigned pages and assign a new, previously
|
||||
unassigned page if needed.
|
||||
"""
|
||||
si = self.cache_seq_interface.info
|
||||
page_assignments = si.page_assignments
|
||||
page_assignments = self.cache_seq_interface.info.page_assignments
|
||||
num_pages = self.cache_seq_interface.kv_cache_manager.blocks_in_primary_pool
|
||||
tokens_per_block = self.cache_seq_interface.kv_cache_manager.tokens_per_block
|
||||
|
||||
free_pages = set(range(si.num_pages)) - {i for pages in page_assignments for i in pages}
|
||||
free_pages = set(range(num_pages)) - {i for pages in page_assignments for i in pages}
|
||||
updated_assignments = []
|
||||
for t_l, pages in zip(total_lens, page_assignments):
|
||||
extra_tokens = t_l - len(pages) * si.page_size
|
||||
num_extra_pages = (extra_tokens // si.page_size) + (extra_tokens > 0)
|
||||
extra_tokens = t_l - len(pages) * tokens_per_block
|
||||
num_extra_pages = (extra_tokens // tokens_per_block) + (extra_tokens > 0)
|
||||
updated_assignments.append(pages + [free_pages.pop() for _ in range(num_extra_pages)])
|
||||
return updated_assignments
|
||||
|
||||
|
||||
@ -1,95 +1,473 @@
|
||||
from typing import Callable, Dict, List, Optional, Tuple, final
|
||||
import copy
|
||||
import functools
|
||||
import math
|
||||
from typing import Callable, Dict, Optional, Tuple, Union, final
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch._prims_common import DeviceLikeType
|
||||
|
||||
from ..custom_ops.attention_interface import GetCacheCallable, SequenceInfo
|
||||
import tensorrt_llm.bindings
|
||||
from tensorrt_llm.llmapi.llm_args import KvCacheConfig
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
|
||||
from ...pyexecutor.mamba_cache_manager import MambaHybridCacheManager
|
||||
from ...pyexecutor.resource_manager import KVCacheManager
|
||||
from ..custom_ops.attention_interface import (
|
||||
PagedResourceHandler,
|
||||
ResourceHandler,
|
||||
ResourceHandlerDict,
|
||||
SequenceInfo,
|
||||
StateResourceHandler,
|
||||
)
|
||||
from ..distributed.common import all_gather_object, get_world_size
|
||||
from ..distributed.common import is_initialized as is_distributed_initialized
|
||||
from ..utils.cuda_mem_tracker import bytes_to, get_mem_info
|
||||
from ..utils.logger import ad_logger
|
||||
|
||||
CacheTypeCpp = tensorrt_llm.bindings.internal.batch_manager.CacheType
|
||||
DataType = tensorrt_llm.bindings.DataType
|
||||
|
||||
|
||||
def with_pre_callback(method, callback):
|
||||
"""Wrap method to call callback before the original method."""
|
||||
|
||||
@functools.wraps(method)
|
||||
def wrapper(*args, **kwargs):
|
||||
callback()
|
||||
return method(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@final
|
||||
class CachedSequenceInterface:
|
||||
"""An interface responsible for maintaining information about sequences and their caches."""
|
||||
"""An interface responsible for maintaining information about sequences and their caches.
|
||||
|
||||
This class is the single source of truth for sequence and cache configuration. It creates
|
||||
SequenceInfo internally, ensuring that tokens_per_block and other fields from KvCacheConfig
|
||||
are always consistent.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, sequence_info: SequenceInfo, device: Optional[DeviceLikeType] = None
|
||||
self,
|
||||
max_seq_len: int,
|
||||
max_batch_size: int,
|
||||
device: Optional[DeviceLikeType] = None,
|
||||
kv_cache_config: Optional[KvCacheConfig] = None,
|
||||
max_num_tokens: Optional[int] = None,
|
||||
vocab_size_padded: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Initialize the CachedSequenceInterface.
|
||||
|
||||
Args:
|
||||
max_seq_len: Maximum sequence length including input and generated tokens.
|
||||
max_batch_size: Maximum number of sequences (requests) that can be processed.
|
||||
device: Target device for tensors. Defaults to "cuda".
|
||||
kv_cache_config: KV cache configuration. If None, uses default KvCacheConfig.
|
||||
max_num_tokens: Maximum total tokens across all sequences. If None, computed from
|
||||
max_seq_len and max_batch_size.
|
||||
vocab_size_padded: Padded vocabulary size of the model.
|
||||
"""
|
||||
# TODO (lucaslie): this is somewhat circular/confusing. Here `device` denotes the desired
|
||||
# device and not the actual device unlike, e.g., in SequenceInfo. We rely on the attribute
|
||||
# here to read the desired device across the inference optimizer pipeline. We should ideally
|
||||
# think about a better way to handle this,
|
||||
# see https://github.com/NVIDIA/TensorRT-LLM/issues/8371
|
||||
self.device = device or "cuda"
|
||||
self.info = sequence_info
|
||||
self._cache_initializers: Dict[str, GetCacheCallable] = {}
|
||||
|
||||
# Initialize kv_cache_config first since SequenceInfo needs tokens_per_block from it
|
||||
self._kv_cache_config_original: KvCacheConfig = kv_cache_config or KvCacheConfig()
|
||||
self._kv_cache_config_tuned: Optional[KvCacheConfig] = None
|
||||
|
||||
# Create SequenceInfo internally, using tokens_per_block from kv_cache_config
|
||||
self.info = SequenceInfo(
|
||||
max_seq_len=max_seq_len,
|
||||
max_batch_size=max_batch_size,
|
||||
tokens_per_block=self._kv_cache_config_original.tokens_per_block,
|
||||
max_num_tokens=max_num_tokens,
|
||||
vocab_size_padded=vocab_size_padded,
|
||||
)
|
||||
|
||||
self._resource_lookup: ResourceHandlerDict = {}
|
||||
self._caches: Dict[str, torch.Tensor] = {}
|
||||
# KVCacheManager (or MambaHybridCacheManager) for managed resources
|
||||
self._kv_cache_manager: Optional[Union[KVCacheManager, MambaHybridCacheManager]] = None
|
||||
# Ordered dicts tracking resource handlers by type
|
||||
self._paged_cache_order: ResourceHandlerDict = {} # Paged resources (kv caches)
|
||||
self._state_resource_order: ResourceHandlerDict = {} # State resources (ssm states)
|
||||
|
||||
@property
|
||||
def args(self) -> Tuple[torch.Tensor, ...]:
|
||||
"""Return all the graph arguments owned by this interface."""
|
||||
return (*self.info.args, *self._caches.values())
|
||||
return tuple(self.named_args.values())
|
||||
|
||||
@property
|
||||
def named_args(self) -> Dict[str, torch.Tensor]:
|
||||
"""Return all the named arguments owned by this interface."""
|
||||
return {**self.info.named_args, **self._caches}
|
||||
|
||||
@property
|
||||
def all_future_arg_names(self) -> List[str]:
|
||||
"""Return all the argument names owned by this interface including uninitialized caches."""
|
||||
return list(self.info.named_args.keys()) + list(self._cache_initializers.keys())
|
||||
|
||||
def to(self, *args, **kwargs) -> None:
|
||||
self.info.to(*args, **kwargs)
|
||||
if self._caches:
|
||||
for cache in self._caches.values():
|
||||
# Only move locally-allocated caches (paged/state caches are managed by cache managers)
|
||||
for name, cache in self._caches.items():
|
||||
if name not in self._paged_cache_order and name not in self._state_resource_order:
|
||||
cache.to(*args, **kwargs)
|
||||
|
||||
def add_cache(self, name: str, get_cache: GetCacheCallable) -> None:
|
||||
"""Add a cache initializer to the cache interface."""
|
||||
self._cache_initializers[name] = get_cache
|
||||
def update_kv_cache_config(self, **kwargs) -> None:
|
||||
"""Update the KVCacheConfig with the given kwargs."""
|
||||
for k, v in kwargs.items():
|
||||
if k in type(self._kv_cache_config_original).model_fields:
|
||||
setattr(self._kv_cache_config_original, k, v)
|
||||
else:
|
||||
raise ValueError(f"Invalid KVCacheConfig field: {k}")
|
||||
|
||||
def initialize_caches(self) -> int:
|
||||
"""Initialize caches using the cache initializers."""
|
||||
assert not self._caches, "Caches already initialized."
|
||||
self.info.to(self.device)
|
||||
self._caches = {
|
||||
name: get_cache(self.info) for name, get_cache in self._cache_initializers.items()
|
||||
def add_resource(self, name: str, resource_handler: ResourceHandler) -> None:
|
||||
"""Add a resource handler to the cache interface."""
|
||||
self._resource_lookup[name] = resource_handler
|
||||
|
||||
def _create_kv_cache_manager(self, max_tokens: Optional[int] = None) -> int:
|
||||
"""Create KVCacheManager or MambaHybridCacheManager with multi-layer byte-level params.
|
||||
|
||||
This uses a multi-layer approach with byte-level abstraction:
|
||||
- Paged resources: Each resource gets its own layer in KVCacheManager with
|
||||
num_kv_heads=bytes_per_token for that resource, head_dim=1.
|
||||
- State resources: Each resource gets its own layer in MambaCacheManager with
|
||||
head_dim=bytes_per_slot for that resource.
|
||||
|
||||
Each layer's cache is contiguous, avoiding byte-offset slicing within layers.
|
||||
|
||||
When state resources exist, MambaHybridCacheManager is used to manage both.
|
||||
|
||||
Important NOTE on contiguity of managed resources:
|
||||
- We only guarantee contiguity for an individual page or an individual state slot.
|
||||
- Outside of these individual pages/slots, resources are NOT guaranteed to be contiguous.
|
||||
|
||||
Args:
|
||||
max_tokens: Maximum number of tokens to allocate. If provided, it will use the min value
|
||||
between this value and max_tokens in kv_cache_config.
|
||||
|
||||
Returns:
|
||||
The final number of tokens that can be cached in the KVCacheManager.
|
||||
NOTE: this number may differ from the provided ``max_tokens`` arg for two reasons:
|
||||
1. the final number of tokens is synced (min) across ranks
|
||||
2. rounding for getting a multiple of tokens_per_block
|
||||
"""
|
||||
# Build per-layer num_kv_heads list for paged resources
|
||||
# Each paged resource becomes one "layer" with num_kv_heads = bytes_per_token
|
||||
num_kv_heads_per_layer = [
|
||||
math.prod(h.token_shape) * h.dtype.itemsize for h in self._paged_cache_order.values()
|
||||
]
|
||||
|
||||
# Calculate total bytes per slot for state resources (modeled as single layer)
|
||||
cumulative_bytes_per_state = [0]
|
||||
for name, handler in self._state_resource_order.items():
|
||||
byte_size = math.prod(handler.state_shape) * handler.dtype.itemsize
|
||||
cumulative_bytes_per_state.append(cumulative_bytes_per_state[-1] + byte_size)
|
||||
|
||||
# Make a deep copy of the kv_cache_config to avoid modifying the original object
|
||||
kv_cache_config = copy.deepcopy(self._kv_cache_config_original)
|
||||
|
||||
# Disable copy_on_partial_reuse
|
||||
# TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/10966
|
||||
if kv_cache_config.copy_on_partial_reuse:
|
||||
kv_cache_config.copy_on_partial_reuse = False
|
||||
ad_logger.info("Disabling copy_on_partial_reuse for AutoDeploy backend.")
|
||||
|
||||
# Update kv_cache_config based on max_tokens if provided
|
||||
if max_tokens is not None:
|
||||
# sync max_tokens across ranks
|
||||
if is_distributed_initialized():
|
||||
max_tokens_gathered = [None] * get_world_size()
|
||||
all_gather_object(max_tokens_gathered, max_tokens)
|
||||
max_tokens = min(max_tokens_gathered)
|
||||
kv_cache_config.free_gpu_memory_fraction = None
|
||||
kv_cache_config.max_tokens = min(kv_cache_config.max_tokens or max_tokens, max_tokens)
|
||||
|
||||
# Check if we should disable block reuse
|
||||
if kv_cache_config.enable_block_reuse and not self.is_paged():
|
||||
kv_cache_config.enable_block_reuse = False
|
||||
ad_logger.info(f"Setting {kv_cache_config.enable_block_reuse=} for non-paged models.")
|
||||
|
||||
# Make sure to set free_gpu_memory_fraction to None if set to 0.0
|
||||
# NOTE: KVCacheConfig validator enforces that free_gpu_memory_fraction must be between 0.0
|
||||
# and 1.0 but we allow 0.0 to be set to disable resizing (corresponding to None in the
|
||||
# manager).
|
||||
if kv_cache_config.free_gpu_memory_fraction == 0.0:
|
||||
kv_cache_config.free_gpu_memory_fraction = None
|
||||
|
||||
# Common KV cache parameters
|
||||
kv_cache_kwargs = {
|
||||
"kv_cache_config": kv_cache_config,
|
||||
"kv_cache_type": CacheTypeCpp.SELFKONLY, # kv_factor=1, treat K, V separately
|
||||
"num_layers": len(self._paged_cache_order), # correct num layers
|
||||
"num_kv_heads": num_kv_heads_per_layer, # per-layer bytes_per_token
|
||||
"head_dim": 1, # all bytes in num_kv_heads
|
||||
"tokens_per_block": kv_cache_config.tokens_per_block,
|
||||
"max_seq_len": self.info.max_seq_len,
|
||||
"max_batch_size": self.info.max_batch_size,
|
||||
"mapping": Mapping(),
|
||||
# NOTE (lucaslie): this is the only 1-byte dtype currently supported by the
|
||||
# KVCacheManager. Ideally, we would use the typical uint8 dtype for byte-level
|
||||
# abstraction, but this is not supported.
|
||||
"dtype": DataType.FP8, # 1-byte dtype for byte-level abstraction
|
||||
"layer_mask": None,
|
||||
# NOTE (lucaslie): we can always run with False here since when we are estimating, we
|
||||
# are explicitly setting the max_tokens in which case it's okay to use False here since
|
||||
# we don't rely on free_gpu_memory_fraction inside the KVCacheManager. This is similar
|
||||
# to _torch.pyexecutor._util.KVCacheCreator, which explicitly estimates the max_tokens
|
||||
# outside of the KVCacheManager.
|
||||
"is_estimating_kv_cache": False,
|
||||
}
|
||||
|
||||
# update args if we are just doing a dummy cache manager
|
||||
if not len(self._paged_cache_order):
|
||||
kv_cache_kwargs.update(
|
||||
{
|
||||
"num_layers": 1,
|
||||
"num_kv_heads": 1,
|
||||
"head_dim": 1,
|
||||
}
|
||||
)
|
||||
|
||||
if self._state_resource_order:
|
||||
# NOTE: +1 for cuda graph padding
|
||||
kv_cache_kwargs["max_batch_size"] = self.info.max_num_state_slots
|
||||
|
||||
self._kv_cache_manager = MambaHybridCacheManager(
|
||||
# Mamba params for single-layer byte buffer
|
||||
mamba_d_state=1,
|
||||
mamba_d_conv=1, # conv_states will have shape [..., 0] (empty)
|
||||
mamba_num_heads=1,
|
||||
mamba_n_groups=1,
|
||||
mamba_head_dim=cumulative_bytes_per_state[-1], # Total bytes per slot
|
||||
mamba_num_layers=1, # Single layer
|
||||
mamba_layer_mask=None, # Single enabled layer
|
||||
mamba_cache_dtype=torch.uint8, # Byte-level
|
||||
mamba_ssm_cache_dtype=torch.uint8, # Byte-level
|
||||
# KV cache params
|
||||
**kv_cache_kwargs,
|
||||
)
|
||||
else:
|
||||
# No state resources - use pure KVCacheManager
|
||||
self._kv_cache_manager = KVCacheManager(**kv_cache_kwargs)
|
||||
|
||||
# store the tuned kv_cache_config
|
||||
self._kv_cache_config_tuned = kv_cache_config
|
||||
|
||||
# Ensure cache_loc capacity is sufficient for the new KVCacheManager
|
||||
blocks_in_primary_pool = self._kv_cache_manager.blocks_in_primary_pool
|
||||
tokens_per_block = self._kv_cache_manager.tokens_per_block
|
||||
self.info.estimate_cache_loc_capacity(blocks_in_primary_pool)
|
||||
|
||||
# Create paged resource views from per-layer buffers
|
||||
for layer_idx, (name, handler) in enumerate(self._paged_cache_order.items()):
|
||||
view = self._kv_cache_manager.get_buffers(layer_idx, kv_layout="NHD")
|
||||
view = view.view(blocks_in_primary_pool, tokens_per_block, -1).view(handler.dtype)
|
||||
view = view.view(blocks_in_primary_pool, tokens_per_block, *handler.token_shape)
|
||||
|
||||
# Sanity check on contiguity of individual pages
|
||||
view_one_page = view[0]
|
||||
assert view_one_page.is_contiguous(), f"Per-page cache for {name} is not contiguous"
|
||||
|
||||
self._caches[name] = view
|
||||
|
||||
for layer_idx, (name, handler) in enumerate(self._state_resource_order.items()):
|
||||
num_states = len(self._kv_cache_manager.state_indices)
|
||||
# Get the single-layer ssm_states buffer
|
||||
# ssm_states shape: [1, num_states, 1, total_bytes_per_slot, 1]
|
||||
ssm_buffer = self._kv_cache_manager.get_ssm_states(0)
|
||||
# Flatten to [max_batch, total_bytes_per_slot_for_all_layers]
|
||||
ssm_buffer = ssm_buffer.view(num_states, -1)
|
||||
|
||||
offset_start = cumulative_bytes_per_state[layer_idx]
|
||||
offset_end = cumulative_bytes_per_state[layer_idx + 1]
|
||||
|
||||
# Slice at byte offset, reinterpret dtype, reshape
|
||||
view = ssm_buffer[:, offset_start:offset_end]
|
||||
view = view.view(handler.dtype)
|
||||
view = view.view(num_states, *handler.state_shape)
|
||||
|
||||
# Sanity check on contiguity of individual state slots
|
||||
assert view[0].is_contiguous(), f"Per-slot state for {name} cache is not contiguous"
|
||||
|
||||
self._caches[name] = view
|
||||
|
||||
# Patch shutdown to clear cache views before pool release
|
||||
self._kv_cache_manager.shutdown = with_pre_callback(
|
||||
self._kv_cache_manager.shutdown,
|
||||
self._clear_cache_views,
|
||||
)
|
||||
|
||||
max_resource_count = self._kv_cache_manager.get_max_resource_count()
|
||||
max_tokens_final = max_resource_count * self._kv_cache_manager.tokens_per_block
|
||||
|
||||
return max_tokens_final
|
||||
|
||||
def initialize_resources(self) -> int:
|
||||
"""Initialize resources - paged/state caches via cache managers, others separately.
|
||||
|
||||
Paged resources are managed by KVCacheManager (or the KV portion of MambaHybridCacheManager).
|
||||
State resources are managed by the Mamba portion of MambaHybridCacheManager.
|
||||
Other resources are allocated locally as a fallback.
|
||||
|
||||
Returns:
|
||||
The number of caches initialized.
|
||||
"""
|
||||
assert not self._caches and not self._paged_cache_order, "Caches already initialized."
|
||||
self.info.to(self.device)
|
||||
|
||||
# Separate resources by type
|
||||
for name, handler in self._resource_lookup.items():
|
||||
if isinstance(handler, PagedResourceHandler):
|
||||
self._paged_cache_order[name] = handler
|
||||
self._caches[name] = None # Will be set by _create_kv_cache_manager
|
||||
elif isinstance(handler, StateResourceHandler):
|
||||
self._state_resource_order[name] = handler
|
||||
self._caches[name] = None # Will be set by _create_kv_cache_manager
|
||||
else:
|
||||
# Unknown handler type - allocate locally (fallback)
|
||||
self._caches[name] = handler.allocate(self.info)
|
||||
|
||||
# Create unified cache manager (handles both paged and state resources)
|
||||
if self.needs_resize() or self._requires_token_estimate():
|
||||
max_tokens_estimate = self.info.estimate_cache_tokens_per_forward()
|
||||
else:
|
||||
# if we don't need a resize, we will just use the original settings in kv_cache_config
|
||||
# instead of passing in an overwrite here.
|
||||
max_tokens_estimate = None
|
||||
self._create_kv_cache_manager(max_tokens=max_tokens_estimate)
|
||||
|
||||
return len(self._caches)
|
||||
|
||||
def current_cache_size_bytes(self) -> int:
|
||||
"""Calculate and return the total size of all caches in bytes."""
|
||||
total_size = 0
|
||||
for name, cache in self._caches.items():
|
||||
# this hack is needed since _caches also contains global buffers such as freqs_cis.
|
||||
if "cache" in name:
|
||||
total_size += cache.element_size() * cache.numel()
|
||||
return total_size
|
||||
def is_paged(self) -> bool:
|
||||
"""Return True if all resources are paged and part of the KVCacheManager."""
|
||||
return set(self._paged_cache_order.keys()) == set(self._resource_lookup.keys())
|
||||
|
||||
def current_kv_cache_size_bytes(self) -> int:
|
||||
"""Return size in bytes of KV caches only (k_cache_*, v_cache_*).
|
||||
def _requires_token_estimate(self) -> bool:
|
||||
"""Check if our kv_cache_config requires."""
|
||||
return (
|
||||
self._kv_cache_config_original.free_gpu_memory_fraction in [None, 0.0]
|
||||
and self._kv_cache_config_original.max_tokens is None
|
||||
)
|
||||
|
||||
Excludes SSM/conv/etc. which do not scale with num_pages.
|
||||
def needs_resize(self) -> bool:
|
||||
"""Check if we need a resize or not."""
|
||||
has_paged = bool(self._paged_cache_order)
|
||||
return has_paged and self._kv_cache_config_original.free_gpu_memory_fraction not in [
|
||||
None,
|
||||
0.0,
|
||||
]
|
||||
|
||||
def resize_kv_cache_manager(self, mem_exclude: int = 0) -> None:
|
||||
"""Shutdown existing KVCacheManager and create new one with optimal capacity.
|
||||
|
||||
Args:
|
||||
mem_exclude: Extra memory to exclude from the calculation of optimal capacity.
|
||||
This is in bytes and typically the memory reserved for the forward pass.
|
||||
|
||||
This implements the two-phase approach: after running a forward pass during estimation
|
||||
to allocate intermediate memory, call this method to recreate the KVCacheManager.
|
||||
The new manager will compute optimal capacity based on current free GPU memory
|
||||
via calculate_max_num_blocks.
|
||||
"""
|
||||
total_size = 0
|
||||
for name, cache in self._caches.items():
|
||||
if name.startswith("k_cache_") or name.startswith("v_cache_"):
|
||||
total_size += cache.element_size() * cache.numel()
|
||||
return total_size
|
||||
if not self.needs_resize():
|
||||
return
|
||||
|
||||
def resize_cache(self, new_num_pages: int):
|
||||
"""Resize the cache to the new number of pages."""
|
||||
# TODO: We should do some sanity check on the new number of pages.
|
||||
self.info.num_pages = new_num_pages
|
||||
for name, cache in self._caches.items():
|
||||
# We assume cache is a tensor of shape (max_batch_size, page_size, n_heads, head_dim)
|
||||
# TODO: cache resize should ideally be handled via a callback to the AttentionDescriptor
|
||||
# to avoid hard-coding any assumptions about the cache shape or its "pagedness"
|
||||
if "k_cache" in name or "v_cache" in name:
|
||||
current_shape = cache.shape
|
||||
new_shape = (new_num_pages, *current_shape[1:])
|
||||
cache.resize_(new_shape)
|
||||
# get per-token cache size for resizable resources
|
||||
paged_cache_bytes_per_token = self._kv_cache_manager.get_cache_bytes_per_token()
|
||||
|
||||
# get total cache size of state resources that cannot be resized
|
||||
# NOTE: this does NOT include resources handled OUTSIDE of the KVCacheManager or
|
||||
# MambaHybridCacheManager. Those will persistent and will be accounted for via free_mem even
|
||||
# after the initialize kv_cache_manager is shutdown.
|
||||
state_cache_bytes_total = sum(
|
||||
cache.numel() * cache.element_size()
|
||||
for name, cache in self._caches.items()
|
||||
if name in self._state_resource_order
|
||||
)
|
||||
|
||||
# get unmanaged cache size
|
||||
unmanaged_cache_bytes_total = sum(
|
||||
cache.numel() * cache.element_size()
|
||||
for name, cache in self._caches.items()
|
||||
if name not in self._paged_cache_order and name not in self._state_resource_order
|
||||
)
|
||||
|
||||
# Shutdown existing KVCacheManager to free memory
|
||||
self._kv_cache_manager.shutdown()
|
||||
|
||||
# Get current free GPU memory (roughly includes model weights + non-managed resources)
|
||||
_, free_mem, *_ = get_mem_info(empty_cache=True)
|
||||
|
||||
# Compute available memory for the KVCacheManager
|
||||
# NOTE: free_mem was obtained AFTER shutdown of initial KVCacheManager - hence it accounts
|
||||
# for unmanaged resources but it does NOT account for state resources since those were
|
||||
# freed as part of the shutdown.
|
||||
free_gpu_memory_fraction = self._kv_cache_config_original.free_gpu_memory_fraction
|
||||
mem_for_paged_optimal = (
|
||||
free_mem - state_cache_bytes_total - mem_exclude
|
||||
) * free_gpu_memory_fraction
|
||||
# Check how many tokens we can fit into the paged cache
|
||||
max_tokens_optimal = int(mem_for_paged_optimal // paged_cache_bytes_per_token)
|
||||
|
||||
# Create new KVCacheManager with final capacity
|
||||
max_tokens_final = self._create_kv_cache_manager(max_tokens=max_tokens_optimal)
|
||||
|
||||
# Log resulting memory information
|
||||
mem_info = [
|
||||
f"free_mem={bytes_to(free_mem, unit='GB'):.2f}GB",
|
||||
f"free_gpu_memory_fraction={free_gpu_memory_fraction}",
|
||||
f"mem_exclude={bytes_to(mem_exclude, unit='GB'):.2f}GB",
|
||||
f"mem_exclude_for_state={bytes_to(state_cache_bytes_total, unit='GB'):.2f}GB",
|
||||
f"mem_for_paged_optimal={bytes_to(mem_for_paged_optimal, unit='GB'):.2f}GB",
|
||||
]
|
||||
total_cache_bytes = (
|
||||
mem_for_paged_optimal + state_cache_bytes_total + unmanaged_cache_bytes_total
|
||||
)
|
||||
mem_cache_info = [
|
||||
f"Max Tokens={max_tokens_final}",
|
||||
f"Paged={bytes_to(mem_for_paged_optimal, unit='GB'):.2f}GB",
|
||||
f"State={bytes_to(state_cache_bytes_total, unit='GB'):.2f}GB",
|
||||
f"Unmanaged={bytes_to(unmanaged_cache_bytes_total, unit='GB'):.2f}GB",
|
||||
f"Total={bytes_to(total_cache_bytes, unit='GB'):.2f}GB",
|
||||
]
|
||||
ad_logger.info(f"Mem info for resize: {' | '.join(mem_info)}")
|
||||
ad_logger.info(f"Final Cache Mem: {' | '.join(mem_cache_info)}")
|
||||
|
||||
@property
|
||||
def kv_cache_manager(self) -> Optional[KVCacheManager]:
|
||||
"""Return the KVCacheManager managing paged resources, or None if not initialized."""
|
||||
assert self._kv_cache_manager is not None, "KVCacheManager not initialized."
|
||||
return self._kv_cache_manager
|
||||
|
||||
@property
|
||||
def kv_cache_config_tuned(self) -> KvCacheConfig:
|
||||
"""Return the KVCacheConfig tuned for the KVCacheManager."""
|
||||
assert None not in [self._kv_cache_manager, self._kv_cache_config_tuned], (
|
||||
"KVCacheManager not initialized."
|
||||
)
|
||||
return self._kv_cache_config_tuned
|
||||
|
||||
@property
|
||||
def kv_cache_config(self) -> KvCacheConfig:
|
||||
"""Return the original KVCacheConfig as passed in."""
|
||||
return self._kv_cache_config_original
|
||||
|
||||
def _clear_cache_views(self) -> None:
|
||||
"""Set paged and state cache views to None before pool release."""
|
||||
self._kv_cache_config_tuned = None
|
||||
for name in self._paged_cache_order:
|
||||
self._caches[name] = None
|
||||
for name in self._state_resource_order:
|
||||
self._caches[name] = None
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""Shutdown and release all resources."""
|
||||
if self._kv_cache_manager is not None:
|
||||
self._kv_cache_manager.shutdown()
|
||||
self._kv_cache_config_tuned = None
|
||||
self._caches.clear()
|
||||
|
||||
|
||||
GetInferenceModel = Callable[[CachedSequenceInterface], nn.Module]
|
||||
|
||||
@ -5,7 +5,8 @@ This module defines the base classes and interfaces for all transforms.
|
||||
|
||||
import time
|
||||
from abc import ABC
|
||||
from contextlib import nullcontext
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from functools import total_ordering, wraps
|
||||
from typing import Any, Callable, Dict, Mapping, Tuple, Type, Union, final
|
||||
@ -24,8 +25,60 @@ from ..utils._graph import (
|
||||
placeholders_on_meta,
|
||||
run_shape_prop,
|
||||
)
|
||||
from ..utils.cuda_mem_tracker import get_mem_info
|
||||
from ..utils.logger import ad_logger
|
||||
|
||||
# ANSI color codes for log formatting (set to False to disable colors)
|
||||
# NOTE: colors disabled by default to make logging in CI/CD pipelines easier to read
|
||||
_ENABLE_LOG_COLORS = False
|
||||
|
||||
|
||||
class _Colors:
|
||||
RESET = "\033[0m" if _ENABLE_LOG_COLORS else ""
|
||||
BOLD = "\033[1m" if _ENABLE_LOG_COLORS else ""
|
||||
DIM = "\033[2m" if _ENABLE_LOG_COLORS else ""
|
||||
CYAN = "\033[36m" if _ENABLE_LOG_COLORS else ""
|
||||
MAGENTA = "\033[35m" if _ENABLE_LOG_COLORS else ""
|
||||
GREEN = "\033[32m" if _ENABLE_LOG_COLORS else ""
|
||||
YELLOW = "\033[33m" if _ENABLE_LOG_COLORS else ""
|
||||
ORANGE = "\033[38;5;208m" if _ENABLE_LOG_COLORS else ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemStats:
|
||||
"""Memory statistics snapshot for tracking CUDA memory usage."""
|
||||
|
||||
tot: float
|
||||
free: float
|
||||
resv: float
|
||||
alloc: float
|
||||
frag: float
|
||||
|
||||
def diff(self, other: "MemStats") -> "MemStats":
|
||||
"""Calculate the difference (self - other)."""
|
||||
return MemStats(
|
||||
tot=self.tot - other.tot,
|
||||
free=self.free - other.free,
|
||||
resv=self.resv - other.resv,
|
||||
alloc=self.alloc - other.alloc,
|
||||
frag=self.frag - other.frag,
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, float]:
|
||||
"""Convert to dictionary for serialization."""
|
||||
return {
|
||||
"tot": self.tot,
|
||||
"free": self.free,
|
||||
"resv": self.resv,
|
||||
"alloc": self.alloc,
|
||||
"frag": self.frag,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: Dict[str, float]) -> "MemStats":
|
||||
"""Create from dictionary."""
|
||||
return cls(tot=d["tot"], free=d["free"], resv=d["resv"], alloc=d["alloc"], frag=d["frag"])
|
||||
|
||||
|
||||
class TransformError(Exception):
|
||||
"""An exception raised when a transform fails."""
|
||||
@ -114,6 +167,11 @@ class TransformConfig(BaseModel):
|
||||
description="Whether this transform requires shape propagation before it is applied.",
|
||||
)
|
||||
|
||||
expect_mem_change: bool = Field(
|
||||
default=False,
|
||||
description="Whether this transform is expected to cause changes in CUDA memory stats.",
|
||||
)
|
||||
|
||||
|
||||
AutodeployMeta = Dict[str, Any]
|
||||
_UntypedInferenceOptimizerConfig = Dict[str, Any]
|
||||
@ -223,6 +281,7 @@ class BaseTransform(ABC):
|
||||
config: TransformConfig # overwrite type hint if other config cls is used in subclass!
|
||||
_autodeploy_meta_key: str = "_autodeploy"
|
||||
_history_key: str = "transform_history"
|
||||
_mem_history_key: str = "mem_history"
|
||||
_transform_key: str # Set by TransformRegistry.register() decorator
|
||||
|
||||
@classmethod
|
||||
@ -324,6 +383,9 @@ class BaseTransform(ABC):
|
||||
# show debug info for debug config
|
||||
ad_logger.debug(f"{t_name} config: {self.config}")
|
||||
|
||||
# capture memory stats at the start
|
||||
mem_pre = self._get_mem_stats(empty_cache=True)
|
||||
|
||||
# store some timing information
|
||||
elapsed_time_total = -time.time()
|
||||
elapsed_time_pre_cleanup = 0.0
|
||||
@ -340,23 +402,28 @@ class BaseTransform(ABC):
|
||||
self.config.requires_shape_prop,
|
||||
info.is_clean,
|
||||
info.has_valid_shapes,
|
||||
phase="pre",
|
||||
)
|
||||
elapsed_time_pre_cleanup += time.time()
|
||||
|
||||
# run the transform in a error-handling wrapper if desired
|
||||
elapsed_time_apply = -time.time()
|
||||
if self.config.skip_on_error:
|
||||
try:
|
||||
with self._apply_logging_context():
|
||||
self._log_info("applying transform...")
|
||||
if self.config.skip_on_error:
|
||||
try:
|
||||
mod, info_apply = self._apply_per_gm_or_whole_model(
|
||||
mod, cm, factory, shared_config
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = f"Transform {t_name} failed"
|
||||
ad_logger.warning(f"{error_msg}: {e}")
|
||||
info_apply = TransformInfo(skipped=True, num_matches=0)
|
||||
else:
|
||||
# handle this here normally to improve debugging and error message
|
||||
mod, info_apply = self._apply_per_gm_or_whole_model(
|
||||
mod, cm, factory, shared_config
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = f"Transform {t_name} failed"
|
||||
ad_logger.warning(f"{error_msg}: {e}")
|
||||
info_apply = TransformInfo(skipped=True, num_matches=0)
|
||||
else:
|
||||
# handle this here normally to improve debugging and error message
|
||||
mod, info_apply = self._apply_per_gm_or_whole_model(mod, cm, factory, shared_config)
|
||||
elapsed_time_apply += time.time()
|
||||
|
||||
# we cannot say it's clean if the previous wasn't clean even if this one is
|
||||
@ -371,31 +438,40 @@ class BaseTransform(ABC):
|
||||
self.config.run_shape_prop,
|
||||
info.is_clean,
|
||||
info.has_valid_shapes,
|
||||
phase="post",
|
||||
)
|
||||
elapsed_time_post_cleanup += time.time()
|
||||
|
||||
elapsed_time_total += time.time()
|
||||
|
||||
# capture memory stats at the end and log summary (only log if enabled)
|
||||
mem_post = self._get_mem_stats(empty_cache=True)
|
||||
if self.config.enabled:
|
||||
self._log_mem_summary(mem_pre, mem_post, self.config.expect_mem_change)
|
||||
|
||||
# log the result of the transform
|
||||
log_msgs = [
|
||||
f"enabled={self.config.enabled}",
|
||||
"skipped=True" if info.skipped else f"num_matches={info.num_matches}",
|
||||
f"is_clean={info.is_clean}",
|
||||
f"has_valid_shapes={info.has_valid_shapes}",
|
||||
]
|
||||
self._log_info(", ".join(log_msgs))
|
||||
log_msgs_timing = [
|
||||
f"elapsed time: total={elapsed_time_total:.3f}s",
|
||||
f"pre_cleanup={elapsed_time_pre_cleanup:.3f}s",
|
||||
f"apply={elapsed_time_apply:.3f}s",
|
||||
f"post_cleanup={elapsed_time_post_cleanup:.3f}s",
|
||||
]
|
||||
self._log_info(", ".join(log_msgs_timing))
|
||||
self._log_transform_summary(
|
||||
enabled=self.config.enabled,
|
||||
skipped=info.skipped,
|
||||
num_matches=info.num_matches,
|
||||
elapsed_total=elapsed_time_total,
|
||||
elapsed_pre=elapsed_time_pre_cleanup,
|
||||
elapsed_apply=elapsed_time_apply,
|
||||
elapsed_post=elapsed_time_post_cleanup,
|
||||
)
|
||||
ad_logger.debug(f"Model after {t_name}: {mod}")
|
||||
|
||||
# update + store new meta data
|
||||
# update + store new meta data (transform history and memory history)
|
||||
history[t_name] = info
|
||||
autodeploy_meta[self._history_key] = history
|
||||
|
||||
# store memory history
|
||||
mem_history: Dict[str, Dict[str, Dict[str, float]]] = autodeploy_meta.get(
|
||||
self._mem_history_key, {}
|
||||
)
|
||||
mem_history[t_name] = {"pre": mem_pre.to_dict(), "post": mem_post.to_dict()}
|
||||
autodeploy_meta[self._mem_history_key] = mem_history
|
||||
|
||||
self._set_autodeploy_meta(mod, autodeploy_meta)
|
||||
|
||||
# return the graph module
|
||||
@ -423,11 +499,161 @@ class BaseTransform(ABC):
|
||||
info = info & info_apply if info is not None else info_apply
|
||||
return mod, info
|
||||
|
||||
@final
|
||||
def _log_warning(self, *args: any):
|
||||
"""Log a warning message with the transform key."""
|
||||
ad_logger.warning(*args)
|
||||
|
||||
@final
|
||||
def _log_info(self, *args: any):
|
||||
"""Log a message with the transform key."""
|
||||
ad_logger.info(*args)
|
||||
|
||||
@final
|
||||
def _log_debug(self, *args: any):
|
||||
"""Log a message with the transform key."""
|
||||
ad_logger.debug(*args)
|
||||
|
||||
@contextmanager
|
||||
def _apply_logging_context(self):
|
||||
"""Context manager to add [APPLY] prefix to logs during transform execution."""
|
||||
original_log = ad_logger.log
|
||||
apply_label = "[APPLY]"
|
||||
|
||||
def _patched_log(severity, *msg):
|
||||
# Prepend [APPLY] after any existing prefix
|
||||
if msg:
|
||||
first = msg[0] if isinstance(msg[0], str) else str(msg[0])
|
||||
return original_log(severity, f"{apply_label} {first}", *msg[1:])
|
||||
return original_log(severity, apply_label)
|
||||
|
||||
ad_logger.log = _patched_log # type: ignore[assignment]
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
ad_logger.log = original_log # type: ignore[assignment]
|
||||
|
||||
@final
|
||||
def _log_cleanup_status(self, phase: str, action: str, reason: str = "") -> None:
|
||||
"""Log cleanup status with colored formatting.
|
||||
|
||||
Args:
|
||||
phase: "pre" or "post"
|
||||
action: "ran" or "skipped"
|
||||
reason: Description of what ran or why skipped
|
||||
"""
|
||||
label = f"{_Colors.CYAN}[{phase.upper()}-CLEANUP]{_Colors.RESET}"
|
||||
if action == "skipped":
|
||||
self._log_info(f"{label} {_Colors.DIM}skipped ({reason}){_Colors.RESET}")
|
||||
else:
|
||||
self._log_info(f"{label} {reason}")
|
||||
|
||||
@final
|
||||
def _log_transform_summary(
|
||||
self,
|
||||
enabled: bool,
|
||||
skipped: bool,
|
||||
num_matches: int,
|
||||
elapsed_total: float,
|
||||
elapsed_pre: float,
|
||||
elapsed_apply: float,
|
||||
elapsed_post: float,
|
||||
) -> None:
|
||||
"""Log transform summary with colored formatting.
|
||||
|
||||
Args:
|
||||
enabled: Whether the transform was enabled.
|
||||
skipped: Whether the transform was skipped.
|
||||
num_matches: Number of matches found.
|
||||
elapsed_total: Total elapsed time in seconds.
|
||||
elapsed_pre: Pre-cleanup elapsed time in seconds.
|
||||
elapsed_apply: Apply elapsed time in seconds.
|
||||
elapsed_post: Post-cleanup elapsed time in seconds.
|
||||
"""
|
||||
label = f"{_Colors.GREEN}[SUMMARY]{_Colors.RESET}"
|
||||
timing_str = (
|
||||
f"{elapsed_total:.3f}s "
|
||||
f"(pre={elapsed_pre:.3f}s, apply={elapsed_apply:.3f}s, post={elapsed_post:.3f}s)"
|
||||
)
|
||||
|
||||
if not enabled:
|
||||
self._log_info(f"{label} {_Colors.DIM}disabled{_Colors.RESET}")
|
||||
elif skipped:
|
||||
self._log_info(f"{label} {_Colors.DIM}skipped{_Colors.RESET} | time: {timing_str}")
|
||||
else:
|
||||
self._log_info(f"{label} matches={num_matches} | time: {timing_str}")
|
||||
|
||||
@final
|
||||
def _get_mem_stats(self, empty_cache: bool = True) -> MemStats:
|
||||
"""Get current CUDA memory statistics.
|
||||
|
||||
Args:
|
||||
empty_cache: Whether to empty the memory cache before getting the memory stats.
|
||||
|
||||
Returns:
|
||||
MemStats object with current memory values in GB.
|
||||
"""
|
||||
tot, free, resv, alloc, frag = get_mem_info(empty_cache=empty_cache, unit="GB")
|
||||
return MemStats(tot=tot, free=free, resv=resv, alloc=alloc, frag=frag)
|
||||
|
||||
@final
|
||||
def _log_mem_summary(self, pre: MemStats, post: MemStats, expect_mem_change: bool) -> None:
|
||||
"""Log memory summary with diff between pre and post stats.
|
||||
|
||||
Logs one of three cases:
|
||||
1. Expected mem change: info log, magenta color
|
||||
2. Unexpected mem change: warning log, yellow color
|
||||
3. No mem change: debug log, no colors
|
||||
|
||||
Args:
|
||||
pre: Memory stats captured before the transform.
|
||||
post: Memory stats captured after the transform.
|
||||
expect_mem_change: Whether this transform is expected to cause memory changes.
|
||||
"""
|
||||
diff = post.diff(pre)
|
||||
|
||||
# Threshold for detecting significant memory changes (in GB)
|
||||
mem_change_threshold = 0.005
|
||||
|
||||
# Check if there was a significant memory change
|
||||
has_mem_change = (
|
||||
abs(diff.resv) >= mem_change_threshold
|
||||
or abs(diff.alloc) >= mem_change_threshold
|
||||
or abs(diff.frag) >= mem_change_threshold
|
||||
)
|
||||
|
||||
def _fmt_val_with_delta(val: float, delta: float, color: str) -> str:
|
||||
"""Format value with optional delta in the specified color."""
|
||||
val_str = f"{val:6.2f}GB"
|
||||
if abs(delta) < mem_change_threshold:
|
||||
return val_str
|
||||
sign = "+" if delta > 0 else ""
|
||||
if color:
|
||||
return f"{val_str} {_Colors.BOLD}{color}({sign}{delta:.2f}GB){_Colors.RESET}"
|
||||
return f"{val_str} ({sign}{delta:.2f}GB)"
|
||||
|
||||
def _fmt_parts(color: str) -> str:
|
||||
"""Format all memory parts with the specified color for deltas."""
|
||||
parts = [
|
||||
f"free: {_fmt_val_with_delta(post.free, diff.free, color)}",
|
||||
f"resv: {_fmt_val_with_delta(post.resv, diff.resv, color)}",
|
||||
f"alloc: {_fmt_val_with_delta(post.alloc, diff.alloc, color)}",
|
||||
f"frag: {_fmt_val_with_delta(post.frag, diff.frag, color)}",
|
||||
]
|
||||
return " | ".join(parts)
|
||||
|
||||
if has_mem_change and expect_mem_change:
|
||||
# Case 1: Expected mem change - info log, magenta
|
||||
label = f"{_Colors.MAGENTA}[CUDA MEM DIFF (EXPECTED)]{_Colors.RESET}"
|
||||
self._log_info(f"{label} {_fmt_parts(_Colors.MAGENTA)}")
|
||||
elif has_mem_change and not expect_mem_change:
|
||||
# Case 2: Unexpected mem change - warning log, yellow
|
||||
label = f"{_Colors.YELLOW}[CUDA MEM DIFF (UNEXPECTED)]{_Colors.RESET}"
|
||||
self._log_warning(f"{label} {_fmt_parts(_Colors.YELLOW)}")
|
||||
else:
|
||||
# Case 3: No mem change - debug log, no colors
|
||||
self._log_debug(f"[CUDA MEM] {_fmt_parts('')}")
|
||||
|
||||
@final
|
||||
def _get_autodeploy_meta(self, mod: nn.Module) -> AutodeployMeta:
|
||||
"""Get the autodeploy metadata from the graphmodule."""
|
||||
@ -450,8 +676,9 @@ class BaseTransform(ABC):
|
||||
clean_shape: bool,
|
||||
is_clean: bool,
|
||||
has_valid_shapes: bool,
|
||||
phase: str,
|
||||
) -> TransformInfo:
|
||||
"""Run graph cleanup before the transform.
|
||||
"""Run graph cleanup before or after the transform.
|
||||
|
||||
Args:
|
||||
mod: The model to run cleanup on.
|
||||
@ -459,23 +686,28 @@ class BaseTransform(ABC):
|
||||
clean_shape: Whether we want clean shapes after the transform.
|
||||
is_clean: The current cleanup status.
|
||||
has_valid_shapes: The current shape propagation status.
|
||||
phase: The phase of cleanup ("pre" or "post").
|
||||
|
||||
Returns:
|
||||
An info object indicating the cleanup status after this function is called.
|
||||
"""
|
||||
# check if run cleanup depending on the config and info
|
||||
if clean_shape and not (is_clean and has_valid_shapes):
|
||||
self._log_info("running graph cleanup (with shape_prop)")
|
||||
self._log_cleanup_status(phase, "ran", "graph canonicalization + shape_prop")
|
||||
canonicalize_graph(mod)
|
||||
with lift_to_meta(mod) if placeholders_on_meta(mod) else nullcontext():
|
||||
run_shape_prop(mod)
|
||||
is_clean = True
|
||||
has_valid_shapes = True
|
||||
elif clean_graph and not is_clean:
|
||||
self._log_info("running graph cleanup (no shape_prop)")
|
||||
self._log_cleanup_status(phase, "ran", "graph canonicalization")
|
||||
canonicalize_graph(mod)
|
||||
is_clean = True
|
||||
has_valid_shapes = False
|
||||
elif not clean_graph and not clean_shape:
|
||||
self._log_cleanup_status(phase, "skipped", "disabled")
|
||||
else:
|
||||
self._log_cleanup_status(phase, "skipped", "graph already clean")
|
||||
|
||||
return TransformInfo(is_clean=is_clean, has_valid_shapes=has_valid_shapes)
|
||||
|
||||
|
||||
@ -46,6 +46,9 @@ class BuildModel(BaseTransform):
|
||||
# build the model
|
||||
model = factory.build_model(self.config.device)
|
||||
|
||||
# update the kv cache config
|
||||
cm.update_kv_cache_config(**factory.get_cache_config_updates())
|
||||
|
||||
# by convention, we say the model is always clean
|
||||
info = TransformInfo(skipped=False, num_matches=1, is_clean=True, has_valid_shapes=True)
|
||||
|
||||
|
||||
@ -21,13 +21,14 @@ import torch
|
||||
from torch._ops import OpOverloadPacket
|
||||
from torch.fx import GraphModule, Node
|
||||
|
||||
from .....llmapi.llm_args import KvCacheConfig
|
||||
from ...custom_ops.attention_interface import (
|
||||
AttentionDescriptor,
|
||||
AttentionLayout,
|
||||
AttentionRegistry,
|
||||
CacheConfig,
|
||||
CacheInitializerDict,
|
||||
MHACallable,
|
||||
ResourceHandler,
|
||||
ResourceHandlerDict,
|
||||
SequenceInfo,
|
||||
)
|
||||
from ...models.factory import ModelFactory
|
||||
@ -195,12 +196,30 @@ class DetectHiddenStatesForCapture(BaseTransform):
|
||||
return gm, info
|
||||
|
||||
|
||||
class HiddenStatesResourceHandler(ResourceHandler):
|
||||
"""A resource handler for hidden states."""
|
||||
|
||||
def __init__(self, hidden_size: int, dtype: torch.dtype) -> None:
|
||||
"""Initialize the HiddenStatesResourceHandler.
|
||||
|
||||
Args:
|
||||
hidden_size: The size of the hidden states resource.
|
||||
dtype: The dtype of the hidden states resource.
|
||||
"""
|
||||
self.hidden_size = hidden_size
|
||||
self.dtype = dtype
|
||||
|
||||
def allocate(self, sequence_info: SequenceInfo) -> torch.Tensor:
|
||||
return torch.empty(
|
||||
sequence_info.max_num_tokens,
|
||||
self.hidden_size,
|
||||
device=sequence_info.device,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
|
||||
@AttentionRegistry.register("cached_residual_add")
|
||||
class CachedResidualAdd(AttentionDescriptor):
|
||||
@classmethod
|
||||
def is_paged(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_attention_layout(cls) -> AttentionLayout:
|
||||
return "bsnd"
|
||||
@ -219,15 +238,12 @@ class CachedResidualAdd(AttentionDescriptor):
|
||||
|
||||
@classmethod
|
||||
def get_cache_initializers(
|
||||
cls, source_attn_node: Node, cache_config: CacheConfig
|
||||
) -> CacheInitializerDict:
|
||||
cls, source_attn_node: Node, cache_config: KvCacheConfig
|
||||
) -> ResourceHandlerDict:
|
||||
hidden_size = source_attn_node.meta["val"].shape[-1]
|
||||
hidden_type = source_attn_node.meta["val"].dtype
|
||||
|
||||
def _get_hidden_states_cache(si: SequenceInfo):
|
||||
return torch.empty(si.max_num_tokens, hidden_size, device=si.device, dtype=hidden_type)
|
||||
|
||||
return {"hidden_states_cache": _get_hidden_states_cache}
|
||||
return {"hidden_states_cache": HiddenStatesResourceHandler(hidden_size, dtype=hidden_type)}
|
||||
|
||||
@classmethod
|
||||
def get_standard_metadata_args(cls) -> List[str]:
|
||||
|
||||
@ -18,7 +18,7 @@
|
||||
|
||||
import inspect
|
||||
import operator
|
||||
from typing import Dict, List, Optional, Tuple, Type
|
||||
from typing import List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -28,17 +28,13 @@ from torch.fx import GraphModule, Node
|
||||
from ...custom_ops.attention_interface import (
|
||||
AttentionDescriptor,
|
||||
AttentionRegistry,
|
||||
CacheConfig,
|
||||
Constant,
|
||||
PrepareMetadataCallable,
|
||||
)
|
||||
from ...distributed.common import all_gather_object, get_world_size
|
||||
from ...distributed.common import is_initialized as is_distributed_initialized
|
||||
from ...models.factory import ModelFactory
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
from ...utils._graph import add_graph_input
|
||||
from ...utils.cuda_mem_tracker import get_mem_info_in_mb
|
||||
from ...utils.logger import ad_logger
|
||||
from ...utils.cuda_mem_tracker import get_mem_info
|
||||
from ...utils.node_utils import is_op
|
||||
from ..interface import (
|
||||
BaseTransform,
|
||||
@ -53,9 +49,6 @@ class InsertCachedAttentionConfig(TransformConfig):
|
||||
"""Configuration for the insert cached attention transform."""
|
||||
|
||||
backend: Optional[str] = Field(default=None, description="The attention backend to use.")
|
||||
cache_config: CacheConfig = Field(
|
||||
default_factory=CacheConfig, description="The custom cache configuration to use."
|
||||
)
|
||||
|
||||
|
||||
@TransformRegistry.register("insert_cached_attention")
|
||||
@ -154,7 +147,6 @@ class InsertCachedAttention(BaseTransform):
|
||||
meta_nodes_std: List[Node],
|
||||
meta_nodes_extra: List[Node],
|
||||
cache_nodes: List[Node],
|
||||
buffer_nodes: List[Node],
|
||||
constants: List[Constant],
|
||||
):
|
||||
"""Insert a cached attention node into the graph."""
|
||||
@ -166,7 +158,6 @@ class InsertCachedAttention(BaseTransform):
|
||||
*meta_nodes_std,
|
||||
*meta_nodes_extra,
|
||||
*cache_nodes,
|
||||
*buffer_nodes,
|
||||
*constants,
|
||||
),
|
||||
)
|
||||
@ -183,10 +174,6 @@ class InsertCachedAttention(BaseTransform):
|
||||
"""Replace uncached source attention node with corresponding cached attn node."""
|
||||
attn_descriptor = self.attn_descriptor
|
||||
|
||||
# run field-wise or to combine the cache config from the transform and the factory
|
||||
# the transform config takes precedence over the factory config
|
||||
cache_config = self.config.cache_config | factory.get_cache_config()
|
||||
|
||||
# Get all attention nodes and their info objects
|
||||
source_op = attn_descriptor.get_source_attention_op()
|
||||
|
||||
@ -199,10 +186,6 @@ class InsertCachedAttention(BaseTransform):
|
||||
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
|
||||
)
|
||||
|
||||
# Sanity check
|
||||
if cm.info.is_paged:
|
||||
assert attn_descriptor.is_paged(), "Paged sequence info requires paged attention op."
|
||||
|
||||
# get standard metadata nodes for all source attention nodes
|
||||
meta_nodes_std = self._process_metadata_std(gm, cm)
|
||||
|
||||
@ -212,8 +195,6 @@ class InsertCachedAttention(BaseTransform):
|
||||
# Register host-side prepare_metadata function for attention descriptor.
|
||||
self._process_metadata_host(cm)
|
||||
|
||||
buffer_in_lookup: Dict[str, Node] = {}
|
||||
|
||||
# replace fused attention node with attention node that has kv cache
|
||||
num_cached_attn_replacements = 0
|
||||
for idx, attn_node in enumerate(source_attn_nodes):
|
||||
@ -222,22 +203,13 @@ class InsertCachedAttention(BaseTransform):
|
||||
|
||||
# setup + store cache initializers and caches as input nodes
|
||||
cache_in_nodes = []
|
||||
for k, get_cache in attn_descriptor.get_cache_initializers(
|
||||
attn_node, cache_config
|
||||
for k, resource_handler in attn_descriptor.get_cache_initializers(
|
||||
attn_node, cm.kv_cache_config
|
||||
).items():
|
||||
k_indexed = f"{k}_{idx}"
|
||||
cm.add_cache(k_indexed, get_cache)
|
||||
cm.add_resource(k_indexed, resource_handler)
|
||||
cache_in_nodes.append(self._process_cache_node(gm, k_indexed))
|
||||
|
||||
# setup + store global buffer initializers and buffers as input nodes
|
||||
# NOTE: we have to check against existing keys to make sure nothing is registered twice...
|
||||
buffer_in_nodes = []
|
||||
for k, get_buffer in attn_descriptor.get_global_buffer_initializers(attn_node).items():
|
||||
if k not in buffer_in_lookup:
|
||||
cm.add_cache(k, get_buffer)
|
||||
buffer_in_lookup[k] = self._process_cache_node(gm, k)
|
||||
buffer_in_nodes.append(buffer_in_lookup[k]) # store buffer nodes for this op
|
||||
|
||||
# retrieve constants for attention_op
|
||||
constants = attn_descriptor.get_constants(attn_node)
|
||||
|
||||
@ -249,7 +221,6 @@ class InsertCachedAttention(BaseTransform):
|
||||
meta_nodes_std,
|
||||
meta_nodes_extra,
|
||||
cache_in_nodes,
|
||||
buffer_in_nodes,
|
||||
constants,
|
||||
)
|
||||
|
||||
@ -276,27 +247,15 @@ class InsertCachedMLAAttention(InsertCachedAttention):
|
||||
pass
|
||||
|
||||
|
||||
class ResizeKVCacheConfig(TransformConfig):
|
||||
"""Configuration for the resize kv cache transform."""
|
||||
|
||||
free_mem_ratio: float = Field(
|
||||
default=0.0, ge=0.0, le=1.0, description="The fraction of available memory to occupy."
|
||||
)
|
||||
|
||||
|
||||
@TransformRegistry.register("resize_kv_cache")
|
||||
class ResizeKVCache(BaseTransform):
|
||||
"""Inflate the kv cache to occupy the available GPU memory.
|
||||
"""Resize the KV cache to occupy available GPU memory.
|
||||
|
||||
free_mem_ratio specifies the fraction of available memory to occupy.
|
||||
This implements the two-phase approach:
|
||||
1. Run a forward pass to allocate intermediate memory (activations, workspaces, etc.)
|
||||
2. Call resize_kv_cache_manager() to recreate KVCacheManager with optimal capacity
|
||||
"""
|
||||
|
||||
config: ResizeKVCacheConfig
|
||||
|
||||
@classmethod
|
||||
def get_config_class(cls) -> Type[TransformConfig]:
|
||||
return ResizeKVCacheConfig
|
||||
|
||||
def _apply_to_full_model(
|
||||
self,
|
||||
mod: nn.Module,
|
||||
@ -304,100 +263,29 @@ class ResizeKVCache(BaseTransform):
|
||||
factory: ModelFactory,
|
||||
shared_config: SharedConfig,
|
||||
) -> Tuple[nn.Module, TransformInfo]:
|
||||
free_mem_ratio = self.config.free_mem_ratio
|
||||
|
||||
free_mem, total_mem = get_mem_info_in_mb(empty_cache=True)
|
||||
self._log_info(f"Free memory (MB): {free_mem}, Total memory (MB): {total_mem}")
|
||||
current_cache_size = cm.current_cache_size_bytes()
|
||||
current_kv_cache_size = getattr(cm, "current_kv_cache_size_bytes", None)
|
||||
current_kv_cache_size = (
|
||||
current_kv_cache_size() if callable(current_kv_cache_size) else current_cache_size
|
||||
)
|
||||
current_num_pages = cm.info.num_pages
|
||||
self._log_info(
|
||||
f"Current cache size (MB): {current_cache_size // 1024**2}, "
|
||||
f"Current num pages: {current_num_pages}"
|
||||
)
|
||||
if current_kv_cache_size != current_cache_size:
|
||||
self._log_info(
|
||||
f"Current KV-only cache size (MB): {current_kv_cache_size // 1024 // 1024}"
|
||||
)
|
||||
|
||||
if free_mem_ratio == 0.0:
|
||||
self._log_info(f"Skipping cache resize for {free_mem_ratio=}")
|
||||
# check if we need a resize or not
|
||||
if not cm.needs_resize():
|
||||
return mod, TransformInfo(
|
||||
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
|
||||
)
|
||||
|
||||
# TODO: the manual PyTorch workflow respects max_num_tokens if set and does _NOT_ resize
|
||||
# the cache in this case. Should we do the same here?
|
||||
|
||||
# Let's run a forward pass to get the memory usage
|
||||
# Run a forward pass to get the extra memory usage
|
||||
cm.info.set_max_num_tokens_sample()
|
||||
free_mem_pre, _ = get_mem_info_in_mb(empty_cache=True)
|
||||
self._log_info(f"Free memory before forward pass (MB): {free_mem_pre}")
|
||||
|
||||
# Reset peak memory stats to get the extra memory used during the forward pass
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
memory_allocated_before_forward_pass_mb = torch.cuda.memory_allocated() // 1024**2
|
||||
try:
|
||||
mod(**cm.named_args)
|
||||
except torch.OutOfMemoryError as e:
|
||||
ad_logger.error(
|
||||
self._log_info(
|
||||
f"OutOfMemoryError in forward pass while trying to resize the kv-cache:\n{e}"
|
||||
)
|
||||
raise e
|
||||
|
||||
peak_memory_during_forward_pass_mb = torch.cuda.max_memory_allocated() // 1024**2
|
||||
mem_used_during_forward_pass_mb = (
|
||||
peak_memory_during_forward_pass_mb - memory_allocated_before_forward_pass_mb
|
||||
)
|
||||
self._log_info(
|
||||
f"Peak memory uasge during forward pass (MB): {peak_memory_during_forward_pass_mb}"
|
||||
)
|
||||
self._log_info(
|
||||
f"Extra memory used during forward pass (MB): {mem_used_during_forward_pass_mb}"
|
||||
)
|
||||
# NOTE: use fragmented memory without empty cache (peak forward memory + fragmented memory)
|
||||
# as a proxy for the memory reserved for the forward pass. This is a rough estimate and
|
||||
# may not be accurate.
|
||||
*_, mem_reserved_for_forward = get_mem_info(empty_cache=False, unit="B")
|
||||
|
||||
# TODO (lucaslie): logic needs overhaul, too much going on. For now, this is just reverting
|
||||
# to the original logic. Full overhaulwill be done as part of #10013
|
||||
free_mem_post, _ = get_mem_info_in_mb(empty_cache=False)
|
||||
self._log_info(f"Free memory after forward pass (MB): {free_mem_post}")
|
||||
|
||||
memory_for_forward_pass = free_mem_pre - free_mem_post
|
||||
self._log_info(f"Memory for forward pass (MB): {memory_for_forward_pass}")
|
||||
|
||||
# Compute new pages using KV-only bytes to avoid SSM/conv inflating per-page cost
|
||||
# Reserve headroom to avoid OOM from other allocations (workspaces, cudagraph pools, etc.)
|
||||
reserve_mb = max(1024, (total_mem * 5) // 100) # at least 1 GiB or 5% of total
|
||||
available_mb = max(0, free_mem_post - reserve_mb)
|
||||
|
||||
new_kv_total_bytes = int(
|
||||
available_mb * 1024 * 1024 * free_mem_ratio + current_kv_cache_size
|
||||
)
|
||||
per_page_bytes = max(1, current_kv_cache_size // max(1, current_num_pages))
|
||||
new_num_pages = int(new_kv_total_bytes // per_page_bytes)
|
||||
|
||||
# Need to sync all the GPUs if distributed group is initialized
|
||||
log_msg = f"Using local new_num_pages: {new_num_pages}"
|
||||
if is_distributed_initialized():
|
||||
gathered_num_pages = [None] * get_world_size()
|
||||
all_gather_object(gathered_num_pages, new_num_pages)
|
||||
new_num_pages = min(gathered_num_pages)
|
||||
log_msg = f"After all_gather - new_num_pages: {new_num_pages}"
|
||||
|
||||
self._log_info(log_msg)
|
||||
cm.resize_cache(new_num_pages)
|
||||
|
||||
# Log the final cache size for performance measurement, do not remove this log.
|
||||
final_cache_size_bytes = cm.current_cache_size_bytes()
|
||||
final_cache_size_gb = final_cache_size_bytes / (1024**3) # Convert to GiB
|
||||
self._log_info(
|
||||
f"Final KV cache size after resize: {final_cache_size_gb:.2f} GiB ({new_num_pages} pages)"
|
||||
)
|
||||
|
||||
# Free memory
|
||||
torch.cuda.empty_cache()
|
||||
# Resize - KVCacheManager will compute optimal capacity based on free memory
|
||||
cm.resize_kv_cache_manager(mem_reserved_for_forward)
|
||||
|
||||
info = TransformInfo(
|
||||
skipped=False,
|
||||
@ -411,6 +299,13 @@ class ResizeKVCache(BaseTransform):
|
||||
|
||||
@TransformRegistry.register("initialize_cache")
|
||||
class InitializeCache(BaseTransform):
|
||||
"""Initialize KV caches using KVCacheManager.
|
||||
|
||||
Gets kv_cache_config from shared_config.ad_config and creates the KVCacheManager
|
||||
in estimation mode with conservative capacity. The ResizeKVCache transform will
|
||||
later recreate it with optimal capacity after measuring memory usage.
|
||||
"""
|
||||
|
||||
def _apply_to_full_model(
|
||||
self,
|
||||
mod: nn.Module,
|
||||
@ -418,7 +313,9 @@ class InitializeCache(BaseTransform):
|
||||
factory: ModelFactory,
|
||||
shared_config: SharedConfig,
|
||||
) -> Tuple[nn.Module, TransformInfo]:
|
||||
num_caches = cm.initialize_caches()
|
||||
# Initialize with estimation mode
|
||||
# This allows resize_kv_cache to recreate with correct capacity after measuring memory
|
||||
num_caches = cm.initialize_resources()
|
||||
self._log_info(f"Initialized {num_caches} caches for cached attention")
|
||||
|
||||
info = TransformInfo(
|
||||
|
||||
@ -163,8 +163,8 @@ def get_cached_attn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
# metadata+caches+buffers with name lookup set up during kvcache transform
|
||||
*[kwargs[k] for k in module._node_ref.meta["metadata_cache_buffer_keys"]],
|
||||
# metadata+caches with name lookup set up during kvcache transform
|
||||
*[kwargs[k] for k in module._node_ref.meta["metadata_cache_keys"]],
|
||||
# constants set up during kvcache transform
|
||||
*module._node_ref.meta["constants"],
|
||||
)
|
||||
@ -242,17 +242,11 @@ class HFReplaceCachedAttn(InsertCachedAttention):
|
||||
meta_nodes_std: List[Node],
|
||||
meta_nodes_extra: List[Node],
|
||||
cache_nodes: List[Node],
|
||||
buffer_nodes: List[Node],
|
||||
constants: List[Constant],
|
||||
):
|
||||
"""Here we now need to actually do the correct mapping of the cached attn nodes."""
|
||||
# store reference to metadata, caches, buffers, and constants for this attn node
|
||||
attn_node.meta["metadata_cache_buffer_keys"] = (
|
||||
*meta_nodes_std,
|
||||
*meta_nodes_extra,
|
||||
*cache_nodes,
|
||||
*buffer_nodes,
|
||||
)
|
||||
# store reference to metadata, caches, and constants for this attn node
|
||||
attn_node.meta["metadata_cache_keys"] = (*meta_nodes_std, *meta_nodes_extra, *cache_nodes)
|
||||
attn_node.meta["constants"] = constants
|
||||
|
||||
def _apply_to_full_model(
|
||||
|
||||
@ -8,6 +8,7 @@ from pydantic import Field
|
||||
from ...models.factory import ModelFactory
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
from ...utils._graph import move_to_device
|
||||
from ...utils.cuda_mem_tracker import bytes_to
|
||||
from ..interface import (
|
||||
BaseTransform,
|
||||
SharedConfig,
|
||||
@ -43,10 +44,11 @@ class LoadWeightsToDevice(BaseTransform):
|
||||
factory: ModelFactory,
|
||||
shared_config: SharedConfig,
|
||||
) -> Tuple[nn.Module, TransformInfo]:
|
||||
factory.load_or_random_init(
|
||||
mod,
|
||||
device=self.config.checkpoint_device or cm.device,
|
||||
)
|
||||
params_size = sum(p.numel() * p.element_size() for p in mod.parameters())
|
||||
total_size_GB = bytes_to(params_size, unit="GB")
|
||||
self._log_info(f"Estimated parameters memory: {total_size_GB:.2f} GB")
|
||||
|
||||
factory.load_or_random_init(mod, device=self.config.checkpoint_device or cm.device)
|
||||
move_to_device(mod, cm.device)
|
||||
|
||||
info = TransformInfo(skipped=False, num_matches=0, is_clean=True, has_valid_shapes=True)
|
||||
|
||||
@ -16,12 +16,15 @@
|
||||
|
||||
import gc
|
||||
from contextlib import contextmanager
|
||||
from typing import Tuple
|
||||
from typing import Literal, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from .logger import ad_logger
|
||||
|
||||
Number = Union[int, float]
|
||||
ByteUnit = Literal["B", "KB", "MB", "GB", "TB"]
|
||||
|
||||
|
||||
@contextmanager
|
||||
def cuda_memory_tracker(logger=ad_logger):
|
||||
@ -43,10 +46,38 @@ def cuda_memory_tracker(logger=ad_logger):
|
||||
logger.warning(f"Potential memory leak detected, leaked memory: {leaked} bytes")
|
||||
|
||||
|
||||
def get_mem_info_in_mb(empty_cache: bool = True) -> Tuple[int, int]:
|
||||
def bytes_to(bytes: int, *more_bytes: int, unit: ByteUnit) -> Union[Number, Tuple[Number, ...]]:
|
||||
units = {"KB": 1 << 10, "MB": 1 << 20, "GB": 1 << 30, "TB": 1 << 40}
|
||||
bytes_converted = (bytes,) + more_bytes
|
||||
unit = unit.upper()
|
||||
if unit != "B":
|
||||
bytes_converted = tuple(float(x) / units[unit.upper()] for x in bytes_converted)
|
||||
return bytes_converted if more_bytes else bytes_converted[0]
|
||||
|
||||
|
||||
def get_mem_info(
|
||||
empty_cache: bool = True, unit: ByteUnit = "B"
|
||||
) -> Tuple[Number, Number, Number, Number, Number]:
|
||||
"""Get the memory information of the current device.
|
||||
|
||||
Args:
|
||||
empty_cache: Whether to empty the memory cache.
|
||||
unit: The unit of the memory information. Defaults to bytes.
|
||||
|
||||
Returns:
|
||||
A tuple of the
|
||||
- total memory,
|
||||
- free memory,
|
||||
- reserved memory,
|
||||
- allocated memory,
|
||||
- fragmented memory
|
||||
in the specified unit.
|
||||
"""
|
||||
if empty_cache:
|
||||
# Clear the memory cache to get the exact free memory
|
||||
torch.cuda.empty_cache()
|
||||
free_mem, total_mem = torch.cuda.mem_get_info()
|
||||
MB = 1024**2
|
||||
return free_mem // MB, total_mem // MB
|
||||
res_mem = torch.cuda.memory_reserved()
|
||||
alloc_mem = torch.cuda.memory_allocated()
|
||||
frag_mem = res_mem - alloc_mem
|
||||
return bytes_to(total_mem, free_mem, res_mem, alloc_mem, frag_mem, unit=unit)
|
||||
|
||||
@ -193,6 +193,7 @@ class MambaCacheManager(BaseResourceManager):
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
@torch.inference_mode()
|
||||
def _prepare_mamba_cache_blocks(self, request_ids: List[int]):
|
||||
self.state_indices_list.clear()
|
||||
for r in request_ids:
|
||||
|
||||
@ -40,10 +40,10 @@ class TestLlama3_1_8B(LlmapiAccuracyTestHarness):
|
||||
# Set it explicitly here to 8192 which is the default in build_config.
|
||||
"max_num_tokens": 8192,
|
||||
"skip_loading_weights": False,
|
||||
"kv_cache_config": {
|
||||
"free_gpu_memory_fraction": 0.7
|
||||
},
|
||||
"transforms": {
|
||||
"resize_kv_cache": {
|
||||
"free_mem_ratio": 0.7
|
||||
},
|
||||
"compile_model": {
|
||||
"backend":
|
||||
"torch-cudagraph",
|
||||
@ -55,7 +55,7 @@ class TestLlama3_1_8B(LlmapiAccuracyTestHarness):
|
||||
if enable_chunked_prefill:
|
||||
config["enable_chunked_prefill"] = True
|
||||
config[
|
||||
"max_num_tokens"] = 512 # NOTE: must be > max(attn_page_size, max_batch_size)
|
||||
"max_num_tokens"] = 512 # NOTE: must be > max(tokens_per_block, max_batch_size)
|
||||
return config
|
||||
|
||||
def get_default_sampling_params(self):
|
||||
@ -110,7 +110,8 @@ class TestNemotronH(LlmapiAccuracyTestHarness):
|
||||
"trust_remote_code": True,
|
||||
# SSMs do not support cache reuse.
|
||||
"kv_cache_config": {
|
||||
"enable_block_reuse": False
|
||||
"enable_block_reuse": False,
|
||||
"free_gpu_memory_fraction": 0.7
|
||||
},
|
||||
# Keep max_batch_size as in the PyTorch test to avoid OOM
|
||||
"max_batch_size": 128,
|
||||
@ -120,9 +121,6 @@ class TestNemotronH(LlmapiAccuracyTestHarness):
|
||||
"max_num_tokens": 8192,
|
||||
"skip_loading_weights": False,
|
||||
"transforms": {
|
||||
"resize_kv_cache": {
|
||||
"free_mem_ratio": 0.7
|
||||
},
|
||||
"compile_model": {
|
||||
"backend": "torch-cudagraph",
|
||||
"cuda_graph_batch_sizes": [1, 2, 4, 8, 16, 32, 64, 128],
|
||||
@ -132,7 +130,7 @@ class TestNemotronH(LlmapiAccuracyTestHarness):
|
||||
if enable_chunked_prefill:
|
||||
config["enable_chunked_prefill"] = True
|
||||
config[
|
||||
"max_num_tokens"] = 512 # NOTE: must be > max(attn_page_size, max_batch_size)
|
||||
"max_num_tokens"] = 512 # NOTE: must be > max(tokens_per_block, max_batch_size)
|
||||
return config
|
||||
|
||||
def get_default_sampling_params(self):
|
||||
@ -169,7 +167,10 @@ class TestNemotronMOE(LlmapiAccuracyTestHarness):
|
||||
"trust_remote_code": True,
|
||||
# SSMs do not support cache reuse.
|
||||
"kv_cache_config": {
|
||||
"enable_block_reuse": False
|
||||
"enable_block_reuse": False,
|
||||
"free_gpu_memory_fraction": 0.7
|
||||
# NOTE: some accuracy benchmarks may require fp32 precision for mamba cache
|
||||
# "mamba_ssm_cache_dtype": "float32",
|
||||
},
|
||||
# Keep max_batch_size as in the PyTorch test to avoid OOM
|
||||
"max_batch_size": 128,
|
||||
@ -180,7 +181,6 @@ class TestNemotronMOE(LlmapiAccuracyTestHarness):
|
||||
"max_num_tokens": 8192,
|
||||
"skip_loading_weights": False,
|
||||
"compile_backend": "torch-cudagraph",
|
||||
"free_mem_ratio": 0.7,
|
||||
"cuda_graph_batch_sizes": [1, 2, 4, 8, 16, 32, 64, 128],
|
||||
"transforms": {
|
||||
"detect_sharding": {
|
||||
@ -191,12 +191,6 @@ class TestNemotronMOE(LlmapiAccuracyTestHarness):
|
||||
"stage": "compile",
|
||||
"enabled": True,
|
||||
},
|
||||
# NOTE: some accuracy benchmarks may require fp32 precision for mamba cache
|
||||
# "insert_cached_ssm_attention": {
|
||||
# "cache_config": {
|
||||
# "mamba_dtype": "float32",
|
||||
# },
|
||||
# },
|
||||
}
|
||||
}
|
||||
|
||||
@ -213,7 +207,8 @@ class TestNemotronMOE(LlmapiAccuracyTestHarness):
|
||||
kwargs = self.get_default_kwargs()
|
||||
# TODO: multi-stream MOE seems to increase the memory usage
|
||||
kwargs["max_batch_size"] = 32
|
||||
kwargs["free_mem_ratio"] = 0.4
|
||||
kwargs["kv_cache_config"] = {"free_gpu_memory_fraction": 0.4}
|
||||
sampling_params = self.get_default_sampling_params()
|
||||
with AutoDeployLLM(model=self.MODEL_PATH_BF16,
|
||||
tokenizer=self.MODEL_PATH_BF16,
|
||||
**kwargs) as llm:
|
||||
@ -279,7 +274,6 @@ class TestNemotronSuperV3(LlmapiAccuracyTestHarness):
|
||||
"trust_remote_code": True,
|
||||
"skip_loading_weights": False,
|
||||
"compile_backend": "torch-cudagraph",
|
||||
"free_mem_ratio": 0.9,
|
||||
"max_batch_size": 128,
|
||||
"max_seq_len": self.MAX_SEQ_LEN,
|
||||
"max_num_tokens": self.MAX_SEQ_LEN,
|
||||
|
||||
@ -576,7 +576,6 @@ class PerfTestConfig:
|
||||
extra: bool = False,
|
||||
# _autodeploy backend specific parameters
|
||||
ad_compile_backend: str = "torch-opt",
|
||||
free_mem_ratio: float = 0.9,
|
||||
extra_runtime: str = "trtllm",
|
||||
skip_loading_weights: bool = False,
|
||||
):
|
||||
@ -636,7 +635,6 @@ class PerfTestConfig:
|
||||
self.extra = extra
|
||||
# _autodeploy backend specific parameters
|
||||
self.ad_compile_backend = ad_compile_backend
|
||||
self.free_mem_ratio = free_mem_ratio
|
||||
self.extra_runtime = extra_runtime
|
||||
self.skip_loading_weights = skip_loading_weights
|
||||
# Just build engines
|
||||
@ -1422,9 +1420,6 @@ class MultiMetricPerfTest(AbstractPerfScriptTestClass):
|
||||
'compile_model': {
|
||||
'backend': self._config.ad_compile_backend
|
||||
},
|
||||
'resize_kv_cache': {
|
||||
'free_mem_ratio': self._config.free_mem_ratio
|
||||
},
|
||||
},
|
||||
'runtime': self._config.extra_runtime,
|
||||
'skip_loading_weights': self._config.skip_loading_weights
|
||||
|
||||
@ -18,11 +18,11 @@ from tensorrt_llm._torch.auto_deploy.transform.library.sharding import ShardingT
|
||||
|
||||
|
||||
class FakeFactory(ModelFactory):
|
||||
"""Dummy factory to pass cache_config for testing."""
|
||||
"""Dummy factory to pass cache_config_updates for testing."""
|
||||
|
||||
def __init__(self, model=None, cache_config=None, quant_config=None):
|
||||
def __init__(self, model=None, cache_config_updates=None, quant_config=None):
|
||||
self._model = model
|
||||
self.cache_config = cache_config
|
||||
self.cache_config_updates = cache_config_updates
|
||||
self.quant_config = quant_config
|
||||
|
||||
def build_model(self, device: str):
|
||||
@ -34,8 +34,8 @@ class FakeFactory(ModelFactory):
|
||||
def _load_checkpoint(self, model, device):
|
||||
return
|
||||
|
||||
def get_cache_config(self):
|
||||
return self.cache_config
|
||||
def get_cache_config_updates(self):
|
||||
return self.cache_config_updates
|
||||
|
||||
def get_quant_config(self):
|
||||
return self.quant_config
|
||||
|
||||
@ -511,7 +511,10 @@ def get_small_model_config(model_hub_id: str, **llm_args_kwargs) -> Dict[str, An
|
||||
|
||||
# add some defaults to llm_args
|
||||
llm_args["skip_loading_weights"] = True # No weight loading to speed up things
|
||||
llm_args["attn_page_size"] = 4 # Make sure paging is activated despite small max_tokens
|
||||
llm_args["kv_cache_config"] = {
|
||||
"tokens_per_block": 4, # Make sure paging is activated despite small max_tokens
|
||||
"free_gpu_memory_fraction": 0.0, # No resizing of the cache to keep the mem footprint small
|
||||
}
|
||||
llm_args["max_batch_size"] = 2 # Minimum batching to speed up things
|
||||
# update with custom llm_args kwargs
|
||||
llm_args.update(llm_args_kwargs)
|
||||
|
||||
@ -84,8 +84,7 @@ def test_flashinfer_attention_op_context(seq_length, n_heads, batch_size, dtype,
|
||||
)
|
||||
|
||||
# make sure planner is initialized
|
||||
workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device)
|
||||
_GlobalFlashInferPlanner.init_workspace(workspace)
|
||||
_GlobalFlashInferPlanner.reset(torch.device(device))
|
||||
|
||||
batch_indices, positions = flashinfer.get_batch_indices_positions(
|
||||
qo_indptr,
|
||||
@ -118,8 +117,6 @@ def test_flashinfer_attention_op_context(seq_length, n_heads, batch_size, dtype,
|
||||
# CACHES
|
||||
k_cache,
|
||||
v_cache,
|
||||
# BUFFERS
|
||||
workspace,
|
||||
# CONSTANTS
|
||||
None,
|
||||
1.0,
|
||||
@ -229,8 +226,7 @@ def test_flashinfer_attention_op_decode(
|
||||
)
|
||||
|
||||
# make sure planner is initialized
|
||||
workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device)
|
||||
_GlobalFlashInferPlanner.init_workspace(workspace)
|
||||
_GlobalFlashInferPlanner.reset(torch.device(device))
|
||||
|
||||
batch_indices, positions = flashinfer.get_batch_indices_positions(
|
||||
qo_indptr,
|
||||
@ -262,8 +258,6 @@ def test_flashinfer_attention_op_decode(
|
||||
# CACHES
|
||||
k_cache,
|
||||
v_cache,
|
||||
# BUFFERS
|
||||
workspace,
|
||||
# CONSTANTS
|
||||
None,
|
||||
1.0,
|
||||
@ -361,8 +355,7 @@ def test_flashinfer_attention_context_and_generate(
|
||||
)
|
||||
|
||||
# make sure planner is initialized
|
||||
workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device)
|
||||
_GlobalFlashInferPlanner.init_workspace(workspace)
|
||||
_GlobalFlashInferPlanner.reset(torch.device(device))
|
||||
|
||||
batch_indices, positions = flashinfer.get_batch_indices_positions(
|
||||
qo_indptr,
|
||||
@ -395,8 +388,6 @@ def test_flashinfer_attention_context_and_generate(
|
||||
# CACHES
|
||||
k_cache,
|
||||
v_cache,
|
||||
# BUFFERS
|
||||
workspace,
|
||||
# CONSTANTS
|
||||
None,
|
||||
1.0,
|
||||
@ -454,7 +445,7 @@ def test_flashinfer_attention_context_and_generate(
|
||||
v_3 = torch.randn(BATCH_SIZE, 1, N_HEADS * D_HEAD, dtype=DTYPE).to(device)
|
||||
|
||||
# Create FlashInferAttention class before calling the custom op
|
||||
_GlobalFlashInferPlanner.reset()
|
||||
_GlobalFlashInferPlanner.reset(torch.device(device))
|
||||
|
||||
batch_indices, positions = flashinfer.get_batch_indices_positions(
|
||||
qo_indptr,
|
||||
@ -485,8 +476,6 @@ def test_flashinfer_attention_context_and_generate(
|
||||
# CACHES
|
||||
k_cache,
|
||||
v_cache,
|
||||
# BUFFERS
|
||||
workspace,
|
||||
# CONSTANTS
|
||||
None,
|
||||
1.0,
|
||||
@ -586,8 +575,7 @@ def test_flashinfer_attention_op_context_input_pos(seq, batch_size, n_heads, dty
|
||||
)
|
||||
|
||||
# make sure planner is initialized
|
||||
workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device)
|
||||
_GlobalFlashInferPlanner.init_workspace(workspace)
|
||||
_GlobalFlashInferPlanner.reset(torch.device(device))
|
||||
|
||||
batch_indices, positions = flashinfer.get_batch_indices_positions(
|
||||
qo_indptr,
|
||||
@ -620,8 +608,6 @@ def test_flashinfer_attention_op_context_input_pos(seq, batch_size, n_heads, dty
|
||||
# CACHES
|
||||
k_cache,
|
||||
v_cache,
|
||||
# BUFFERS
|
||||
workspace,
|
||||
# CONSTANTS
|
||||
None,
|
||||
1.0,
|
||||
@ -748,8 +734,7 @@ def test_flashinfer_attention_with_fp8_cache(
|
||||
v_cache = v_cache.to(torch.float8_e4m3fn)
|
||||
|
||||
# make sure planner is initialized
|
||||
workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device)
|
||||
_GlobalFlashInferPlanner.init_workspace(workspace)
|
||||
_GlobalFlashInferPlanner.reset(torch.device(device))
|
||||
|
||||
batch_indices, positions = flashinfer.get_batch_indices_positions(
|
||||
qo_indptr,
|
||||
@ -782,8 +767,6 @@ def test_flashinfer_attention_with_fp8_cache(
|
||||
# CACHES
|
||||
k_cache,
|
||||
v_cache,
|
||||
# BUFFERS
|
||||
workspace,
|
||||
# CONSTANTS
|
||||
None,
|
||||
K_SCALE,
|
||||
@ -859,8 +842,7 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de
|
||||
seq_len_with_cache_host = (offsets + seq_len_tensor).cpu()
|
||||
|
||||
# make sure planner is initialized
|
||||
workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device)
|
||||
_GlobalFlashInferPlanner.init_workspace(workspace)
|
||||
_GlobalFlashInferPlanner.reset(torch.device(device))
|
||||
|
||||
batch_indices, positions = flashinfer.get_batch_indices_positions(
|
||||
qo_indptr,
|
||||
@ -891,8 +873,6 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de
|
||||
# CACHES
|
||||
k_cache,
|
||||
v_cache,
|
||||
# BUFFERS
|
||||
workspace,
|
||||
# CONSTANTS
|
||||
None,
|
||||
1.0,
|
||||
@ -956,7 +936,7 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de
|
||||
seq_len_with_cache2_host = (offsets2 + seq_len_tensor2).cpu()
|
||||
|
||||
# Create FlashInferAttention class before calling the custom op
|
||||
_GlobalFlashInferPlanner.reset()
|
||||
_GlobalFlashInferPlanner.reset(torch.device(device))
|
||||
|
||||
batch_indices, positions = flashinfer.get_batch_indices_positions(
|
||||
qo_indptr2,
|
||||
@ -987,8 +967,6 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de
|
||||
# CACHES
|
||||
k_cache,
|
||||
v_cache,
|
||||
# BUFFERS
|
||||
workspace,
|
||||
# CONSTANTS
|
||||
None,
|
||||
1.0,
|
||||
|
||||
@ -0,0 +1,237 @@
|
||||
"""Unit tests for ResourceHandler classes in attention_interface.py.
|
||||
|
||||
Tests the new resource handler abstraction for cache management:
|
||||
- PagedResourceHandler (for paged KV caches)
|
||||
- StateResourceHandler (for SSM/conv states)
|
||||
- UnpagedResourceHandler (for unpaged local caches)
|
||||
- AttentionDescriptor.resolve_cache_dtype()
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import (
|
||||
AttentionDescriptor,
|
||||
ManagedResourceHandler,
|
||||
PagedResourceHandler,
|
||||
ResourceHandler,
|
||||
SequenceInfo,
|
||||
StateResourceHandler,
|
||||
UnpagedResourceHandler,
|
||||
)
|
||||
|
||||
# =============================================================================
|
||||
# PagedResourceHandler Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_paged_handler_stores_token_shape_and_dtype():
|
||||
"""Verify PagedResourceHandler stores token_shape and dtype correctly."""
|
||||
handler = PagedResourceHandler(8, 64, dtype=torch.float16)
|
||||
assert handler.token_shape == (8, 64)
|
||||
assert handler.dtype == torch.float16
|
||||
|
||||
|
||||
def test_paged_handler_single_dimension_token_shape():
|
||||
"""Test PagedResourceHandler with single dimension token shape."""
|
||||
handler = PagedResourceHandler(128, dtype=torch.bfloat16)
|
||||
assert handler.token_shape == (128,)
|
||||
assert handler.dtype == torch.bfloat16
|
||||
|
||||
|
||||
def test_paged_handler_multi_dimension_token_shape():
|
||||
"""Test PagedResourceHandler with multiple dimension token shape."""
|
||||
handler = PagedResourceHandler(4, 8, 16, dtype=torch.float32)
|
||||
assert handler.token_shape == (4, 8, 16)
|
||||
assert handler.dtype == torch.float32
|
||||
|
||||
|
||||
def test_paged_handler_allocate_raises_not_implemented():
|
||||
"""Verify PagedResourceHandler.allocate() raises NotImplementedError."""
|
||||
handler = PagedResourceHandler(8, 64, dtype=torch.float16)
|
||||
seq_info = SequenceInfo(max_seq_len=128, max_batch_size=4)
|
||||
|
||||
with pytest.raises(NotImplementedError, match="Managed resources should not be allocated"):
|
||||
handler.allocate(seq_info)
|
||||
|
||||
|
||||
def test_paged_handler_is_resource_handler():
|
||||
"""Verify PagedResourceHandler is a ResourceHandler subclass."""
|
||||
handler = PagedResourceHandler(8, 64, dtype=torch.float16)
|
||||
assert isinstance(handler, ResourceHandler)
|
||||
|
||||
|
||||
def test_paged_handler_is_managed_resource():
|
||||
"""Verify PagedResourceHandler is a ManagedResourceHandler."""
|
||||
handler = PagedResourceHandler(8, 64, dtype=torch.float16)
|
||||
assert isinstance(handler, ManagedResourceHandler)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# StateResourceHandler Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_state_handler_stores_state_shape_and_dtype():
|
||||
"""Verify StateResourceHandler stores state_shape and dtype correctly."""
|
||||
handler = StateResourceHandler(4, 64, 16, dtype=torch.bfloat16)
|
||||
assert handler.state_shape == (4, 64, 16)
|
||||
assert handler.dtype == torch.bfloat16
|
||||
|
||||
|
||||
def test_state_handler_single_dimension_state_shape():
|
||||
"""Test StateResourceHandler with single dimension state shape."""
|
||||
handler = StateResourceHandler(256, dtype=torch.float16)
|
||||
assert handler.state_shape == (256,)
|
||||
assert handler.dtype == torch.float16
|
||||
|
||||
|
||||
def test_state_handler_conv_state_shape():
|
||||
"""Test StateResourceHandler with typical conv state shape [in_channels, kernel_size-1]."""
|
||||
handler = StateResourceHandler(512, 3, dtype=torch.bfloat16)
|
||||
assert handler.state_shape == (512, 3)
|
||||
assert handler.dtype == torch.bfloat16
|
||||
|
||||
|
||||
def test_state_handler_ssm_state_shape():
|
||||
"""Test StateResourceHandler with typical SSM state shape [num_heads, head_dim, ssm_state_size]."""
|
||||
handler = StateResourceHandler(4, 64, 16, dtype=torch.float32)
|
||||
assert handler.state_shape == (4, 64, 16)
|
||||
assert handler.dtype == torch.float32
|
||||
|
||||
|
||||
def test_state_handler_allocate_raises_not_implemented():
|
||||
"""Verify StateResourceHandler.allocate() raises NotImplementedError."""
|
||||
handler = StateResourceHandler(4, 64, 16, dtype=torch.bfloat16)
|
||||
seq_info = SequenceInfo(max_seq_len=128, max_batch_size=4)
|
||||
|
||||
with pytest.raises(NotImplementedError, match="Managed resources should not be allocated"):
|
||||
handler.allocate(seq_info)
|
||||
|
||||
|
||||
def test_state_handler_is_resource_handler():
|
||||
"""Verify StateResourceHandler is a ResourceHandler subclass."""
|
||||
handler = StateResourceHandler(4, 64, 16, dtype=torch.bfloat16)
|
||||
assert isinstance(handler, ResourceHandler)
|
||||
|
||||
|
||||
def test_state_handler_is_managed_resource():
|
||||
"""Verify StateResourceHandler is a ManagedResourceHandler."""
|
||||
handler = StateResourceHandler(4, 64, 16, dtype=torch.bfloat16)
|
||||
assert isinstance(handler, ManagedResourceHandler)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# UnpagedResourceHandler Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_unpaged_handler_stores_token_shape_and_dtype():
|
||||
"""Verify UnpagedResourceHandler stores token_shape and dtype correctly."""
|
||||
handler = UnpagedResourceHandler(8, 64, dtype=torch.float16)
|
||||
assert handler.token_shape == (8, 64)
|
||||
assert handler.dtype == torch.float16
|
||||
|
||||
|
||||
def test_unpaged_handler_single_dimension_token_shape():
|
||||
"""Test UnpagedResourceHandler with single dimension token shape."""
|
||||
handler = UnpagedResourceHandler(128, dtype=torch.bfloat16)
|
||||
assert handler.token_shape == (128,)
|
||||
assert handler.dtype == torch.bfloat16
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_kv_heads,head_dim,dtype",
|
||||
[
|
||||
(8, 64, torch.float16),
|
||||
(4, 128, torch.bfloat16),
|
||||
(1, 64, torch.float32),
|
||||
],
|
||||
)
|
||||
def test_unpaged_handler_allocate_returns_correct_shape(num_kv_heads, head_dim, dtype):
|
||||
"""Verify UnpagedResourceHandler.allocate() returns tensor with correct shape."""
|
||||
max_batch_size = 4
|
||||
max_seq_len = 128
|
||||
|
||||
handler = UnpagedResourceHandler(num_kv_heads, head_dim, dtype=dtype)
|
||||
seq_info = SequenceInfo(max_seq_len=max_seq_len, max_batch_size=max_batch_size)
|
||||
seq_info.to("cuda")
|
||||
|
||||
tensor = handler.allocate(seq_info)
|
||||
|
||||
expected_shape = (seq_info.max_num_state_slots, max_seq_len, num_kv_heads, head_dim)
|
||||
assert tensor.shape == expected_shape
|
||||
assert tensor.dtype == dtype
|
||||
assert tensor.device.type == "cuda"
|
||||
|
||||
|
||||
def test_unpaged_handler_allocate_correct_device():
|
||||
"""Verify UnpagedResourceHandler allocated tensor is on the correct device."""
|
||||
handler = UnpagedResourceHandler(8, 64, dtype=torch.float16)
|
||||
seq_info = SequenceInfo(max_seq_len=128, max_batch_size=4)
|
||||
seq_info.to("cuda")
|
||||
|
||||
tensor = handler.allocate(seq_info)
|
||||
assert tensor.device == seq_info.device
|
||||
|
||||
|
||||
def test_unpaged_handler_is_resource_handler():
|
||||
"""Verify UnpagedResourceHandler is a ResourceHandler subclass."""
|
||||
handler = UnpagedResourceHandler(8, 64, dtype=torch.float16)
|
||||
assert isinstance(handler, ResourceHandler)
|
||||
|
||||
|
||||
def test_unpaged_handler_is_not_managed_resource():
|
||||
"""Verify UnpagedResourceHandler is NOT a ManagedResourceHandler."""
|
||||
handler = UnpagedResourceHandler(8, 64, dtype=torch.float16)
|
||||
assert not isinstance(handler, ManagedResourceHandler)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# AttentionDescriptor.resolve_cache_dtype() Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_resolve_cache_dtype_auto_returns_fallback_float16():
|
||||
"""Test 'auto' returns the fallback dtype (float16)."""
|
||||
result = AttentionDescriptor.resolve_cache_dtype("auto", torch.float16)
|
||||
assert result == torch.float16
|
||||
|
||||
|
||||
def test_resolve_cache_dtype_auto_returns_fallback_bfloat16():
|
||||
"""Test 'auto' returns the fallback dtype (bfloat16)."""
|
||||
result = AttentionDescriptor.resolve_cache_dtype("auto", torch.bfloat16)
|
||||
assert result == torch.bfloat16
|
||||
|
||||
|
||||
def test_resolve_cache_dtype_auto_returns_fallback_float32():
|
||||
"""Test 'auto' returns the fallback dtype (float32)."""
|
||||
result = AttentionDescriptor.resolve_cache_dtype("auto", torch.float32)
|
||||
assert result == torch.float32
|
||||
|
||||
|
||||
def test_resolve_cache_dtype_explicit_float16():
|
||||
"""Test explicit 'float16' dtype string resolves correctly."""
|
||||
result = AttentionDescriptor.resolve_cache_dtype("float16", torch.bfloat16)
|
||||
assert result == torch.float16
|
||||
|
||||
|
||||
def test_resolve_cache_dtype_explicit_bfloat16():
|
||||
"""Test explicit 'bfloat16' dtype string resolves correctly."""
|
||||
result = AttentionDescriptor.resolve_cache_dtype("bfloat16", torch.float16)
|
||||
assert result == torch.bfloat16
|
||||
|
||||
|
||||
def test_resolve_cache_dtype_explicit_float32():
|
||||
"""Test explicit 'float32' dtype string resolves correctly."""
|
||||
result = AttentionDescriptor.resolve_cache_dtype("float32", torch.float16)
|
||||
assert result == torch.float32
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
torch.cuda.get_device_capability(0) < (8, 9), reason="FP8 requires compute capability >= 8.9"
|
||||
)
|
||||
def test_resolve_cache_dtype_explicit_fp8():
|
||||
"""Test explicit 'fp8' dtype string resolves correctly."""
|
||||
result = AttentionDescriptor.resolve_cache_dtype("fp8", torch.float16)
|
||||
assert result == torch.float8_e4m3fn
|
||||
@ -329,7 +329,7 @@ def test_trtllm_fused_moe_fp8(
|
||||
pytest.skip(f"top_k ({top_k}) cannot be greater than num_experts ({num_experts})")
|
||||
|
||||
assert itype == torch.float8_e4m3fn and wtype == torch.float8_e4m3fn, (
|
||||
"FP8 test only supports float8_e4m3fn"
|
||||
"FP8 test only supports torch.float8_e4m3fn"
|
||||
)
|
||||
assert otype == torch.bfloat16 or otype == torch.float16, (
|
||||
"FP8 test only supports bfloat16 or float16 output type"
|
||||
|
||||
@ -0,0 +1,724 @@
|
||||
"""Unit tests for CachedSequenceInterface in interface.py.
|
||||
|
||||
Tests the refactored CachedSequenceInterface which now:
|
||||
- Creates SequenceInfo internally with tokens_per_block from KvCacheConfig
|
||||
- Manages resources via KVCacheManager or MambaHybridCacheManager
|
||||
- Supports paged resources (KV caches) and state resources (SSM states)
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import (
|
||||
PagedResourceHandler,
|
||||
SequenceInfo,
|
||||
StateResourceHandler,
|
||||
UnpagedResourceHandler,
|
||||
)
|
||||
from tensorrt_llm._torch.auto_deploy.shim.interface import CachedSequenceInterface
|
||||
from tensorrt_llm._torch.pyexecutor.mamba_cache_manager import MambaHybridCacheManager
|
||||
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
|
||||
from tensorrt_llm.llmapi.llm_args import KvCacheConfig
|
||||
|
||||
# =============================================================================
|
||||
# Fixtures
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_kv_cache_config():
|
||||
"""KvCacheConfig with default settings."""
|
||||
return KvCacheConfig()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def paged_kv_cache_config():
|
||||
"""KvCacheConfig with paging enabled and no resizing."""
|
||||
return KvCacheConfig(
|
||||
tokens_per_block=32,
|
||||
max_tokens=1024,
|
||||
free_gpu_memory_fraction=0.0, # Disable dynamic resizing
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def resizable_kv_cache_config():
|
||||
"""KvCacheConfig with dynamic resizing enabled."""
|
||||
return KvCacheConfig(
|
||||
tokens_per_block=32,
|
||||
max_tokens=1024,
|
||||
free_gpu_memory_fraction=0.5,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Initialization Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_init_creates_sequence_info_with_tokens_per_block(paged_kv_cache_config):
|
||||
"""Verify SequenceInfo is created with correct tokens_per_block from KvCacheConfig."""
|
||||
interface = CachedSequenceInterface(
|
||||
max_seq_len=128,
|
||||
max_batch_size=4,
|
||||
device="cuda",
|
||||
kv_cache_config=paged_kv_cache_config,
|
||||
)
|
||||
|
||||
assert interface.info.tokens_per_block == paged_kv_cache_config.tokens_per_block
|
||||
assert interface.info.max_seq_len == 128
|
||||
assert interface.info.max_batch_size == 4
|
||||
|
||||
|
||||
def test_init_uses_default_kv_cache_config_when_not_provided():
|
||||
"""Verify default KvCacheConfig is used when not provided."""
|
||||
interface = CachedSequenceInterface(
|
||||
max_seq_len=128,
|
||||
max_batch_size=4,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
# Default KvCacheConfig should be created
|
||||
assert interface.kv_cache_config is not None
|
||||
# Default tokens_per_block is 64 in KvCacheConfig
|
||||
assert interface.info.tokens_per_block == interface.kv_cache_config.tokens_per_block
|
||||
|
||||
|
||||
def test_init_propagates_max_num_tokens():
|
||||
"""Verify max_num_tokens propagates to SequenceInfo."""
|
||||
max_num_tokens = 512
|
||||
interface = CachedSequenceInterface(
|
||||
max_seq_len=128,
|
||||
max_batch_size=4,
|
||||
max_num_tokens=max_num_tokens,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
assert interface.info.max_num_tokens == max_num_tokens
|
||||
|
||||
|
||||
def test_init_propagates_vocab_size_padded():
|
||||
"""Verify vocab_size_padded propagates to SequenceInfo."""
|
||||
vocab_size_padded = 32000
|
||||
interface = CachedSequenceInterface(
|
||||
max_seq_len=128,
|
||||
max_batch_size=4,
|
||||
vocab_size_padded=vocab_size_padded,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
assert interface.info.vocab_size_padded == vocab_size_padded
|
||||
|
||||
|
||||
def test_init_stores_device():
|
||||
"""Verify device is stored correctly."""
|
||||
interface = CachedSequenceInterface(
|
||||
max_seq_len=128,
|
||||
max_batch_size=4,
|
||||
device="cuda:0",
|
||||
)
|
||||
|
||||
assert interface.device == "cuda:0"
|
||||
|
||||
|
||||
def test_init_default_device_is_cuda():
|
||||
"""Verify default device is 'cuda' when not specified."""
|
||||
interface = CachedSequenceInterface(
|
||||
max_seq_len=128,
|
||||
max_batch_size=4,
|
||||
)
|
||||
|
||||
assert interface.device == "cuda"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Resource Registration Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_add_resource_paged_handler(paged_kv_cache_config):
|
||||
"""Test adding a PagedResourceHandler resource."""
|
||||
interface = CachedSequenceInterface(
|
||||
max_seq_len=128,
|
||||
max_batch_size=4,
|
||||
device="cuda",
|
||||
kv_cache_config=paged_kv_cache_config,
|
||||
)
|
||||
|
||||
handler = PagedResourceHandler(8, 64, dtype=torch.float16)
|
||||
interface.add_resource("k_cache_0", handler)
|
||||
|
||||
assert "k_cache_0" in interface._resource_lookup
|
||||
assert interface._resource_lookup["k_cache_0"] is handler
|
||||
|
||||
|
||||
def test_add_resource_state_handler(paged_kv_cache_config):
|
||||
"""Test adding a StateResourceHandler resource."""
|
||||
interface = CachedSequenceInterface(
|
||||
max_seq_len=128,
|
||||
max_batch_size=4,
|
||||
device="cuda",
|
||||
kv_cache_config=paged_kv_cache_config,
|
||||
)
|
||||
|
||||
handler = StateResourceHandler(4, 64, 16, dtype=torch.bfloat16)
|
||||
interface.add_resource("ssm_state_0", handler)
|
||||
|
||||
assert interface._resource_lookup["ssm_state_0"] is handler
|
||||
|
||||
|
||||
def test_add_resource_unpaged_handler(paged_kv_cache_config):
|
||||
"""Test adding an UnpagedResourceHandler resource."""
|
||||
interface = CachedSequenceInterface(
|
||||
max_seq_len=128,
|
||||
max_batch_size=4,
|
||||
device="cuda",
|
||||
kv_cache_config=paged_kv_cache_config,
|
||||
)
|
||||
|
||||
handler = UnpagedResourceHandler(8, 64, dtype=torch.float16)
|
||||
interface.add_resource("unpaged_cache", handler)
|
||||
|
||||
assert "unpaged_cache" in interface._resource_lookup
|
||||
assert interface._resource_lookup["unpaged_cache"] is handler
|
||||
|
||||
|
||||
def test_add_multiple_resources(paged_kv_cache_config):
|
||||
"""Test adding multiple resources of different types."""
|
||||
interface = CachedSequenceInterface(
|
||||
max_seq_len=128,
|
||||
max_batch_size=4,
|
||||
device="cuda",
|
||||
kv_cache_config=paged_kv_cache_config,
|
||||
)
|
||||
|
||||
k_handler = PagedResourceHandler(8, 64, dtype=torch.float16)
|
||||
v_handler = PagedResourceHandler(8, 64, dtype=torch.float16)
|
||||
ssm_handler = StateResourceHandler(4, 64, 16, dtype=torch.bfloat16)
|
||||
|
||||
interface.add_resource("k_cache_0", k_handler)
|
||||
interface.add_resource("v_cache_0", v_handler)
|
||||
interface.add_resource("ssm_state_0", ssm_handler)
|
||||
|
||||
assert len(interface._resource_lookup) == 3
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Resource Initialization Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_initialize_resources_paged_only_creates_kv_cache_manager(paged_kv_cache_config):
|
||||
"""Test paged-only resources create KVCacheManager (not MambaHybridCacheManager)."""
|
||||
interface = CachedSequenceInterface(
|
||||
max_seq_len=128,
|
||||
max_batch_size=4,
|
||||
device="cuda",
|
||||
kv_cache_config=paged_kv_cache_config,
|
||||
)
|
||||
|
||||
# Add only paged resources
|
||||
interface.add_resource("k_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16))
|
||||
interface.add_resource("v_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16))
|
||||
|
||||
num_caches = interface.initialize_resources()
|
||||
|
||||
assert num_caches == 2
|
||||
assert isinstance(interface.kv_cache_manager, KVCacheManager)
|
||||
assert not isinstance(interface.kv_cache_manager, MambaHybridCacheManager)
|
||||
|
||||
|
||||
def test_initialize_resources_mixed_creates_mamba_hybrid_cache_manager(paged_kv_cache_config):
|
||||
"""Test mixed paged + state resources create MambaHybridCacheManager."""
|
||||
interface = CachedSequenceInterface(
|
||||
max_seq_len=128,
|
||||
max_batch_size=4,
|
||||
device="cuda",
|
||||
kv_cache_config=paged_kv_cache_config,
|
||||
)
|
||||
|
||||
# Add paged and state resources
|
||||
interface.add_resource("k_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16))
|
||||
interface.add_resource("v_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16))
|
||||
interface.add_resource("ssm_state_0", StateResourceHandler(4, 64, 16, dtype=torch.bfloat16))
|
||||
|
||||
num_caches = interface.initialize_resources()
|
||||
|
||||
assert num_caches == 3
|
||||
assert isinstance(interface.kv_cache_manager, MambaHybridCacheManager)
|
||||
|
||||
|
||||
def test_initialize_resources_creates_cache_views_with_correct_shape(paged_kv_cache_config):
|
||||
"""Verify cache views are created with correct shapes."""
|
||||
interface = CachedSequenceInterface(
|
||||
max_seq_len=128,
|
||||
max_batch_size=4,
|
||||
device="cuda",
|
||||
kv_cache_config=paged_kv_cache_config,
|
||||
)
|
||||
|
||||
num_kv_heads = 8
|
||||
head_dim = 64
|
||||
interface.add_resource(
|
||||
"k_cache_0", PagedResourceHandler(num_kv_heads, head_dim, dtype=torch.float16)
|
||||
)
|
||||
interface.add_resource(
|
||||
"v_cache_0", PagedResourceHandler(num_kv_heads, head_dim, dtype=torch.float16)
|
||||
)
|
||||
|
||||
interface.initialize_resources()
|
||||
|
||||
# Check cache views exist
|
||||
assert "k_cache_0" in interface._caches
|
||||
assert "v_cache_0" in interface._caches
|
||||
|
||||
# Check shapes: [num_blocks, tokens_per_block, num_kv_heads, head_dim]
|
||||
k_cache = interface._caches["k_cache_0"]
|
||||
assert k_cache is not None
|
||||
assert k_cache.shape[1] == paged_kv_cache_config.tokens_per_block
|
||||
assert k_cache.shape[2] == num_kv_heads
|
||||
assert k_cache.shape[3] == head_dim
|
||||
assert k_cache.dtype == torch.float16
|
||||
|
||||
|
||||
def test_initialize_resources_creates_state_views_with_correct_shape(paged_kv_cache_config):
|
||||
"""Verify state views are created with correct shapes for MambaHybridCacheManager."""
|
||||
interface = CachedSequenceInterface(
|
||||
max_seq_len=128,
|
||||
max_batch_size=4,
|
||||
device="cuda",
|
||||
kv_cache_config=paged_kv_cache_config,
|
||||
)
|
||||
|
||||
num_heads = 4
|
||||
head_dim = 64
|
||||
ssm_state_size = 16
|
||||
interface.add_resource("k_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16))
|
||||
interface.add_resource(
|
||||
"ssm_state_0",
|
||||
StateResourceHandler(num_heads, head_dim, ssm_state_size, dtype=torch.bfloat16),
|
||||
)
|
||||
|
||||
interface.initialize_resources()
|
||||
|
||||
# Check state view exists
|
||||
ssm_cache = interface._caches["ssm_state_0"]
|
||||
assert ssm_cache is not None
|
||||
# Shape: [num_states, num_heads, head_dim, ssm_state_size]
|
||||
assert ssm_cache.shape[1] == num_heads
|
||||
assert ssm_cache.shape[2] == head_dim
|
||||
assert ssm_cache.shape[3] == ssm_state_size
|
||||
assert ssm_cache.dtype == torch.bfloat16
|
||||
|
||||
|
||||
def test_initialize_resources_unpaged_allocated_locally(paged_kv_cache_config):
|
||||
"""Verify UnpagedResourceHandler resources are allocated locally (not via cache manager)."""
|
||||
interface = CachedSequenceInterface(
|
||||
max_seq_len=128,
|
||||
max_batch_size=4,
|
||||
device="cuda",
|
||||
kv_cache_config=paged_kv_cache_config,
|
||||
)
|
||||
|
||||
num_kv_heads = 8
|
||||
head_dim = 64
|
||||
interface.add_resource(
|
||||
"unpaged_cache", UnpagedResourceHandler(num_kv_heads, head_dim, dtype=torch.float16)
|
||||
)
|
||||
|
||||
interface.initialize_resources()
|
||||
|
||||
# Check unpaged cache was allocated
|
||||
assert "unpaged_cache" in interface._caches
|
||||
unpaged_cache = interface._caches["unpaged_cache"]
|
||||
assert unpaged_cache is not None
|
||||
# Shape: [max_batch_size + 1, max_seq_len, num_kv_heads, head_dim]
|
||||
assert unpaged_cache.shape == (4 + 1, 128, num_kv_heads, head_dim)
|
||||
|
||||
|
||||
def test_is_paged_returns_true_for_paged_only(paged_kv_cache_config):
|
||||
"""Test is_paged() returns True when all resources are paged."""
|
||||
interface = CachedSequenceInterface(
|
||||
max_seq_len=128,
|
||||
max_batch_size=4,
|
||||
device="cuda",
|
||||
kv_cache_config=paged_kv_cache_config,
|
||||
)
|
||||
|
||||
interface.add_resource("k_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16))
|
||||
interface.add_resource("v_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16))
|
||||
interface.initialize_resources()
|
||||
|
||||
assert interface.is_paged() is True
|
||||
|
||||
|
||||
def test_is_paged_returns_false_for_hybrid(paged_kv_cache_config):
|
||||
"""Test is_paged() returns False when state resources exist."""
|
||||
interface = CachedSequenceInterface(
|
||||
max_seq_len=128,
|
||||
max_batch_size=4,
|
||||
device="cuda",
|
||||
kv_cache_config=paged_kv_cache_config,
|
||||
)
|
||||
|
||||
interface.add_resource("k_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16))
|
||||
interface.add_resource("ssm_state_0", StateResourceHandler(4, 64, 16, dtype=torch.bfloat16))
|
||||
interface.initialize_resources()
|
||||
|
||||
assert interface.is_paged() is False
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# KV Cache Resize Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_needs_resize_returns_false_when_fraction_is_zero(paged_kv_cache_config):
|
||||
"""Test needs_resize() returns False when free_gpu_memory_fraction is 0.0."""
|
||||
interface = CachedSequenceInterface(
|
||||
max_seq_len=128,
|
||||
max_batch_size=4,
|
||||
device="cuda",
|
||||
kv_cache_config=paged_kv_cache_config,
|
||||
)
|
||||
|
||||
interface.add_resource("k_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16))
|
||||
interface.initialize_resources()
|
||||
|
||||
assert interface.needs_resize() is False
|
||||
|
||||
|
||||
def test_needs_resize_returns_true_when_fraction_is_positive(resizable_kv_cache_config):
|
||||
"""Test needs_resize() returns True when free_gpu_memory_fraction is positive."""
|
||||
interface = CachedSequenceInterface(
|
||||
max_seq_len=128,
|
||||
max_batch_size=4,
|
||||
device="cuda",
|
||||
kv_cache_config=resizable_kv_cache_config,
|
||||
)
|
||||
|
||||
interface.add_resource("k_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16))
|
||||
interface.initialize_resources()
|
||||
|
||||
assert interface.needs_resize() is True
|
||||
|
||||
|
||||
def test_resize_kv_cache_manager_skipped_when_not_needed(paged_kv_cache_config):
|
||||
"""Test resize_kv_cache_manager() does nothing when resize not needed."""
|
||||
interface = CachedSequenceInterface(
|
||||
max_seq_len=128,
|
||||
max_batch_size=4,
|
||||
device="cuda",
|
||||
kv_cache_config=paged_kv_cache_config,
|
||||
)
|
||||
|
||||
interface.add_resource("k_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16))
|
||||
interface.initialize_resources()
|
||||
|
||||
# Get initial state
|
||||
initial_manager = interface.kv_cache_manager
|
||||
|
||||
# Resize should be a no-op
|
||||
interface.resize_kv_cache_manager()
|
||||
|
||||
# Manager should be the same object (no recreation)
|
||||
assert interface.kv_cache_manager is initial_manager
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Shutdown and Cleanup Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_shutdown_clears_caches(paged_kv_cache_config):
|
||||
"""Test shutdown() clears all caches."""
|
||||
interface = CachedSequenceInterface(
|
||||
max_seq_len=128,
|
||||
max_batch_size=4,
|
||||
device="cuda",
|
||||
kv_cache_config=paged_kv_cache_config,
|
||||
)
|
||||
|
||||
interface.add_resource("k_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16))
|
||||
interface.add_resource("v_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16))
|
||||
interface.initialize_resources()
|
||||
|
||||
assert len(interface._caches) == 2
|
||||
|
||||
interface.shutdown()
|
||||
|
||||
assert len(interface._caches) == 0
|
||||
|
||||
|
||||
def test_clear_cache_views_sets_views_to_none(paged_kv_cache_config):
|
||||
"""Test _clear_cache_views() sets paged and state cache views to None."""
|
||||
interface = CachedSequenceInterface(
|
||||
max_seq_len=128,
|
||||
max_batch_size=4,
|
||||
device="cuda",
|
||||
kv_cache_config=paged_kv_cache_config,
|
||||
)
|
||||
|
||||
interface.add_resource("k_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16))
|
||||
interface.add_resource("ssm_state_0", StateResourceHandler(4, 64, 16, dtype=torch.bfloat16))
|
||||
interface.initialize_resources()
|
||||
|
||||
# Manually call _clear_cache_views
|
||||
interface._clear_cache_views()
|
||||
|
||||
# Paged and state caches should be None
|
||||
assert interface._caches["k_cache_0"] is None
|
||||
assert interface._caches["ssm_state_0"] is None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Configuration Update Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_update_kv_cache_config_valid_field():
|
||||
"""Test update_kv_cache_config() with valid field updates."""
|
||||
interface = CachedSequenceInterface(
|
||||
max_seq_len=128,
|
||||
max_batch_size=4,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
interface.update_kv_cache_config(tokens_per_block=64)
|
||||
|
||||
assert interface.kv_cache_config.tokens_per_block == 64
|
||||
|
||||
|
||||
def test_update_kv_cache_config_multiple_fields():
|
||||
"""Test update_kv_cache_config() with multiple field updates."""
|
||||
interface = CachedSequenceInterface(
|
||||
max_seq_len=128,
|
||||
max_batch_size=4,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
interface.update_kv_cache_config(
|
||||
tokens_per_block=64,
|
||||
max_tokens=2048,
|
||||
free_gpu_memory_fraction=0.8,
|
||||
)
|
||||
|
||||
assert interface.kv_cache_config.tokens_per_block == 64
|
||||
assert interface.kv_cache_config.max_tokens == 2048
|
||||
assert interface.kv_cache_config.free_gpu_memory_fraction == 0.8
|
||||
|
||||
|
||||
def test_update_kv_cache_config_invalid_field_raises():
|
||||
"""Test update_kv_cache_config() raises ValueError for invalid fields."""
|
||||
interface = CachedSequenceInterface(
|
||||
max_seq_len=128,
|
||||
max_batch_size=4,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid KVCacheConfig field"):
|
||||
interface.update_kv_cache_config(invalid_field=123)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# named_args and args Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_named_args_includes_sequence_info_and_caches(paged_kv_cache_config):
|
||||
"""Verify named_args includes both SequenceInfo args and caches."""
|
||||
interface = CachedSequenceInterface(
|
||||
max_seq_len=128,
|
||||
max_batch_size=4,
|
||||
device="cuda",
|
||||
kv_cache_config=paged_kv_cache_config,
|
||||
)
|
||||
|
||||
interface.add_resource("k_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16))
|
||||
interface.initialize_resources()
|
||||
|
||||
named_args = interface.named_args
|
||||
|
||||
# Should contain sequence info args
|
||||
assert "input_ids" in named_args
|
||||
assert "position_ids" in named_args
|
||||
|
||||
# Should contain cache
|
||||
assert "k_cache_0" in named_args
|
||||
|
||||
|
||||
def test_args_returns_tuple_of_tensors(paged_kv_cache_config):
|
||||
"""Verify args returns a tuple of tensors."""
|
||||
interface = CachedSequenceInterface(
|
||||
max_seq_len=128,
|
||||
max_batch_size=4,
|
||||
device="cuda",
|
||||
kv_cache_config=paged_kv_cache_config,
|
||||
)
|
||||
|
||||
interface.add_resource("k_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16))
|
||||
interface.initialize_resources()
|
||||
|
||||
args = interface.args
|
||||
|
||||
assert isinstance(args, tuple)
|
||||
assert all(isinstance(a, torch.Tensor) for a in args)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# to() method Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_to_moves_sequence_info(paged_kv_cache_config):
|
||||
"""Verify to() moves SequenceInfo to the target device."""
|
||||
interface = CachedSequenceInterface(
|
||||
max_seq_len=128,
|
||||
max_batch_size=4,
|
||||
device="cpu",
|
||||
kv_cache_config=paged_kv_cache_config,
|
||||
)
|
||||
|
||||
interface.to("cuda")
|
||||
|
||||
assert interface.info.device.type == "cuda"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# SequenceInfo API Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_sequence_info_tokens_per_block_from_constructor():
|
||||
"""Verify tokens_per_block is set correctly from constructor."""
|
||||
seq_info = SequenceInfo(max_seq_len=128, max_batch_size=4, tokens_per_block=32)
|
||||
assert seq_info.tokens_per_block == 32
|
||||
|
||||
|
||||
def test_sequence_info_tokens_per_block_defaults_to_max_seq_len():
|
||||
"""Verify tokens_per_block defaults to max_seq_len when not provided."""
|
||||
seq_info = SequenceInfo(max_seq_len=128, max_batch_size=4)
|
||||
assert seq_info.tokens_per_block == 128
|
||||
|
||||
|
||||
def test_sequence_info_estimate_cache_tokens_per_forward():
|
||||
"""Test estimate_cache_tokens_per_forward() calculation."""
|
||||
# With max_num_tokens=64, max_batch_size=4, tokens_per_block=16
|
||||
# seq_len = ceil(64/4) = 16
|
||||
# num_blocks_per_seq = ceil(16/16) = 1
|
||||
# num_blocks_total = 1 * 4 = 4
|
||||
# result = 4 * 16 = 64
|
||||
seq_info = SequenceInfo(
|
||||
max_seq_len=128,
|
||||
max_batch_size=4,
|
||||
tokens_per_block=16,
|
||||
max_num_tokens=64,
|
||||
)
|
||||
|
||||
result = seq_info.estimate_cache_tokens_per_forward()
|
||||
|
||||
# Expected: ceil(64/4) = 16 tokens per seq
|
||||
# ceil(16/16) = 1 block per seq
|
||||
# 1 * 4 = 4 blocks total
|
||||
# 4 * 16 = 64 tokens
|
||||
assert result == 64
|
||||
|
||||
|
||||
def test_sequence_info_estimate_cache_tokens_per_forward_with_overflow():
|
||||
"""Test estimate_cache_tokens_per_forward() with sequence overflow into extra blocks."""
|
||||
# With max_num_tokens=100, max_batch_size=4, tokens_per_block=16
|
||||
# seq_len = ceil(100/4) = 25
|
||||
# num_blocks_per_seq = ceil(25/16) = 2
|
||||
# num_blocks_total = 2 * 4 = 8
|
||||
# result = 8 * 16 = 128
|
||||
seq_info = SequenceInfo(
|
||||
max_seq_len=128,
|
||||
max_batch_size=4,
|
||||
tokens_per_block=16,
|
||||
max_num_tokens=100,
|
||||
)
|
||||
|
||||
result = seq_info.estimate_cache_tokens_per_forward()
|
||||
assert result == 128
|
||||
|
||||
|
||||
def test_sequence_info_estimate_cache_loc_capacity_no_resize():
|
||||
"""Test estimate_cache_loc_capacity() when capacity is sufficient."""
|
||||
seq_info = SequenceInfo(
|
||||
max_seq_len=128,
|
||||
max_batch_size=4,
|
||||
tokens_per_block=32,
|
||||
max_num_tokens=256,
|
||||
)
|
||||
|
||||
initial_capacity = seq_info._input_buffer.get_capacity("cache_loc")
|
||||
|
||||
# Request a small capacity that should already be available
|
||||
seq_info.estimate_cache_loc_capacity(num_blocks=4)
|
||||
|
||||
# Capacity should not have changed if it was already sufficient
|
||||
if initial_capacity >= 4 * 4 + 1: # num_blocks * max_batch_size + 1
|
||||
assert seq_info._input_buffer.get_capacity("cache_loc") == initial_capacity
|
||||
|
||||
|
||||
def test_sequence_info_estimate_cache_loc_capacity_resizes():
|
||||
"""Test estimate_cache_loc_capacity() resizes buffer when needed."""
|
||||
seq_info = SequenceInfo(
|
||||
max_seq_len=128,
|
||||
max_batch_size=4,
|
||||
tokens_per_block=32,
|
||||
max_num_tokens=128, # Small to have small initial capacity
|
||||
)
|
||||
|
||||
initial_capacity = seq_info._input_buffer.get_capacity("cache_loc")
|
||||
|
||||
# Request a large capacity
|
||||
large_num_blocks = 1000
|
||||
seq_info.estimate_cache_loc_capacity(num_blocks=large_num_blocks)
|
||||
|
||||
expected_capacity = large_num_blocks * 4 + 1 # num_blocks * max_batch_size + 1
|
||||
if expected_capacity > initial_capacity:
|
||||
assert seq_info._input_buffer.get_capacity("cache_loc") >= expected_capacity
|
||||
|
||||
|
||||
def test_sequence_info_last_page_len_uses_tokens_per_block():
|
||||
"""Verify nest_sequences calculates last_page_len using tokens_per_block."""
|
||||
seq_info = SequenceInfo(
|
||||
max_seq_len=128,
|
||||
max_batch_size=4,
|
||||
tokens_per_block=16,
|
||||
)
|
||||
|
||||
# Set up a sequence with 25 tokens
|
||||
# last_page_len = (25 - 1) % 16 + 1 = 24 % 16 + 1 = 8 + 1 = 9
|
||||
input_ids = [[1] * 25]
|
||||
seq_info.nest_sequences(
|
||||
input_ids,
|
||||
input_pos=0,
|
||||
cache_loc=[0, 1], # 2 pages
|
||||
pages_per_seq=[2],
|
||||
)
|
||||
|
||||
expected_last_page_len = (25 - 1) % 16 + 1
|
||||
assert seq_info._args_list["last_page_len"][0] == expected_last_page_len
|
||||
|
||||
|
||||
def test_sequence_info_page_assignments():
|
||||
"""Test page_assignments property returns correct structure."""
|
||||
seq_info = SequenceInfo(
|
||||
max_seq_len=128,
|
||||
max_batch_size=4,
|
||||
tokens_per_block=16,
|
||||
)
|
||||
|
||||
# Set up two sequences with different page assignments
|
||||
input_ids = [[1] * 10, [1] * 20]
|
||||
seq_info.nest_sequences(
|
||||
input_ids,
|
||||
input_pos=0,
|
||||
cache_loc=[0, 1, 2], # seq 0 has page 0, seq 1 has pages 1 and 2
|
||||
pages_per_seq=[1, 2],
|
||||
)
|
||||
|
||||
page_assignments = seq_info.page_assignments
|
||||
assert page_assignments == [[0], [1, 2]]
|
||||
@ -107,6 +107,7 @@ def test_create_autodeploy_executor_with_guided_decoding(
|
||||
512 # placeholder to satisfy ADEngine.build_from_config
|
||||
)
|
||||
mock_engine.cache_seq_interface.info.vocab_size_padded = vocab_size_padded
|
||||
mock_engine.cache_seq_interface.max_num_state_slots = max_batch_size
|
||||
|
||||
# Mock the specific dependencies requested, plus minimal additional mocks to prevent errors
|
||||
with (
|
||||
|
||||
@ -5,10 +5,11 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from tensorrt_llm import SamplingParams
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import SequenceInfo
|
||||
from tensorrt_llm._torch.auto_deploy.shim.ad_executor import ADEngine
|
||||
from tensorrt_llm._torch.auto_deploy.shim.demollm import DemoEngine
|
||||
from tensorrt_llm._torch.auto_deploy.shim.interface import CachedSequenceInterface
|
||||
from tensorrt_llm._torch.pyexecutor.scheduler import ScheduledRequests
|
||||
from tensorrt_llm.llmapi.llm_args import KvCacheConfig
|
||||
|
||||
|
||||
class TransformerLikeModelwithFakeCachePool(nn.Module):
|
||||
@ -38,12 +39,13 @@ def get_inference_model(cache_seq_interface):
|
||||
|
||||
model = TransformerLikeModelwithFakeCachePool(vocab_size, embed_dim, hidden_dim)
|
||||
model.eval().to(device)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@pytest.mark.parametrize("engine_cls", [ADEngine, DemoEngine])
|
||||
@pytest.mark.parametrize("attn_page_size", [0, 2, 0])
|
||||
def test_engine(engine_cls: Type[ADEngine], attn_page_size: int):
|
||||
@pytest.mark.parametrize("tokens_per_block", [0, 2, 0])
|
||||
def test_engine(engine_cls: Type[ADEngine], tokens_per_block: int):
|
||||
"""Test the SimpleEngine functionality."""
|
||||
|
||||
seed = 42 # Set random seed for model param init
|
||||
@ -55,21 +57,24 @@ def test_engine(engine_cls: Type[ADEngine], attn_page_size: int):
|
||||
max_seq_len = 64
|
||||
max_batch_size = 8
|
||||
|
||||
sequence_info = SequenceInfo(
|
||||
# Create KvCacheConfig with specified tokens_per_block (use 32 as default if 0)
|
||||
kv_cache_config = KvCacheConfig(tokens_per_block=tokens_per_block or max_seq_len)
|
||||
cache_seq_interface = CachedSequenceInterface(
|
||||
max_seq_len=max_seq_len,
|
||||
max_batch_size=max_batch_size,
|
||||
page_size=attn_page_size,
|
||||
device=device,
|
||||
kv_cache_config=kv_cache_config,
|
||||
)
|
||||
sequence_info.to(device)
|
||||
cache_seq_interface.to(device)
|
||||
|
||||
engine = engine_cls(get_inference_model, sequence_info, device)
|
||||
engine = engine_cls(get_inference_model, cache_seq_interface)
|
||||
|
||||
# Test basic token generation
|
||||
with torch.inference_mode():
|
||||
# Test logits
|
||||
input_ids = [torch.tensor([0, 1, 2], device=device)]
|
||||
sequence_info.reset()
|
||||
sequence_info.nest_sequences(input_ids)
|
||||
cache_seq_interface.info.reset()
|
||||
cache_seq_interface.info.nest_sequences(input_ids)
|
||||
logits = engine._compute_logits()
|
||||
assert logits is not None, "Logits are None"
|
||||
|
||||
@ -77,9 +82,11 @@ def test_engine(engine_cls: Type[ADEngine], attn_page_size: int):
|
||||
original_logits = get_inference_model(mock_input)(input_ids[0].unsqueeze(0))[0]
|
||||
assert torch.allclose(logits, original_logits, atol=1e-5), "Generated Token ID mismatch"
|
||||
|
||||
cache_seq_interface.shutdown()
|
||||
|
||||
@pytest.mark.parametrize("attn_page_size", [0, 2])
|
||||
def test_demo_engine_sampling(attn_page_size: int):
|
||||
|
||||
@pytest.mark.parametrize("tokens_per_block", [0, 2])
|
||||
def test_demo_engine_sampling(tokens_per_block: int):
|
||||
"""Test sampling logic specific to DemoEngine."""
|
||||
seed = 0
|
||||
torch.manual_seed(seed)
|
||||
@ -90,19 +97,22 @@ def test_demo_engine_sampling(attn_page_size: int):
|
||||
max_seq_len = 64
|
||||
max_batch_size = 8
|
||||
|
||||
sequence_info = SequenceInfo(
|
||||
# Create KvCacheConfig with specified tokens_per_block (use 32 as default if 0)
|
||||
kv_cache_config = KvCacheConfig(tokens_per_block=tokens_per_block or 32)
|
||||
cache_seq_interface = CachedSequenceInterface(
|
||||
max_seq_len=max_seq_len,
|
||||
max_batch_size=max_batch_size,
|
||||
page_size=attn_page_size,
|
||||
device=device,
|
||||
kv_cache_config=kv_cache_config,
|
||||
)
|
||||
sequence_info.to(device)
|
||||
cache_seq_interface.to(device)
|
||||
|
||||
engine = DemoEngine(get_inference_model, sequence_info, device)
|
||||
engine = DemoEngine(get_inference_model, cache_seq_interface)
|
||||
|
||||
with torch.inference_mode():
|
||||
input_ids = [torch.tensor([1, 2, 3, 4], device=device)]
|
||||
sequence_info.reset()
|
||||
sequence_info.nest_sequences(input_ids)
|
||||
cache_seq_interface.info.reset()
|
||||
cache_seq_interface.info.nest_sequences(input_ids)
|
||||
logits = engine._compute_logits()
|
||||
|
||||
vocab_size = logits.size(-1)
|
||||
@ -127,6 +137,8 @@ def test_demo_engine_sampling(attn_page_size: int):
|
||||
|
||||
torch.testing.assert_close(token_ids_1, token_ids_2)
|
||||
|
||||
cache_seq_interface.shutdown()
|
||||
|
||||
|
||||
class _DummyKVCacheManager:
|
||||
def __init__(self, tokens_per_block: int):
|
||||
@ -163,8 +175,8 @@ class _DummyRequest:
|
||||
return self._tokens
|
||||
|
||||
|
||||
@pytest.mark.parametrize("attn_page_size", [256, 2])
|
||||
def test_ad_engine_chunked_prefill_equivalence(attn_page_size: int):
|
||||
@pytest.mark.parametrize("tokens_per_block", [256, 2])
|
||||
def test_ad_engine_chunked_prefill_equivalence(tokens_per_block: int):
|
||||
"""Verify ADEngine logits match between chunked and non-chunked prefill.
|
||||
|
||||
We simulate chunking by splitting a single context request into two chunks and
|
||||
@ -179,19 +191,22 @@ def test_ad_engine_chunked_prefill_equivalence(attn_page_size: int):
|
||||
max_seq_len = 64
|
||||
max_batch_size = 8
|
||||
|
||||
sequence_info = SequenceInfo(
|
||||
# Create KvCacheConfig with specified tokens_per_block
|
||||
kv_cache_config = KvCacheConfig(tokens_per_block=tokens_per_block)
|
||||
cache_seq_interface = CachedSequenceInterface(
|
||||
max_seq_len=max_seq_len,
|
||||
max_batch_size=max_batch_size,
|
||||
page_size=attn_page_size,
|
||||
device=device,
|
||||
kv_cache_config=kv_cache_config,
|
||||
)
|
||||
sequence_info.to(device)
|
||||
cache_seq_interface.to(device)
|
||||
|
||||
engine = ADEngine(get_inference_model, sequence_info, device)
|
||||
engine = ADEngine(get_inference_model, cache_seq_interface)
|
||||
|
||||
# A simple prompt; model is position-wise so last token dominates the last logit
|
||||
tokens = [1, 2, 3, 4, 5, 6]
|
||||
|
||||
kv_manager = _DummyKVCacheManager(tokens_per_block=attn_page_size)
|
||||
kv_manager = _DummyKVCacheManager(tokens_per_block=tokens_per_block)
|
||||
resource_manager = _DummyResourceManager(kv_manager)
|
||||
|
||||
# No-chunk: whole prompt in one request
|
||||
@ -215,3 +230,237 @@ def test_ad_engine_chunked_prefill_equivalence(attn_page_size: int):
|
||||
logits_chunked_last = engine.forward(scheduled_requests_part2, resource_manager)["logits"][-1]
|
||||
|
||||
torch.testing.assert_close(logits_full_last, logits_chunked_last) # , atol=1e-5)
|
||||
|
||||
cache_seq_interface.shutdown()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Hybrid Cache Manager Integration Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class _DummyHybridKVCacheManager:
|
||||
"""Simulates MambaHybridCacheManager with mamba_cache_index."""
|
||||
|
||||
def __init__(self, tokens_per_block: int, num_slots: int = 8):
|
||||
self.tokens_per_block = tokens_per_block
|
||||
# mamba_cache_index maps request_id to slot_idx
|
||||
self.mamba_cache_index = {i: num_slots - 1 - i for i in range(num_slots)}
|
||||
self.mamba_cache_free_blocks = num_slots
|
||||
|
||||
def get_cache_indices(self, request):
|
||||
return list(range(1024))
|
||||
|
||||
def get_num_kv_blocks(self, num_tokens: int) -> int:
|
||||
if self.tokens_per_block and self.tokens_per_block > 0:
|
||||
return (num_tokens + self.tokens_per_block - 1) // self.tokens_per_block
|
||||
return num_tokens
|
||||
|
||||
def get_num_free_blocks(self):
|
||||
return 100
|
||||
|
||||
|
||||
class _DummyRequestWithRequestId:
|
||||
"""Request with py_request_id for hybrid cache manager testing."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokens: List[int],
|
||||
begin: int,
|
||||
size: int,
|
||||
seq_slot: int = 0,
|
||||
request_id: int = 0,
|
||||
):
|
||||
self._tokens = tokens
|
||||
self.context_current_position = begin
|
||||
self.context_chunk_size = size
|
||||
self.seq_slot = seq_slot
|
||||
self.py_request_id = request_id
|
||||
self.py_batch_idx = None
|
||||
self.py_multimodal_data = None
|
||||
|
||||
def get_tokens(self, _beam: int) -> List[int]:
|
||||
return self._tokens
|
||||
|
||||
|
||||
def test_ad_engine_prepare_inputs_with_hybrid_cache_manager():
|
||||
"""Test ADEngine _prepare_inputs uses mamba_cache_index when available."""
|
||||
seed = 42
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
device = torch.device("cuda")
|
||||
max_seq_len = 64
|
||||
max_batch_size = 8
|
||||
tokens_per_block = 16
|
||||
|
||||
kv_cache_config = KvCacheConfig(tokens_per_block=tokens_per_block)
|
||||
cache_seq_interface = CachedSequenceInterface(
|
||||
max_seq_len=max_seq_len,
|
||||
max_batch_size=max_batch_size,
|
||||
device=device,
|
||||
kv_cache_config=kv_cache_config,
|
||||
)
|
||||
cache_seq_interface.to(device)
|
||||
|
||||
engine = ADEngine(get_inference_model, cache_seq_interface)
|
||||
|
||||
# Create hybrid KV cache manager with specific mamba_cache_index mapping
|
||||
hybrid_manager = _DummyHybridKVCacheManager(tokens_per_block=tokens_per_block)
|
||||
|
||||
class _HybridResourceManager:
|
||||
def __init__(self, kv_mgr):
|
||||
self._kv = kv_mgr
|
||||
|
||||
def get_resource_manager(self, _):
|
||||
return self._kv
|
||||
|
||||
resource_manager = _HybridResourceManager(hybrid_manager)
|
||||
|
||||
# Create request with specific request_id
|
||||
request_id = 3
|
||||
tokens = [1, 2, 3, 4]
|
||||
req = _DummyRequestWithRequestId(
|
||||
tokens=tokens,
|
||||
begin=0,
|
||||
size=len(tokens),
|
||||
seq_slot=0,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
scheduled = ScheduledRequests()
|
||||
scheduled.context_requests.append(req)
|
||||
|
||||
# Call _prepare_inputs
|
||||
engine._prepare_inputs(scheduled, resource_manager, new_tokens=None)
|
||||
|
||||
# Verify slot_idx was taken from mamba_cache_index, not seq_slot
|
||||
expected_slot_idx = hybrid_manager.mamba_cache_index[request_id]
|
||||
actual_slot_idx = cache_seq_interface.info._args_list["slot_idx"][0]
|
||||
assert actual_slot_idx == expected_slot_idx
|
||||
|
||||
cache_seq_interface.shutdown()
|
||||
|
||||
|
||||
def test_ad_engine_prepare_inputs_generation_with_hybrid_cache():
|
||||
"""Test ADEngine _prepare_inputs handles generation requests with hybrid cache."""
|
||||
seed = 42
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
device = torch.device("cuda")
|
||||
max_seq_len = 64
|
||||
max_batch_size = 8
|
||||
tokens_per_block = 16
|
||||
|
||||
kv_cache_config = KvCacheConfig(tokens_per_block=tokens_per_block)
|
||||
cache_seq_interface = CachedSequenceInterface(
|
||||
max_seq_len=max_seq_len,
|
||||
max_batch_size=max_batch_size,
|
||||
device=device,
|
||||
kv_cache_config=kv_cache_config,
|
||||
)
|
||||
cache_seq_interface.to(device)
|
||||
|
||||
engine = ADEngine(get_inference_model, cache_seq_interface)
|
||||
|
||||
# Create hybrid KV cache manager
|
||||
hybrid_manager = _DummyHybridKVCacheManager(tokens_per_block=tokens_per_block)
|
||||
|
||||
class _HybridResourceManager:
|
||||
def __init__(self, kv_mgr):
|
||||
self._kv = kv_mgr
|
||||
|
||||
def get_resource_manager(self, _):
|
||||
return self._kv
|
||||
|
||||
resource_manager = _HybridResourceManager(hybrid_manager)
|
||||
|
||||
# Create generation request
|
||||
class _GenRequest:
|
||||
def __init__(self, request_id: int, seq_slot: int, num_tokens: int):
|
||||
self.py_request_id = request_id
|
||||
self.seq_slot = seq_slot
|
||||
self.py_batch_idx = None
|
||||
self.is_dummy = False
|
||||
self.py_draft_tokens = []
|
||||
|
||||
# Mock methods for generation request
|
||||
def get_token(beam, idx):
|
||||
return 42 # Dummy token
|
||||
|
||||
self.get_token = get_token
|
||||
self.get_num_tokens = lambda beam: num_tokens
|
||||
self.max_beam_num_tokens = num_tokens
|
||||
|
||||
def get_draft_token_length(self):
|
||||
return 0
|
||||
|
||||
# Create a generation request with specific request_id
|
||||
request_id = 2
|
||||
gen_req = _GenRequest(request_id=request_id, seq_slot=5, num_tokens=10)
|
||||
|
||||
scheduled = ScheduledRequests()
|
||||
scheduled.generation_requests.append(gen_req)
|
||||
|
||||
# Call _prepare_inputs
|
||||
engine._prepare_inputs(scheduled, resource_manager, new_tokens=None)
|
||||
|
||||
# Verify slot_idx was taken from mamba_cache_index
|
||||
expected_slot_idx = hybrid_manager.mamba_cache_index[request_id]
|
||||
actual_slot_idx = cache_seq_interface.info._args_list["slot_idx"][0]
|
||||
assert actual_slot_idx == expected_slot_idx
|
||||
|
||||
cache_seq_interface.shutdown()
|
||||
|
||||
|
||||
def test_ad_engine_with_regular_kv_cache_manager():
|
||||
"""Test ADEngine falls back to seq_slot when mamba_cache_index not available."""
|
||||
seed = 42
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
device = torch.device("cuda")
|
||||
max_seq_len = 64
|
||||
max_batch_size = 8
|
||||
tokens_per_block = 16
|
||||
|
||||
kv_cache_config = KvCacheConfig(tokens_per_block=tokens_per_block)
|
||||
cache_seq_interface = CachedSequenceInterface(
|
||||
max_seq_len=max_seq_len,
|
||||
max_batch_size=max_batch_size,
|
||||
device=device,
|
||||
kv_cache_config=kv_cache_config,
|
||||
)
|
||||
cache_seq_interface.to(device)
|
||||
|
||||
engine = ADEngine(get_inference_model, cache_seq_interface)
|
||||
|
||||
# Use regular (non-hybrid) KV cache manager without mamba_cache_index
|
||||
regular_manager = _DummyKVCacheManager(tokens_per_block=tokens_per_block)
|
||||
resource_manager = _DummyResourceManager(regular_manager)
|
||||
|
||||
# Create request with specific seq_slot
|
||||
expected_seq_slot = 3
|
||||
tokens = [1, 2, 3, 4]
|
||||
req = _DummyRequest(
|
||||
tokens=tokens,
|
||||
begin=0,
|
||||
size=len(tokens),
|
||||
seq_slot=expected_seq_slot,
|
||||
)
|
||||
|
||||
scheduled = ScheduledRequests()
|
||||
scheduled.context_requests.append(req)
|
||||
|
||||
# Call _prepare_inputs
|
||||
engine._prepare_inputs(scheduled, resource_manager, new_tokens=None)
|
||||
|
||||
# Verify slot_idx falls back to seq_slot
|
||||
actual_slot_idx = cache_seq_interface.info._args_list["slot_idx"][0]
|
||||
assert actual_slot_idx == expected_seq_slot
|
||||
|
||||
cache_seq_interface.shutdown()
|
||||
|
||||
@ -4,7 +4,6 @@ import pydantic
|
||||
import pytest
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy import LLM, DemoLLM, LlmArgs
|
||||
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
|
||||
|
||||
|
||||
def test_custom_values():
|
||||
@ -14,7 +13,6 @@ def test_custom_values():
|
||||
"model_factory": "AutoModelForImageTextToText",
|
||||
"model_kwargs": {"custom_param": True},
|
||||
"skip_loading_weights": True,
|
||||
"attn_page_size": 128,
|
||||
"max_seq_len": 2048,
|
||||
"transforms": {
|
||||
"detect_sharding": {
|
||||
@ -25,10 +23,6 @@ def test_custom_values():
|
||||
"stage": "cache_init",
|
||||
"backend": "flashinfer",
|
||||
},
|
||||
"resize_kv_cache": {
|
||||
"stage": "cache_init",
|
||||
"free_mem_ratio": 0.9,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@ -39,32 +33,12 @@ def test_custom_values():
|
||||
"custom_param": True,
|
||||
}
|
||||
assert args.skip_loading_weights
|
||||
assert args.transforms["resize_kv_cache"]["free_mem_ratio"] == 0.9
|
||||
assert args.transforms["detect_sharding"]["simple_shard_only"]
|
||||
assert args.attn_page_size == 128
|
||||
assert args.max_seq_len == 2048
|
||||
# backend should be overridden if it was 'TRTLLM'
|
||||
assert args.transforms["insert_cached_attention"]["backend"] == "flashinfer"
|
||||
|
||||
|
||||
def test_free_mem_ratio_validation():
|
||||
"""Test free_mem_ratio validation."""
|
||||
|
||||
def get_transform_config(free_mem_ratio):
|
||||
return {"resize_kv_cache": {"stage": "cache_init", "free_mem_ratio": free_mem_ratio}}
|
||||
|
||||
# Valid values
|
||||
InferenceOptimizer(None, get_transform_config(0.0))
|
||||
InferenceOptimizer(None, get_transform_config(1.0))
|
||||
InferenceOptimizer(None, get_transform_config(0.5))
|
||||
|
||||
# Invalid values
|
||||
with pytest.raises(ValueError):
|
||||
InferenceOptimizer(None, get_transform_config(-0.1))
|
||||
with pytest.raises(ValueError):
|
||||
InferenceOptimizer(None, get_transform_config(1.1))
|
||||
|
||||
|
||||
# ================================
|
||||
# Config Flow Tests
|
||||
# ================================
|
||||
@ -77,7 +51,6 @@ def test_config_params():
|
||||
"model": "test-model",
|
||||
"model_factory": "AutoModelForImageTextToText",
|
||||
"skip_loading_weights": True,
|
||||
"attn_page_size": 17,
|
||||
"max_seq_len": 19,
|
||||
"max_batch_size": 5,
|
||||
"world_size": 3,
|
||||
@ -90,10 +63,6 @@ def test_config_params():
|
||||
"stage": "cache_init",
|
||||
"backend": "flashinfer",
|
||||
},
|
||||
"resize_kv_cache": {
|
||||
"stage": "cache_init",
|
||||
"free_mem_ratio": 0.7,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@ -151,16 +120,11 @@ def test_config_flow(
|
||||
|
||||
# Common assertions for both APIs
|
||||
assert instance.args.model_factory == test_config_params["model_factory"]
|
||||
assert (
|
||||
instance.args.transforms["resize_kv_cache"]["free_mem_ratio"]
|
||||
== test_config_params["transforms"]["resize_kv_cache"]["free_mem_ratio"]
|
||||
)
|
||||
assert (
|
||||
instance.args.transforms["detect_sharding"]["simple_shard_only"]
|
||||
== test_config_params["transforms"]["detect_sharding"]["simple_shard_only"]
|
||||
)
|
||||
assert instance.args.skip_loading_weights == test_config_params["skip_loading_weights"]
|
||||
assert instance.args.attn_page_size == test_config_params["attn_page_size"]
|
||||
assert instance.args.max_seq_len == test_config_params["max_seq_len"]
|
||||
assert instance.args.max_batch_size == test_config_params["max_batch_size"]
|
||||
|
||||
@ -215,23 +179,6 @@ def test_parallel_config_validation(parallel_field, invalid_value):
|
||||
LlmArgs(**kwargs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"backend,expected_attn_page_size",
|
||||
[
|
||||
("flashinfer", 64), # Default attn_page_size
|
||||
("triton", 1024), # Should equal max_seq_len
|
||||
],
|
||||
)
|
||||
def test_attention_backend_page_size_logic(backend, expected_attn_page_size):
|
||||
"""Test attn_page_size logic for different attention backends."""
|
||||
args = LlmArgs(
|
||||
model="test-model",
|
||||
max_seq_len=1024,
|
||||
transforms={"insert_cached_attention": {"stage": "cache_init", "backend": backend}},
|
||||
)
|
||||
assert args.attn_page_size == expected_attn_page_size
|
||||
|
||||
|
||||
# ================================
|
||||
# CUDA Graph Batch Sizes Tests
|
||||
# ================================
|
||||
@ -351,7 +298,7 @@ class TestSequenceInfoExampleBatchSize:
|
||||
max_batch_size=1,
|
||||
max_seq_len=128,
|
||||
max_num_tokens=128,
|
||||
page_size=64,
|
||||
tokens_per_block=64,
|
||||
)
|
||||
|
||||
# Set example sequence (this is what's used during export)
|
||||
@ -371,7 +318,7 @@ class TestSequenceInfoExampleBatchSize:
|
||||
max_batch_size=32,
|
||||
max_seq_len=128,
|
||||
max_num_tokens=128,
|
||||
page_size=64,
|
||||
tokens_per_block=64,
|
||||
)
|
||||
|
||||
seq_info.set_example_sequence()
|
||||
|
||||
@ -41,8 +41,10 @@ def _check_ad_config(experiment_config: ExperimentConfig, llm_args: LlmArgs):
|
||||
(
|
||||
"meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
{
|
||||
"kv_cache_config": {
|
||||
"free_gpu_memory_fraction": 0.0001,
|
||||
},
|
||||
"transforms": {
|
||||
"resize_kv_cache": {"free_mem_ratio": 0.0001},
|
||||
"insert_cached_attention": {"backend": "flashinfer"},
|
||||
# TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/9878
|
||||
# "compile_model": {"backend": "torch-opt"},
|
||||
|
||||
@ -84,25 +84,27 @@ def prepare_dataset(root_dir: str, temp_dir: str, model_path_or_name: str):
|
||||
@pytest.mark.parametrize("compile_backend", ["torch-compile", "torch-opt", "torch-cudagraph"])
|
||||
@pytest.mark.parametrize("model_name", ["TinyLlama/TinyLlama-1.1B-Chat-v1.0"])
|
||||
def test_trtllm_bench(llm_root, compile_backend, model_name): # noqa: F811
|
||||
config = get_small_model_config(model_name)
|
||||
args = get_small_model_config(model_name)["args"]
|
||||
# remove kv_cache_config and max_batch_size to avoid conflicts with trtllm-bench
|
||||
args.pop("kv_cache_config", None)
|
||||
args.pop("max_batch_size", None)
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
extra_llm_api_options_path = f"{temp_dir}/extra_llm_api_options.yaml"
|
||||
with open(extra_llm_api_options_path, "w") as f:
|
||||
yaml.dump(
|
||||
{
|
||||
**config["args"],
|
||||
**args,
|
||||
"transforms": {
|
||||
"resize_kv_cache": {"enabled": False}, # rely on default estimation
|
||||
"compile_model": {
|
||||
"stage": "compile",
|
||||
"cuda_graph_batch_sizes": [1, 2, 4, 8, 16, 32, 64, 128],
|
||||
"backend": compile_backend,
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
f,
|
||||
)
|
||||
|
||||
dataset_path = prepare_dataset(llm_root, temp_dir, config["args"]["model"])
|
||||
run_benchmark(
|
||||
model_name, str(config["args"]["model"]), dataset_path, extra_llm_api_options_path
|
||||
)
|
||||
dataset_path = prepare_dataset(llm_root, temp_dir, args["model"])
|
||||
run_benchmark(model_name, str(args["model"]), dataset_path, extra_llm_api_options_path)
|
||||
|
||||
@ -22,7 +22,6 @@ from torch.export import Dim
|
||||
from torch.fx import GraphModule
|
||||
|
||||
# Import to register custom op
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import SequenceInfo
|
||||
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
|
||||
from tensorrt_llm._torch.auto_deploy.shim.interface import CachedSequenceInterface
|
||||
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
|
||||
@ -154,13 +153,12 @@ class TestGatherLogitsBeforeLmHeadTransform:
|
||||
|
||||
def _create_cached_sequence_interface(self, max_batch_size: int = 8, device: str = "cuda"):
|
||||
"""Create a mock CachedSequenceInterface for testing."""
|
||||
seq_info = SequenceInfo(
|
||||
return CachedSequenceInterface(
|
||||
max_seq_len=64,
|
||||
max_batch_size=max_batch_size,
|
||||
device=device,
|
||||
max_num_tokens=1024,
|
||||
)
|
||||
seq_info.to(device)
|
||||
return CachedSequenceInterface(seq_info, device=device)
|
||||
|
||||
def _check_gather_op_in_graph(self, gm: GraphModule) -> bool:
|
||||
"""Check if gather_logits_before_lm_head op is in the graph."""
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from typing import List, Optional
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -6,7 +7,8 @@ import torch.nn as nn
|
||||
from _model_test_utils import GQA
|
||||
from _torch_test_utils import all_close
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import CacheConfig, SequenceInfo
|
||||
# Initialize resources first
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import PagedResourceHandler
|
||||
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
|
||||
from tensorrt_llm._torch.auto_deploy.models.factory import (
|
||||
FullModelExportInfo,
|
||||
@ -14,15 +16,18 @@ from tensorrt_llm._torch.auto_deploy.models.factory import (
|
||||
SubModuleExportInfo,
|
||||
)
|
||||
from tensorrt_llm._torch.auto_deploy.shim.interface import CachedSequenceInterface
|
||||
from tensorrt_llm._torch.auto_deploy.transform.interface import Stages, TransformConfig
|
||||
from tensorrt_llm._torch.auto_deploy.transform.library.kvcache import InitializeCache, ResizeKVCache
|
||||
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
|
||||
from tensorrt_llm.llmapi.llm_args import KvCacheConfig
|
||||
|
||||
|
||||
class DummyFactory(ModelFactory):
|
||||
"""Dummy factory to pass cache_config for testing."""
|
||||
"""Dummy factory to pass cache_config_updates for testing."""
|
||||
|
||||
def __init__(self, model, cache_config):
|
||||
def __init__(self, model, cache_config_updates):
|
||||
self._model = model
|
||||
self.cache_config = cache_config
|
||||
self.cache_config_updates = cache_config_updates
|
||||
|
||||
def build_model(self, device: str):
|
||||
return self._model.to(device=device)
|
||||
@ -33,8 +38,8 @@ class DummyFactory(ModelFactory):
|
||||
def _load_checkpoint(self, model, device):
|
||||
return
|
||||
|
||||
def get_cache_config(self):
|
||||
return self.cache_config
|
||||
def get_cache_config_updates(self):
|
||||
return self.cache_config_updates
|
||||
|
||||
def get_export_infos(self, model: nn.Module) -> List[SubModuleExportInfo]:
|
||||
return [FullModelExportInfo()]
|
||||
@ -151,12 +156,19 @@ def test_sdpa_with_kv_cache(dtype, attn_backend, gqa_config):
|
||||
max_position_embeddings = 128
|
||||
vocab_size = 1000
|
||||
|
||||
# set up sequence+cache objects using standard SequenceInfo
|
||||
ci = SequenceInfo(
|
||||
# set up sequence+cache objects using CachedSequenceInterface
|
||||
# Use tokens_per_block=max_position_embeddings so each sequence fits in 1 page for the test
|
||||
kv_cache_config = KvCacheConfig(
|
||||
tokens_per_block=max_position_embeddings,
|
||||
max_tokens=batch_size * max_position_embeddings,
|
||||
free_gpu_memory_fraction=0.0, # Disable dynamic resizing for test
|
||||
)
|
||||
cm = CachedSequenceInterface(
|
||||
max_seq_len=max_position_embeddings,
|
||||
max_batch_size=batch_size,
|
||||
device="cuda",
|
||||
kv_cache_config=kv_cache_config,
|
||||
)
|
||||
cm = CachedSequenceInterface(sequence_info=ci, device="cuda")
|
||||
|
||||
# Create the model with embedding layer and SDPA, wrap it in a fake factory
|
||||
model = GQAWithSdpaAndEmbedding(
|
||||
@ -175,7 +187,7 @@ def test_sdpa_with_kv_cache(dtype, attn_backend, gqa_config):
|
||||
|
||||
# Apply the transformation
|
||||
optimizer = InferenceOptimizer(
|
||||
DummyFactory(model, CacheConfig()),
|
||||
DummyFactory(model, cache_config_updates={}),
|
||||
{
|
||||
"build_model": {
|
||||
"stage": "factory",
|
||||
@ -204,7 +216,7 @@ def test_sdpa_with_kv_cache(dtype, attn_backend, gqa_config):
|
||||
gm = optimizer(cm)
|
||||
|
||||
gm.to("cuda")
|
||||
num_caches = cm.initialize_caches()
|
||||
num_caches = cm.initialize_resources()
|
||||
print(f"num_caches: {num_caches}")
|
||||
|
||||
# Helper function to call the model with proper sequence nesting
|
||||
@ -248,3 +260,273 @@ def test_sdpa_with_kv_cache(dtype, attn_backend, gqa_config):
|
||||
# Test 4: Exportability of the transformed model
|
||||
exported_gm = torch_export_to_gm(gm, args=(), kwargs=cm.named_args)
|
||||
assert exported_gm is not None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Transform Unit Tests for Refactored Pipeline
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_cached_interface():
|
||||
"""Create a CachedSequenceInterface for transform testing."""
|
||||
kv_cache_config = KvCacheConfig(
|
||||
tokens_per_block=32,
|
||||
max_tokens=256,
|
||||
free_gpu_memory_fraction=0.0,
|
||||
)
|
||||
return CachedSequenceInterface(
|
||||
max_seq_len=128,
|
||||
max_batch_size=4,
|
||||
device="cuda",
|
||||
kv_cache_config=kv_cache_config,
|
||||
)
|
||||
|
||||
|
||||
def test_initialize_cache_transform_calls_initialize_resources(dummy_cached_interface):
|
||||
"""Verify InitializeCache transform calls cm.initialize_resources()."""
|
||||
# Create a mock module
|
||||
mock_module = MagicMock()
|
||||
|
||||
# Create the transform with a proper config
|
||||
transform = InitializeCache(config=TransformConfig(stage=Stages.PATTERN_MATCHER))
|
||||
|
||||
# Add a resource to verify initialize_resources is called
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import PagedResourceHandler
|
||||
|
||||
dummy_cached_interface.add_resource(
|
||||
"k_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16)
|
||||
)
|
||||
|
||||
# Mock the factory and shared_config
|
||||
mock_factory = MagicMock()
|
||||
mock_shared_config = MagicMock()
|
||||
|
||||
# Run the transform
|
||||
result_mod, info = transform._apply_to_full_model(
|
||||
mock_module, dummy_cached_interface, mock_factory, mock_shared_config
|
||||
)
|
||||
|
||||
# Verify caches were initialized
|
||||
assert info.skipped is False
|
||||
assert info.num_matches >= 1
|
||||
assert dummy_cached_interface.kv_cache_manager is not None
|
||||
|
||||
|
||||
def test_resize_kv_cache_transform_skipped_when_not_needed(dummy_cached_interface):
|
||||
"""Verify ResizeKVCache transform is skipped when resize not needed."""
|
||||
dummy_cached_interface.add_resource(
|
||||
"k_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16)
|
||||
)
|
||||
dummy_cached_interface.initialize_resources()
|
||||
|
||||
# Create the transform with a proper config
|
||||
transform = ResizeKVCache(config=TransformConfig(stage=Stages.PATTERN_MATCHER))
|
||||
|
||||
# Create a mock module
|
||||
mock_module = MagicMock()
|
||||
|
||||
# Mock forward call
|
||||
mock_module.side_effect = lambda **kwargs: None
|
||||
|
||||
mock_factory = MagicMock()
|
||||
mock_shared_config = MagicMock()
|
||||
|
||||
# Run the transform - should be skipped since free_gpu_memory_fraction=0.0
|
||||
result_mod, info = transform._apply_to_full_model(
|
||||
mock_module, dummy_cached_interface, mock_factory, mock_shared_config
|
||||
)
|
||||
|
||||
# Verify transform was skipped
|
||||
assert info.skipped is True
|
||||
|
||||
|
||||
def test_resize_kv_cache_transform_runs_when_needed():
|
||||
"""Verify ResizeKVCache transform runs when resize is needed."""
|
||||
# Create interface with resizing enabled
|
||||
kv_cache_config = KvCacheConfig(
|
||||
tokens_per_block=32,
|
||||
max_tokens=256,
|
||||
free_gpu_memory_fraction=0.5, # Enable resizing
|
||||
)
|
||||
cm = CachedSequenceInterface(
|
||||
max_seq_len=128,
|
||||
max_batch_size=4,
|
||||
device="cuda",
|
||||
kv_cache_config=kv_cache_config,
|
||||
)
|
||||
|
||||
cm.add_resource("k_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16))
|
||||
cm.initialize_resources()
|
||||
|
||||
# Create the transform with a proper config
|
||||
transform = ResizeKVCache(config=TransformConfig(stage=Stages.PATTERN_MATCHER))
|
||||
|
||||
# Create a simple mock module that just returns None
|
||||
class MockModule:
|
||||
def __call__(self, **kwargs):
|
||||
return None
|
||||
|
||||
mock_module = MockModule()
|
||||
mock_factory = MagicMock()
|
||||
mock_shared_config = MagicMock()
|
||||
|
||||
# Run the transform
|
||||
result_mod, info = transform._apply_to_full_model(
|
||||
mock_module, cm, mock_factory, mock_shared_config
|
||||
)
|
||||
|
||||
# Verify transform was not skipped
|
||||
assert info.skipped is False
|
||||
|
||||
|
||||
def test_insert_cached_attention_uses_add_resource():
|
||||
"""Verify InsertCachedAttention uses cm.add_resource() for cache registration."""
|
||||
# This test verifies the integration point between InsertCachedAttention
|
||||
# and CachedSequenceInterface.add_resource() by checking that after the
|
||||
# transform, resources are registered in the interface.
|
||||
|
||||
num_attention_heads = 8
|
||||
hidden_size = 512
|
||||
num_key_value_heads = 8
|
||||
vocab_size = 1000
|
||||
batch_size = 4
|
||||
max_seq_len = 64
|
||||
|
||||
kv_cache_config = KvCacheConfig(
|
||||
tokens_per_block=max_seq_len,
|
||||
max_tokens=batch_size * max_seq_len,
|
||||
free_gpu_memory_fraction=0.0,
|
||||
)
|
||||
cm = CachedSequenceInterface(
|
||||
max_seq_len=max_seq_len,
|
||||
max_batch_size=batch_size,
|
||||
device="cuda",
|
||||
kv_cache_config=kv_cache_config,
|
||||
)
|
||||
|
||||
# Create a model
|
||||
model = GQAWithSdpaAndEmbedding(
|
||||
num_attention_heads,
|
||||
hidden_size,
|
||||
num_key_value_heads,
|
||||
vocab_size=vocab_size,
|
||||
).to(dtype=torch.float16, device="cuda")
|
||||
|
||||
# Apply transformation
|
||||
optimizer = InferenceOptimizer(
|
||||
DummyFactory(model, cache_config_updates={}),
|
||||
{
|
||||
"build_model": {
|
||||
"stage": "factory",
|
||||
"run_per_gm": False,
|
||||
"device": "cuda",
|
||||
"run_graph_cleanup": False,
|
||||
"requires_clean_graph": False,
|
||||
},
|
||||
"export_to_gm": {
|
||||
"stage": "export",
|
||||
"strict": False,
|
||||
"run_per_gm": False,
|
||||
"clone_state_dict": True,
|
||||
"run_graph_cleanup": False,
|
||||
"requires_clean_graph": False,
|
||||
},
|
||||
"cleanup_input_constraints": {
|
||||
"stage": "post_export",
|
||||
},
|
||||
"insert_cached_attention": {
|
||||
"stage": "cache_init",
|
||||
"backend": "triton",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
optimizer(cm)
|
||||
|
||||
# Verify resources were added
|
||||
assert len(cm._resource_lookup) > 0
|
||||
# Should have k_cache and v_cache resources registered
|
||||
resource_names = list(cm._resource_lookup.keys())
|
||||
assert any("k_cache" in name for name in resource_names)
|
||||
assert any("v_cache" in name for name in resource_names)
|
||||
|
||||
|
||||
def test_insert_cached_attention_passes_kv_cache_config():
|
||||
"""Verify InsertCachedAttention passes cm.kv_cache_config to get_cache_initializers."""
|
||||
# This test verifies that the KvCacheConfig from the interface is used
|
||||
# when initializing cache resources (e.g., for dtype configuration).
|
||||
|
||||
num_attention_heads = 8
|
||||
hidden_size = 512
|
||||
num_key_value_heads = 8
|
||||
vocab_size = 1000
|
||||
batch_size = 4
|
||||
max_seq_len = 64
|
||||
|
||||
# Use specific dtype in kv_cache_config
|
||||
kv_cache_config = KvCacheConfig(
|
||||
tokens_per_block=max_seq_len,
|
||||
max_tokens=batch_size * max_seq_len,
|
||||
free_gpu_memory_fraction=0.0,
|
||||
dtype="bfloat16", # Specify explicit dtype
|
||||
)
|
||||
cm = CachedSequenceInterface(
|
||||
max_seq_len=max_seq_len,
|
||||
max_batch_size=batch_size,
|
||||
device="cuda",
|
||||
kv_cache_config=kv_cache_config,
|
||||
)
|
||||
|
||||
# Verify kv_cache_config is accessible
|
||||
assert cm.kv_cache_config.dtype == "bfloat16"
|
||||
|
||||
# Create a model
|
||||
model = GQAWithSdpaAndEmbedding(
|
||||
num_attention_heads,
|
||||
hidden_size,
|
||||
num_key_value_heads,
|
||||
vocab_size=vocab_size,
|
||||
).to(dtype=torch.bfloat16, device="cuda")
|
||||
|
||||
# Apply transformation
|
||||
optimizer = InferenceOptimizer(
|
||||
DummyFactory(model, cache_config_updates={}),
|
||||
{
|
||||
"build_model": {
|
||||
"stage": "factory",
|
||||
"run_per_gm": False,
|
||||
"device": "cuda",
|
||||
"run_graph_cleanup": False,
|
||||
"requires_clean_graph": False,
|
||||
},
|
||||
"export_to_gm": {
|
||||
"stage": "export",
|
||||
"strict": False,
|
||||
"run_per_gm": False,
|
||||
"clone_state_dict": True,
|
||||
"run_graph_cleanup": False,
|
||||
"requires_clean_graph": False,
|
||||
},
|
||||
"cleanup_input_constraints": {
|
||||
"stage": "post_export",
|
||||
},
|
||||
"insert_cached_attention": {
|
||||
"stage": "cache_init",
|
||||
"backend": "triton",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
optimizer(cm)
|
||||
|
||||
# Initialize resources
|
||||
cm.initialize_resources()
|
||||
|
||||
assert not cm.is_paged(), "triton should not use paged resources"
|
||||
assert cm._caches, "at least some resources should be present"
|
||||
|
||||
# Verify cache dtype matches config
|
||||
for name, handler in cm._resource_lookup.items():
|
||||
if hasattr(handler, "dtype"):
|
||||
assert handler.dtype == torch.bfloat16
|
||||
|
||||
Loading…
Reference in New Issue
Block a user