mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 07:53:55 +08:00
Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
This commit is contained in:
parent
e2bd9cce1e
commit
925d911fc0
@ -270,5 +270,5 @@ Deprecation is used to inform developers that some APIs and tools are no longer
|
||||
## Useful Links
|
||||
- [Quantized models on Hugging Face](https://huggingface.co/collections/nvidia/model-optimizer-66aa84f7966b3150262481a4): A growing collection of quantized (e.g., FP8, FP4) and optimized LLMs, including [DeepSeek FP4](https://huggingface.co/nvidia/DeepSeek-R1-FP4), ready for fast inference with TensorRT LLM.
|
||||
- [NVIDIA Dynamo](https://github.com/ai-dynamo/dynamo): A datacenter scale distributed inference serving framework that works seamlessly with TensorRT LLM.
|
||||
- [AutoDeploy](https://nvidia.github.io/TensorRT-LLM/torch/auto_deploy/auto-deploy.html): A prototype backend for TensorRT LLM to simplify and accelerate the deployment of PyTorch models.
|
||||
- [AutoDeploy](https://nvidia.github.io/TensorRT-LLM/features/auto_deploy/auto-deploy.html): A beta backend for TensorRT LLM to simplify and accelerate the deployment of PyTorch models.
|
||||
- [WeChat Discussion Group](https://github.com/NVIDIA/TensorRT-LLM/issues/5359): A real-time channel for TensorRT LLM Q&A and news.
|
||||
|
||||
@ -8,3 +8,4 @@ pygit2
|
||||
sphinx_copybutton
|
||||
autodoc_pydantic
|
||||
sphinx-togglebutton
|
||||
sphinxcontrib-mermaid
|
||||
|
||||
@ -66,6 +66,7 @@ extensions = [
|
||||
'sphinx_copybutton',
|
||||
'sphinxcontrib.autodoc_pydantic',
|
||||
'sphinx_togglebutton',
|
||||
'sphinxcontrib.mermaid',
|
||||
'trtllm_config_selector',
|
||||
]
|
||||
|
||||
|
||||
@ -0,0 +1,199 @@
|
||||
# KV Cache Architecture
|
||||
|
||||
## Overview
|
||||
|
||||
The caching system in AutoDeploy manages KV caches for attention layers, SSM/convolution states for Mamba models, and other stateful resources. The architecture is built around three key concepts:
|
||||
|
||||
1. **Resource Handlers** - Abstract descriptions of cache resources (shape, dtype, layout)
|
||||
1. **CachedSequenceInterface** - The central manager that collects handlers and allocates caches
|
||||
1. **KVCacheManager / MambaHybridCacheManager** - Low-level memory managers from the executor
|
||||
|
||||
## Flowchart: Cache Collection and Allocation Pipeline
|
||||
|
||||
```{mermaid}
|
||||
flowchart TD
|
||||
subgraph Phase1["PHASE 1: RESOURCE HANDLER COLLECTION"]
|
||||
A1[AttentionDescriptor implementations]
|
||||
A2["get_cache_initializers(node, config)"]
|
||||
A3["cm.add_resource(k_indexed, handler)"]
|
||||
A1 --> A2
|
||||
A2 -->|"Returns ResourceHandlerDict"| A3
|
||||
end
|
||||
|
||||
subgraph Phase2["PHASE 2: INITIALIZE CACHES"]
|
||||
B1["_resource_lookup: Dict"]
|
||||
B2["Initialize _caches with None values"]
|
||||
B3["Order matches _resource_lookup"]
|
||||
B1 --> B2
|
||||
B2 --> B3
|
||||
end
|
||||
|
||||
subgraph Phase3["PHASE 3: COMPATIBILITY CHECKING"]
|
||||
C1["_identify_managed_kv_resources()"]
|
||||
C2["_identify_managed_state_resources()"]
|
||||
end
|
||||
|
||||
subgraph Phase4["PHASE 4: CACHE MANAGER CREATION"]
|
||||
D1{"Has state resources?"}
|
||||
D2["Create MambaHybridCacheManager"]
|
||||
D3["Create KVCacheManager"]
|
||||
D4["VIEW ASSIGNMENT"]
|
||||
D5["self._caches = tensor_view"]
|
||||
D1 -->|YES| D2
|
||||
D1 -->|NO| D3
|
||||
D2 --> D4
|
||||
D3 --> D4
|
||||
D4 --> D5
|
||||
end
|
||||
|
||||
subgraph Phase5["PHASE 5: UNMANAGED RESOURCE ALLOCATION"]
|
||||
E1["handler.allocate(sequence_info)"]
|
||||
end
|
||||
|
||||
subgraph Phase6["PHASE 6: OPTIONAL RESIZE"]
|
||||
F1["Recreate KVCacheManager with optimal capacity"]
|
||||
end
|
||||
|
||||
Phase1 --> Phase2
|
||||
Phase2 --> Phase3
|
||||
Phase3 --> Phase4
|
||||
Phase4 --> Phase5
|
||||
Phase5 --> Phase6
|
||||
```
|
||||
|
||||
## Detailed Pipeline Flow
|
||||
|
||||
For reference, here's the detailed text-based flow:
|
||||
|
||||
```text
|
||||
PHASE 1: RESOURCE HANDLER COLLECTION (insert_cached_attention transform)
|
||||
├── AttentionDescriptor implementations (e.g., FlashinferCachedAttention)
|
||||
├── get_cache_initializers(node, config)
|
||||
│ ├── Extracts shapes from FakeTensors
|
||||
│ └── Returns ResourceHandlerDict {"kv_cache": KVPagedResourceHandler, ...}
|
||||
└── For each attention node (idx=0,1,2...):
|
||||
└── cm.add_resource(f"{k}_{idx}", handler) → Stores in CachedSequenceInterface
|
||||
|
||||
PHASE 2: INITIALIZE CACHES (initialize_resources)
|
||||
├── _resource_lookup contains all collected handlers
|
||||
└── Initialize _caches dict with None values (same order as _resource_lookup)
|
||||
|
||||
PHASE 3: COMPATIBILITY CHECKING (dynamically from _resource_lookup)
|
||||
├── _identify_managed_kv_resources()
|
||||
│ ├── Iterate _resource_lookup, find first KVPagedResourceHandler → kv_ref
|
||||
│ ├── All KVPagedResourceHandlers matching kv_ref (head_dim, dtype, layout) → kv_managed
|
||||
│ └── Non-matching handlers → local allocation later
|
||||
└── _identify_managed_state_resources()
|
||||
├── Iterate _resource_lookup, find first SSMResourceHandler → ssm_ref
|
||||
├── Iterate _resource_lookup, find first CausalConvResourceHandler → conv_ref
|
||||
├── Check n_groups constraint: conv_dim = head_dim*num_heads + 2*n_groups*d_state
|
||||
└── If constraint fails → conv_ref = None (local allocation)
|
||||
|
||||
PHASE 4: CACHE MANAGER CREATION (_create_kv_cache_manager)
|
||||
├── Has state resources? (ssm_managed or conv_managed)
|
||||
│ ├── YES → Create MambaHybridCacheManager (manages KV + SSM + Conv)
|
||||
│ └── NO → Create KVCacheManager (manages paged KV only)
|
||||
└── View Assignment:
|
||||
├── _assign_kv_cache_views() → manager.get_buffers(idx)
|
||||
└── _create_and_assign_state_views() → manager.get_ssm_states/get_conv_states
|
||||
|
||||
PHASE 5: UNMANAGED RESOURCE ALLOCATION
|
||||
└── For resources where self._caches[name] is None:
|
||||
├── self._caches[name] = handler.allocate(sequence_info)
|
||||
└── Track in _unmanaged_resources list (for proper .to() handling)
|
||||
|
||||
PHASE 6: OPTIONAL RESIZE (resize_kv_cache transform)
|
||||
├── Run forward pass to measure activation memory
|
||||
├── Shutdown existing KVCacheManager
|
||||
├── Compute: mem_for_paged = (free_mem - non_paged - forward_mem) * free_gpu_fraction
|
||||
└── Recreate KVCacheManager with optimal max_tokens
|
||||
```
|
||||
|
||||
## Key Resource Handler Types
|
||||
|
||||
| Handler Type | Managed By | Buffer Source | Use Case |
|
||||
|--------------|------------|---------------|----------|
|
||||
| `KVPagedResourceHandler` | `KVCacheManager` | `get_buffers(idx)` | Paged KV caches for attention |
|
||||
| `SSMResourceHandler` | `MambaHybridCacheManager` | `get_ssm_states(layer)` | Mamba SSM state |
|
||||
| `CausalConvResourceHandler` | `MambaHybridCacheManager` | `get_conv_states(layer)` | Mamba causal conv state |
|
||||
| `StateResourceHandler` | Local allocation | `handler.allocate()` | Generic per-sequence state |
|
||||
| `UnpagedResourceHandler` | Local allocation | `handler.allocate()` | Unpaged per-token resources |
|
||||
|
||||
## Key Files and Their Responsibilities
|
||||
|
||||
### `tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py`
|
||||
|
||||
- **`ResourceHandler`** (abstract base): Interface for allocating resources
|
||||
- **`KVPagedResourceHandler`**: Describes paged KV cache with `num_kv_heads`, `head_dim`, `dtype`, `kv_layout`
|
||||
- **`SSMResourceHandler`**: Describes Mamba SSM state with `num_heads`, `head_dim`, `d_state`
|
||||
- **`CausalConvResourceHandler`**: Describes causal conv state with `conv_dim`, `d_conv`
|
||||
- **`AttentionDescriptor.get_cache_initializers()`**: Returns `ResourceHandlerDict` mapping names to handlers
|
||||
|
||||
### `tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py`
|
||||
|
||||
- **`InsertCachedAttention`**: Iterates over attention nodes, calls `get_cache_initializers()`, and registers handlers via `cm.add_resource()`
|
||||
- **`InitializeCache`**: Triggers `cm.initialize_resources()` to allocate all caches
|
||||
- **`ResizeKVCache`**: Runs forward pass, measures memory, and calls `cm.resize_kv_cache_manager()`
|
||||
|
||||
### `tensorrt_llm/_torch/auto_deploy/shim/interface.py`
|
||||
|
||||
- **`CachedSequenceInterface`**: Central class managing all caches
|
||||
- `_resource_lookup`: Dict of all registered resource handlers
|
||||
- `_unmanaged_resources`: List tracking locally-allocated (non-managed) resource names
|
||||
- `add_resource()`: Stores handlers in `_resource_lookup`
|
||||
- `initialize_resources()`: Initializes caches, creates cache managers, assigns views
|
||||
- `_identify_managed_kv_resources()`: Finds compatible KV handlers from `_resource_lookup`
|
||||
- `_identify_managed_state_resources()`: Finds compatible SSM/Conv handlers with constraint checking
|
||||
- `_create_kv_cache_manager()`: Creates `KVCacheManager` or `MambaHybridCacheManager`
|
||||
- `_allocate_unmanaged_resources()`: Allocates resources not managed by cache managers, tracks in `_unmanaged_resources`
|
||||
|
||||
## Example Flow: FlashInfer Attention
|
||||
|
||||
```python
|
||||
# In flashinfer_attention.py
|
||||
class FlashinferCachedAttention(AttentionDescriptor):
|
||||
@classmethod
|
||||
def get_cache_initializers(cls, source_attn_node, cache_config):
|
||||
k_fake = source_attn_node.args[1].meta["val"]
|
||||
return {
|
||||
"kv_cache": KVPagedResourceHandler(
|
||||
num_kv_heads=k_fake.shape[2],
|
||||
head_dim=k_fake.shape[3],
|
||||
dtype=cls.resolve_cache_dtype(cache_config.dtype, k_fake.dtype),
|
||||
kv_factor=2,
|
||||
kv_layout="HND",
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
**This handler gets:**
|
||||
|
||||
1. **Collected** by `InsertCachedAttention` → `cm.add_resource("kv_cache_0", handler)`
|
||||
1. **Stored** in `_resource_lookup` during `initialize_resources()`
|
||||
1. **Identified** as manageable by `_identify_managed_kv_resources()` if compatible with other KV handlers
|
||||
1. **View assigned** via `self._caches["kv_cache_0"] = manager.get_buffers(0, kv_layout="HND")`
|
||||
|
||||
## Compatibility Rules
|
||||
|
||||
### KV Cache Compatibility (for KVCacheManager)
|
||||
|
||||
Handlers are compatible if they match on:
|
||||
|
||||
- `head_dim`
|
||||
- `dtype`
|
||||
- `kv_factor`
|
||||
- `kv_layout`
|
||||
|
||||
Note: `num_kv_heads` can differ (supports GQA/MQA with varying head counts per layer).
|
||||
|
||||
### State Resource Compatibility (for MambaHybridCacheManager)
|
||||
|
||||
**SSM Resources**: Compatible if `state_shape` and `dtype` match.
|
||||
|
||||
**Conv Resources**: Compatible if `state_shape` and `dtype` match, **AND** the n_groups constraint holds:
|
||||
|
||||
```text
|
||||
conv_dim = head_dim * num_heads + 2 * n_groups * d_state
|
||||
```
|
||||
|
||||
If this constraint cannot be satisfied with integer `n_groups >= 0`, Conv resources fall back to local allocation.
|
||||
@ -60,6 +60,8 @@ The exported graph then undergoes a series of automated transformations, includi
|
||||
- [Incorporating AutoDeploy into Your Own Workflow](./advanced/workflow.md)
|
||||
- [Expert Configurations](./advanced/expert_configurations.md)
|
||||
- [Performance Benchmarking](./advanced/benchmarking_with_trtllm_bench.md)
|
||||
- [KV Cache Architecture](./advanced/kv_cache_architecture.md)
|
||||
- [Export ONNX for EdgeLLM](./advanced/export_onnx.md)
|
||||
|
||||
## Roadmap
|
||||
|
||||
|
||||
@ -1,98 +0,0 @@
|
||||
# Benchmarking with trtllm-bench
|
||||
|
||||
AutoDeploy is integrated with the `trtllm-bench` performance benchmarking utility, enabling you to measure comprehensive performance metrics such as token throughput, request throughput, and latency for your AutoDeploy-optimized models.
|
||||
|
||||
## Getting Started
|
||||
|
||||
Before benchmarking with AutoDeploy, review the [TensorRT-LLM benchmarking guide](../../../performance/perf-benchmarking.md#running-with-the-pytorch-workflow) to familiarize yourself with the standard trtllm-bench workflow and best practices.
|
||||
|
||||
## Basic Usage
|
||||
|
||||
Invoke the AutoDeploy backend by specifying `--backend _autodeploy` in your `trtllm-bench` command:
|
||||
|
||||
```bash
|
||||
trtllm-bench \
|
||||
--model meta-llama/Llama-3.1-8B \
|
||||
throughput \
|
||||
--dataset /tmp/synthetic_128_128.txt \
|
||||
--backend _autodeploy
|
||||
```
|
||||
|
||||
```{note}
|
||||
As in the PyTorch workflow, AutoDeploy does not require a separate `trtllm-bench build` step. The model is automatically optimized during benchmark initialization.
|
||||
```
|
||||
|
||||
## Advanced Configuration
|
||||
|
||||
For more granular control over AutoDeploy's behavior during benchmarking, use the `--config` flag with a YAML configuration file:
|
||||
|
||||
```bash
|
||||
trtllm-bench \
|
||||
--model meta-llama/Llama-3.1-8B \
|
||||
throughput \
|
||||
--dataset /tmp/synthetic_128_128.txt \
|
||||
--backend _autodeploy \
|
||||
--config autodeploy_config.yaml
|
||||
```
|
||||
|
||||
### Configuration Examples
|
||||
|
||||
#### Basic Performance Configuration (`autodeploy_config.yaml`)
|
||||
|
||||
```yaml
|
||||
# runtime engine
|
||||
runtime: trtllm
|
||||
|
||||
# model loading
|
||||
skip_loading_weights: false
|
||||
|
||||
# Sequence configuration
|
||||
max_batch_size: 256
|
||||
|
||||
# transform options
|
||||
# KV cache configuration
|
||||
kv_cache_config:
|
||||
# fraction of free memory to use for kv-caches
|
||||
free_gpu_memory_fraction: 0.9
|
||||
|
||||
# transform options
|
||||
transforms:
|
||||
insert_cached_attention:
|
||||
# attention backend
|
||||
backend: flashinfer
|
||||
compile_model:
|
||||
# compilation backend
|
||||
backend: torch-opt
|
||||
# CUDA Graph optimization
|
||||
cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 32, 64, 128, 256]
|
||||
```
|
||||
|
||||
Enable multi-GPU execution by specifying `--tp n`, where `n` is the number of GPUs.
|
||||
|
||||
## Configuration Options Reference
|
||||
|
||||
### Core Performance Settings
|
||||
|
||||
| Parameter | Default | Description |
|
||||
|-----------|---------|-------------|
|
||||
| `compile_backend` | `torch-compile` | Compilation backend: `torch-simple`, `torch-compile`, `torch-cudagraph`, `torch-opt` |
|
||||
| `runtime` | `trtllm` | Runtime engine: `trtllm`, `demollm` |
|
||||
| `kv_cache_config.free_gpu_memory_fraction` | `0.9` | Fraction of available GPU memory for KV cache (0.0-1.0) |
|
||||
| `skip_loading_weights` | `false` | Skip weight loading for architecture-only benchmarks |
|
||||
|
||||
### CUDA Graph Optimization
|
||||
|
||||
| Parameter | Default | Description |
|
||||
|-----------|---------|-------------|
|
||||
| `cuda_graph_batch_sizes` | `null` | List of batch sizes for CUDA graph creation |
|
||||
|
||||
```{tip}
|
||||
For optimal CUDA graph performance, specify batch sizes that match your expected workload patterns. For example: `[1, 2, 4, 8, 16, 32, 64, 128]`
|
||||
```
|
||||
|
||||
## Performance Optimization Tips
|
||||
|
||||
1. **Memory Management**: Set `kv_cache_config.free_gpu_memory_fraction` to 0.8-0.9 for optimal KV cache utilization
|
||||
1. **Compilation Backend**: Use `torch-opt` for production workloads
|
||||
1. **Attention Backend**: `flashinfer` generally provides the best performance for most models
|
||||
1. **CUDA Graphs**: Enable CUDA graphs for batch sizes that match your production traffic patterns.
|
||||
@ -1,49 +0,0 @@
|
||||
# Example Run Script
|
||||
|
||||
To build and run AutoDeploy example, use the `examples/auto_deploy/build_and_run_ad.py` script:
|
||||
|
||||
```bash
|
||||
cd examples/auto_deploy
|
||||
python build_and_run_ad.py --model "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
||||
```
|
||||
|
||||
You can configure your experiment with various options. Use the `-h/--help` flag to see available options:
|
||||
|
||||
```bash
|
||||
python build_and_run_ad.py --help
|
||||
```
|
||||
|
||||
The following is a non-exhaustive list of common configuration options:
|
||||
|
||||
| Configuration Key | Description |
|
||||
|-------------------|-------------|
|
||||
| `--model` | The HF model card or path to a HF checkpoint folder |
|
||||
| `--args.model-factory` | Choose model factory implementation (`"AutoModelForCausalLM"`, ...) |
|
||||
| `--args.skip-loading-weights` | Only load the architecture, not the weights |
|
||||
| `--args.model-kwargs` | Extra kwargs that are being passed to the model initializer in the model factory |
|
||||
| `--args.tokenizer-kwargs` | Extra kwargs that are being passed to the tokenizer initializer in the model factory |
|
||||
| `--args.world-size` | The number of GPUs used for auto-sharding the model |
|
||||
| `--args.runtime` | Specifies which type of Engine to use during runtime (`"demollm"` or `"trtllm"`) |
|
||||
| `--args.compile-backend` | Specifies how to compile the graph at the end |
|
||||
| `--args.attn-backend` | Specifies kernel implementation for attention |
|
||||
| `--args.mla-backend` | Specifies implementation for multi-head latent attention |
|
||||
| `--args.max-seq-len` | Maximum sequence length for inference/cache |
|
||||
| `--args.max-batch-size` | Maximum dimension for statically allocated KV cache |
|
||||
| `--args.attn-page-size` | Page size for attention |
|
||||
| `--prompt.batch-size` | Number of queries to generate |
|
||||
| `--benchmark.enabled` | Whether to run the built-in benchmark (true/false) |
|
||||
|
||||
For default values and additional configuration options, refer to the `ExperimentConfig` class in `examples/auto_deploy/build_and_run_ad.py` file.
|
||||
|
||||
The following is a more complete example of using the script:
|
||||
|
||||
```bash
|
||||
cd examples/auto_deploy
|
||||
python build_and_run_ad.py \
|
||||
--model "TinyLlama/TinyLlama-1.1B-Chat-v1.0" \
|
||||
--args.world-size 2 \
|
||||
--args.runtime "demollm" \
|
||||
--args.compile-backend "torch-compile" \
|
||||
--args.attn-backend "flashinfer" \
|
||||
--benchmark.enabled True
|
||||
```
|
||||
@ -1,268 +0,0 @@
|
||||
# Expert Configuration of LLM API
|
||||
|
||||
For advanced TensorRT-LLM users, the full set of `tensorrt_llm._torch.auto_deploy.llm_args.LlmArgs` is exposed. Use at your own risk. The argument list may diverge from the standard TRT-LLM argument list.
|
||||
|
||||
- All configuration fields used by the AutoDeploy core pipeline, `InferenceOptimizer`, are exposed exclusively in `AutoDeployConfi`g in `tensorrt_llm._torch.auto_deploy.llm_args`.
|
||||
Please make sure to refer to those first.
|
||||
- For advanced users, the full set of `LlmArgs` in `tensorrt_llm._torch.auto_deploy.llm_args` can be used to configure the AutoDeploy `LLM` API, including runtime options.
|
||||
- Note that some fields in the full `LlmArgs`
|
||||
object are overlapping, duplicated, and/or _ignored_ in AutoDeploy, particularly arguments
|
||||
pertaining to configuring the model itself since AutoDeploy's model ingestion+optimize pipeline
|
||||
significantly differs from the default manual workflow in TensorRT-LLM.
|
||||
- However, with the proper care the full `LlmArgs`
|
||||
objects can be used to configure advanced runtime options in TensorRT-LLM.
|
||||
- Any valid field can be simply provided as keyword argument ("`**kwargs`") to the AutoDeploy `LLM` API.
|
||||
|
||||
# Expert Configuration of `build_and_run_ad.py`
|
||||
|
||||
For advanced users, `build_and_run_ad.py` provides advanced configuration capabilities using a flexible argument parser powered by PyDantic Settings and OmegaConf. You can use dot notation for CLI arguments, provide multiple YAML configuration files, and utilize sophisticated configuration precedence rules to create complex deployment configurations.
|
||||
|
||||
## CLI Arguments with Dot Notation
|
||||
|
||||
The script supports flexible CLI argument parsing using dot notation to modify nested configurations dynamically. You can target any field in both the `ExperimentConfig` in `examples/auto_deploy/build_and_run_ad.py` and nested `AutoDeployConfig` or `LlmArgs` objects in `tensorrt_llm._torch.auto_deploy.llm_args`:
|
||||
|
||||
```bash
|
||||
# Configure model parameters
|
||||
# NOTE: config values like num_hidden_layers are automatically resolved into the appropriate nested
|
||||
# dict value ``{"args": {"model_kwargs": {"num_hidden_layers": 10}}}`` although not explicitly
|
||||
# specified as CLI arg
|
||||
python build_and_run_ad.py \
|
||||
--model "meta-llama/Meta-Llama-3.1-8B-Instruct" \
|
||||
--args.model-kwargs.num-hidden-layers=10 \
|
||||
--args.model-kwargs.hidden-size=2048 \
|
||||
--args.tokenizer-kwargs.padding-side=left
|
||||
|
||||
# Configure runtime and backend options
|
||||
python build_and_run_ad.py \
|
||||
--model "TinyLlama/TinyLlama-1.1B-Chat-v1.0" \
|
||||
--args.world-size=2 \
|
||||
--args.compile-backend=torch-opt \
|
||||
--args.attn-backend=flashinfer
|
||||
|
||||
# Configure prompting and benchmarking
|
||||
python build_and_run_ad.py \
|
||||
--model "microsoft/phi-4" \
|
||||
--prompt.batch-size=4 \
|
||||
--prompt.sp-kwargs.max-tokens=200 \
|
||||
--prompt.sp-kwargs.temperature=0.7 \
|
||||
--benchmark.enabled=true \
|
||||
--benchmark.bs=8 \
|
||||
--benchmark.isl=1024
|
||||
```
|
||||
|
||||
## YAML Configuration Files
|
||||
|
||||
Both `ExperimentConfig` and `AutoDeployConfig`/`LlmArgs` inherit from `DynamicYamlMixInForSettings`, which enables you to provide multiple YAML configuration files that are automatically deep-merged at runtime.
|
||||
|
||||
Create a YAML configuration file (e.g., `my_config.yaml`):
|
||||
|
||||
```yaml
|
||||
# my_config.yaml
|
||||
args:
|
||||
model_kwargs:
|
||||
num_hidden_layers: 12
|
||||
hidden_size: 1024
|
||||
world_size: 4
|
||||
max_seq_len: 2048
|
||||
max_batch_size: 16
|
||||
transforms:
|
||||
detect_sharding:
|
||||
support_partial_config: true
|
||||
insert_cached_attention:
|
||||
backend: triton
|
||||
compile_model:
|
||||
backend: torch-compile
|
||||
|
||||
prompt:
|
||||
batch_size: 8
|
||||
sp_kwargs:
|
||||
max_tokens: 150
|
||||
temperature: 0.8
|
||||
top_k: 50
|
||||
```
|
||||
|
||||
Create an additional override file (e.g., `production.yaml`):
|
||||
|
||||
```yaml
|
||||
# production.yaml
|
||||
args:
|
||||
world_size: 8
|
||||
max_batch_size: 32
|
||||
transforms:
|
||||
compile_model:
|
||||
backend: torch-opt
|
||||
```
|
||||
|
||||
Then use these configurations:
|
||||
|
||||
```bash
|
||||
# Using single YAML config
|
||||
python build_and_run_ad.py \
|
||||
--model "meta-llama/Meta-Llama-3.1-8B-Instruct" \
|
||||
--yaml-extra my_config.yaml
|
||||
|
||||
# Using multiple YAML configs (deep merged in order, later files have higher priority)
|
||||
python build_and_run_ad.py \
|
||||
--model "meta-llama/Meta-Llama-3.1-8B-Instruct" \
|
||||
--yaml-extra my_config.yaml production.yaml
|
||||
|
||||
# Targeting nested AutoDeployConfig with separate YAML
|
||||
python build_and_run_ad.py \
|
||||
--model "meta-llama/Meta-Llama-3.1-8B-Instruct" \
|
||||
--yaml-extra my_config.yaml \
|
||||
--args.yaml-extra autodeploy_overrides.yaml
|
||||
```
|
||||
|
||||
## Configuration Precedence and Deep Merging
|
||||
|
||||
The configuration system follows a precedence order in which higher priority sources override lower priority ones:
|
||||
|
||||
1. **CLI Arguments** (highest priority) - Direct command line arguments
|
||||
1. **YAML Configs** - Files specified via `--yaml-extra` and `--args.yaml-extra`
|
||||
1. **Default Settings** (lowest priority) - Built-in defaults from the config classes
|
||||
|
||||
**Deep Merging**: Unlike simple overwriting, deep merging recursively combines nested dictionaries. For example:
|
||||
|
||||
```yaml
|
||||
# Base config
|
||||
args:
|
||||
model_kwargs:
|
||||
num_hidden_layers: 10
|
||||
hidden_size: 1024
|
||||
max_seq_len: 2048
|
||||
```
|
||||
|
||||
```yaml
|
||||
# Override config
|
||||
args:
|
||||
model_kwargs:
|
||||
hidden_size: 2048 # This will override
|
||||
# num_hidden_layers: 10 remains unchanged
|
||||
world_size: 4 # This gets added
|
||||
```
|
||||
|
||||
**Nested Config Behavior**: When using nested configurations, outer YAML configuration files become initialization settings for inner objects, giving them higher precedence:
|
||||
|
||||
```bash
|
||||
# The outer yaml-extra affects the entire ExperimentConfig
|
||||
# The inner args.yaml-extra affects only the AutoDeployConfig
|
||||
python build_and_run_ad.py \
|
||||
--model "meta-llama/Meta-Llama-3.1-8B-Instruct" \
|
||||
--yaml-extra experiment_config.yaml \
|
||||
--args.yaml-extra autodeploy_config.yaml \
|
||||
--args.world-size=8 # CLI override beats both YAML configs
|
||||
```
|
||||
|
||||
## Sharding configuration
|
||||
|
||||
The `detect_sharding` transform automatically detects and applies sharding strategies to the model. It supports multiple sharding sources and dimensions, allowing flexible configuration for different model architectures and parallelism strategies.
|
||||
|
||||
### Configuration Parameters
|
||||
|
||||
The `detect_sharding` transform accepts the following configuration parameters:
|
||||
|
||||
#### `simple_shard_only` (bool, default: `false`)
|
||||
|
||||
When set to `true`, forces simple sharding (row-wise sharding with all-gather) for all linear layers, bypassing more sophisticated column/row sharding strategies. This is useful when you want a uniform sharding approach across all layers or when debugging sharding issues.
|
||||
|
||||
#### `sharding_source` (list, default: `['manual', 'factory', 'heuristic']`)
|
||||
|
||||
Specifies the priority order of sharding sources. The order matters: if multiple sources try to apply sharding to the same layer, only the first one in the list will be applied. The available sources are:
|
||||
|
||||
- **`'manual'`**: Uses manually provided sharding configuration via `manual_config` parameter
|
||||
- **`'factory'`**: Uses factory-provided sharding configuration (e.g., from HuggingFace model configs)
|
||||
- **`'heuristic'`**: Uses automatic heuristic-based sharding detection based on layer patterns
|
||||
|
||||
Example: If both `manual` and `heuristic` try to apply sharding to layer L, only the `manual` transformation will be applied since it appears first in the list.
|
||||
|
||||
#### `support_partial_config` (bool, default: `true`)
|
||||
|
||||
When `true`, allows partial sharding configurations where not all layers need to be specified in the manual or factory config. Layers not explicitly configured will be handled by heuristic sharding or left unsharded. When `false`, the configuration must specify all layers or it will be invalidated and skipped.
|
||||
|
||||
#### `sharding_dims` (list, default: `['tp', 'ep', 'bmm']`)
|
||||
|
||||
Specifies which sharding dimensions to apply during heuristic sharding. The available dimensions are:
|
||||
|
||||
- **`'tp'`**: Tensor parallelism - applies column/row sharding for standard transformer layers
|
||||
- **`'ep'`**: Expert parallelism - shards experts across ranks for Mixture-of-Experts (MoE) models
|
||||
- **`'bmm'`**: Batch matrix multiplication sharding - shards batch matrix multiplication operations
|
||||
- **`'ssm'`**: State space model sharding - applies specialized sharding for Mamba/SSM layers
|
||||
|
||||
You can enable multiple dimensions simultaneously. For example, `['tp', 'ep']` will apply both tensor parallelism and expert parallelism.
|
||||
|
||||
#### `process_grid` (dict, default: `None`)
|
||||
|
||||
Specifies a 2D device mesh for hybrid EP+TP parallelism.
|
||||
|
||||
- NOTE 1: This grid applies only to the MoE layers. Attention, Mamba, and MLP layers are unaffected.
|
||||
- NOTE 2: The order of the keys matters. Process grid's layout is in the generalized column-major order,
|
||||
that is, the last dimension is stride-one.
|
||||
- NOTE 3: `ep * tp` must be equal to the provided world size. Otherwise, the mesh will be considered invalid,
|
||||
and 1D ep-only parallelism will be applied.
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
process_grid: {'ep': 2, 'tp': 2}
|
||||
```
|
||||
|
||||
If `world_size == 4`, ranks \[0,1\] and \[2,3\] will create two EP groups. Experts will be distributed across these two
|
||||
groups, and internally, TP=2 column-row sharding will be applied.
|
||||
|
||||
#### `requires_shape_prop` (bool, default: `true`)
|
||||
|
||||
Whether shape propagation is required before applying this transform. Shape propagation enables the transform to make informed decisions about sharding strategies based on tensor dimensions.
|
||||
|
||||
### Manual TP Sharding Configuration
|
||||
|
||||
For advanced users, you can provide a manual sharding configuration. An example of such setting:
|
||||
|
||||
```yaml
|
||||
args:
|
||||
transforms:
|
||||
detect_sharding:
|
||||
manual_config:
|
||||
head_dim: 128
|
||||
tp_plan:
|
||||
# mamba SSM layers
|
||||
in_proj: mamba
|
||||
out_proj: rowwise
|
||||
# attention layers
|
||||
q_proj: colwise
|
||||
k_proj: colwise
|
||||
v_proj: colwise
|
||||
o_proj: rowwise
|
||||
# NOTE: for performance reason, consider not sharding the following
|
||||
# layers at all. Commenting out the following layers will replicate
|
||||
# them across ranks.
|
||||
# MLP and shared experts in MoE layers
|
||||
gate_proj: colwise
|
||||
up_proj: colwise
|
||||
down_proj: rowwise
|
||||
# MoLE: latent projections: simple shard
|
||||
fc1_latent_proj: gather
|
||||
fc2_latent_proj: gather
|
||||
```
|
||||
|
||||
The `tp_plan` dictionary maps layer names (using module paths with wildcard `*` support) to sharding strategies:
|
||||
|
||||
- **`colwise`**: Column-wise sharding (splits the weight matrix along columns)
|
||||
- **`rowwise`**: Row-wise sharding (splits the weight matrix along rows)
|
||||
- **`mamba`**: Specialized sharding for Mamba SSM layers
|
||||
- **`gather`**: Simple shard with row-wise sharding and all-gather operation
|
||||
|
||||
## Built-in Default Configuration
|
||||
|
||||
Both `AutoDeployConfig` and `LlmArgs` classes automatically load a built-in `default.yaml` configuration file that provides defaults for the AutoDeploy inference optimizer pipeline. This file is specified in the `_get_config_dict()` function in `tensorrt_llm._torch.auto_deploy.llm_args` and defines default transform configurations for graph optimization stages.
|
||||
|
||||
The built-in defaults are automatically merged with your configurations at the lowest priority level, ensuring that your custom settings always override the defaults. You can inspect the current default configuration to understand the baseline transform pipeline:
|
||||
|
||||
```bash
|
||||
# View the default configuration
|
||||
cat tensorrt_llm/_torch/auto_deploy/config/default.yaml
|
||||
|
||||
# Override specific transform settings
|
||||
python build_and_run_ad.py \
|
||||
--model "TinyLlama/TinyLlama-1.1B-Chat-v1.0" \
|
||||
--args.transforms.export-to-gm.strict=true
|
||||
```
|
||||
@ -1,14 +0,0 @@
|
||||
# Logging Level
|
||||
|
||||
Use the following env variable to specify the logging level of our built-in logger, ordered by
|
||||
decreasing verbosity;
|
||||
|
||||
```bash
|
||||
AUTO_DEPLOY_LOG_LEVEL=DEBUG
|
||||
AUTO_DEPLOY_LOG_LEVEL=INFO
|
||||
AUTO_DEPLOY_LOG_LEVEL=WARNING
|
||||
AUTO_DEPLOY_LOG_LEVEL=ERROR
|
||||
AUTO_DEPLOY_LOG_LEVEL=INTERNAL_ERROR
|
||||
```
|
||||
|
||||
The default log level is `INFO`.
|
||||
@ -1,88 +0,0 @@
|
||||
# Serving with trtllm-serve
|
||||
|
||||
AutoDeploy integrates with the OpenAI-compatible `trtllm-serve` CLI so you can expose AutoDeploy-optimized models over HTTP without writing server code. This page shows how to launch the server with the AutoDeploy backend, configure it via YAML, and validate with a simple request.
|
||||
|
||||
## Quick start
|
||||
|
||||
Launch `trtllm-serve` with the AutoDeploy backend by setting `--backend _autodeploy`:
|
||||
|
||||
```bash
|
||||
trtllm-serve \
|
||||
meta-llama/Llama-3.1-8B-Instruct \
|
||||
--backend _autodeploy
|
||||
```
|
||||
|
||||
- `model`: HF name or local path
|
||||
- `--backend _autodeploy`: uses AutoDeploy runtime
|
||||
|
||||
Once the server is ready, test with an OpenAI-compatible request:
|
||||
|
||||
```bash
|
||||
curl -s http://localhost:8000/v1/chat/completions \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"messages":[{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Where is New York? Tell me in a single sentence."}],
|
||||
"max_tokens": 32
|
||||
}'
|
||||
```
|
||||
|
||||
## Configuration via YAML
|
||||
|
||||
Use `--config` to supply a YAML file that augments or overrides server/runtime settings.
|
||||
|
||||
```bash
|
||||
trtllm-serve \
|
||||
meta-llama/Llama-3.1-8B \
|
||||
--backend _autodeploy \
|
||||
--config autodeploy_config.yaml
|
||||
```
|
||||
|
||||
Example `autodeploy_config.yaml`:
|
||||
|
||||
```yaml
|
||||
# runtime engine
|
||||
runtime: trtllm
|
||||
|
||||
# model loading
|
||||
skip_loading_weights: false
|
||||
|
||||
# Sequence configuration
|
||||
max_batch_size: 256
|
||||
|
||||
# multi-gpu execution
|
||||
world_size: 1
|
||||
|
||||
# transform options
|
||||
# KV cache configuration
|
||||
kv_cache_config:
|
||||
# fraction of free memory to use for kv-caches
|
||||
free_gpu_memory_fraction: 0.9
|
||||
|
||||
# transform options
|
||||
transforms:
|
||||
insert_cached_attention:
|
||||
# attention backend
|
||||
backend: flashinfer
|
||||
compile_model:
|
||||
# compilation backend
|
||||
backend: torch-opt
|
||||
# CUDA Graph optimization
|
||||
cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 32, 64, 128, 256]
|
||||
```
|
||||
|
||||
## Limitations and tips
|
||||
|
||||
- KV cache block reuse is disabled automatically for AutoDeploy backend
|
||||
- AutoDeploy backend doesn't yet support disaggregated serving. WIP
|
||||
- For best performance:
|
||||
- Prefer `compile_backend: torch-opt`
|
||||
- Use `attn_backend: flashinfer`
|
||||
- Set realistic `cuda_graph_batch_sizes` that match expected traffic
|
||||
- Tune `kv_cache_config.free_gpu_memory_fraction` to 0.8–0.9
|
||||
|
||||
## See also
|
||||
|
||||
- [AutoDeploy overview](../auto-deploy.md)
|
||||
- [Benchmarking with trtllm-bench](./benchmarking_with_trtllm_bench.md)
|
||||
@ -1,30 +0,0 @@
|
||||
### Incorporating `auto_deploy` into your own workflow
|
||||
|
||||
AutoDeploy can be seamlessly integrated into existing workflows using TRT-LLM's LLM high-level API. This section provides an example for configuring and invoking AutoDeploy in custom applications.
|
||||
|
||||
The following example demonstrates how to build an LLM object with AutoDeploy integration:
|
||||
|
||||
```
|
||||
from tensorrt_llm._torch.auto_deploy import LLM
|
||||
|
||||
|
||||
# Construct the LLM high-level interface object with autodeploy as backend
|
||||
llm = LLM(
|
||||
model=<HF_MODEL_CARD_OR_DIR>,
|
||||
world_size=<DESIRED_WORLD_SIZE>,
|
||||
model_factory="AutoModelForCausalLM", # choose appropriate model factory
|
||||
model_kwargs={"num_hidden_layers": 2}, # test with smaller model configuration
|
||||
transforms={
|
||||
"insert_cached_attention": {"backend": "flashinfer"}, # or "triton"
|
||||
"insert_cached_mla_attention": {"backend": "MultiHeadLatentAttention"},
|
||||
"compile_model": {"backend": "torch-compile"},
|
||||
"detect_sharding": {"simple_shard_only": False},
|
||||
},
|
||||
skip_loading_weights=False,
|
||||
max_seq_len=<MAX_SEQ_LEN>,
|
||||
max_batch_size=<MAX_BATCH_SIZE>,
|
||||
)
|
||||
|
||||
```
|
||||
|
||||
For more information about configuring AutoDeploy via the `LLM` API using `**kwargs`, see the AutoDeploy LLM API in `tensorrt_llm._torch.auto_deploy.llm` and the `AutoDeployConfig` class in `tensorrt_llm._torch.auto_deploy.llm_args`.
|
||||
@ -1,82 +0,0 @@
|
||||
# AutoDeploy
|
||||
|
||||
```{note}
|
||||
This project is under active development and is currently in a prototype stage. The code is experimental, subject to change, and may include backward-incompatible updates. While we strive for correctness, there are no guarantees regarding functionality, stability, or reliability.
|
||||
```
|
||||
|
||||
### Seamless Model Deployment from PyTorch to TensorRT-LLM
|
||||
|
||||
AutoDeploy is a prototype designed to simplify and accelerate the deployment of PyTorch models, including off-the-shelf models such as those from the Hugging Face Transformers library, to TensorRT-LLM.
|
||||
|
||||

|
||||
<sub><em>AutoDeploy overview and relation with TensorRT-LLM's LLM API</em></sub>
|
||||
|
||||
AutoDeploy provides an alternative method for deploying models using the LLM API without requiring code changes to the source model (for example, Hugging Face Transformers models) or manual implementation of inference optimizations, such as KV-caches, multi-GPU parallelism, or quantization. Instead, AutoDeploy extracts a computation graph from the source model and applies inference optimizations through a series of automated graph transformations. AutoDeploy generates an inference-optimized graph that can be directly executed in the TensorRT-LLM PyTorch runtime and leverages various runtime optimizations including in-flight batching, paging, and overlap scheduling.
|
||||
|
||||
### Key Feature:
|
||||
|
||||
- **Seamless Model Translation:** Automatically converts PyTorch/Hugging Face models to TensorRT-LLM without manual rewrites.
|
||||
- **Unified Model Definition:** Maintain a single source of truth with your original PyTorch/Hugging Face model.
|
||||
- **Optimized Inference:** Built-in transformations for sharding, quantization, KV-cache integration, MHA fusion, and CudaGraph optimization.
|
||||
- **Immediate Deployment:** Day-0 support for models with continuous performance enhancements.
|
||||
- **Quick Setup & Prototyping:** Lightweight pip package for easy installation with a demo environment for fast testing.
|
||||
|
||||
## Get Started
|
||||
|
||||
1. **Install AutoDeploy:**
|
||||
|
||||
AutoDeploy is included with the TRT-LLM installation.
|
||||
|
||||
```bash
|
||||
sudo apt-get -y install libopenmpi-dev && pip3 install --upgrade pip setuptools && pip3 install tensorrt_llm
|
||||
```
|
||||
|
||||
You can refer to [TRT-LLM installation guide](../../installation/linux.md) for more information.
|
||||
|
||||
2. **Run Llama Example:**
|
||||
|
||||
You are now ready to run an in-framework LLama Demo.
|
||||
|
||||
The general entry point for running the AutoDeploy demo is the `build_and_run_ad.py` script, Checkpoints are loaded directly from Huggingface (HF) or a local HF-like directory:
|
||||
|
||||
```bash
|
||||
cd examples/auto_deploy
|
||||
python build_and_run_ad.py --model "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
||||
```
|
||||
|
||||
## Support Matrix
|
||||
|
||||
AutoDeploy streamlines the model deployment process through an automated workflow designed for efficiency and performance. The workflow begins with a PyTorch model, which is exported using `torch.export` to generate a standard Torch graph. This graph contains core PyTorch ATen operations alongside custom attention operations, determined by the attention backend specified in the configuration.
|
||||
|
||||
The exported graph then undergoes a series of automated transformations, including graph sharding, KV-cache insertion, and GEMM fusion, to optimize model performance. After these transformations, the graph is compiled using one of the supported compile backends (like `torch-opt`), followed by deploying it via the TensorRT-LLM runtime.
|
||||
|
||||
- [Support Matrix](support_matrix.md)
|
||||
|
||||
## Advanced Usage
|
||||
|
||||
- [Example Run Script](./advanced/example_run.md)
|
||||
- [Logging Level](./advanced/logging.md)
|
||||
- [Incorporating AutoDeploy into Your Own Workflow](./advanced/workflow.md)
|
||||
- [Expert Configurations](./advanced/expert_configurations.md)
|
||||
- [Performance Benchmarking](./advanced/benchmarking_with_trtllm_bench.md)
|
||||
- [Serving with trtllm-serve](./advanced/serving_with_trtllm_serve.md)
|
||||
- [Export ONNX for EdgeLLM](./advanced/export_onnx.md)
|
||||
|
||||
## Roadmap
|
||||
|
||||
We are actively expanding AutoDeploy to support a broader range of model architectures and inference features.
|
||||
|
||||
**Upcoming Model Support:**
|
||||
|
||||
- Vision-Language Models (VLMs)
|
||||
|
||||
- Structured State Space Models (SSMs) and Linear Attention architectures
|
||||
|
||||
**Planned Features:**
|
||||
|
||||
- Low-Rank Adaptation (LoRA)
|
||||
|
||||
- Speculative Decoding for accelerated generation
|
||||
|
||||
To track development progress and contribute, visit our [Github Project Board](https://github.com/orgs/NVIDIA/projects/83/views/13).
|
||||
We welcome community contributions, see `examples/auto_deploy/CONTRIBUTING.md` for guidelines.
|
||||
@ -1,129 +0,0 @@
|
||||
## Support Matrix
|
||||
|
||||
AutoDeploy streamlines model deployment with an automated workflow designed for efficiency and performance. The workflow begins with a PyTorch model, which is exported using `torch.export` to generate a standard Torch graph. This graph contains core PyTorch ATen operations alongside custom attention operations, determined by the attention backend specified in the configuration.
|
||||
|
||||
The exported graph then undergoes a series of automated transformations, including graph sharding, KV-cache insertion, and GEMM fusion, to optimize model performance. After these transformations, the graph is compiled using one of the supported compile backends (like `torch-opt`), followed by deploying it via the TRT-LLM runtime.
|
||||
|
||||
### Support Models
|
||||
|
||||
**Bring Your Own Model**: AutoDeploy leverages `torch.export` and dynamic graph pattern matching, enabling seamless integration for a wide variety of models without relying on hard-coded architectures.
|
||||
|
||||
AutoDeploy supports Hugging Face models compatible with `AutoModelForCausalLM` and `AutoModelForImageTextToText`.
|
||||
In addition, the following models have been officially validated using the default configuration: `runtime=trtllm`, `compile_backend=torch-compile`, and `attn_backend=flashinfer`
|
||||
|
||||
<details>
|
||||
<summary>Click to expand supported models list</summary>
|
||||
|
||||
- Qwen/QwQ-32B
|
||||
- Qwen/Qwen2.5-0.5B-Instruct
|
||||
- Qwen/Qwen2.5-1.5B-Instruct
|
||||
- Qwen/Qwen2.5-3B-Instruct
|
||||
- Qwen/Qwen2.5-7B-Instruct
|
||||
- Qwen/Qwen3-0.6B
|
||||
- Qwen/Qwen3-235B-A22B
|
||||
- Qwen/Qwen3-30B-A3B
|
||||
- Qwen/Qwen3-4B
|
||||
- Qwen/Qwen3-8B
|
||||
- TinyLlama/TinyLlama-1.1B-Chat-v1.0
|
||||
- apple/OpenELM-1_1B-Instruct
|
||||
- apple/OpenELM-270M-Instruct
|
||||
- apple/OpenELM-3B-Instruct
|
||||
- apple/OpenELM-450M-Instruct
|
||||
- bigcode/starcoder2-15b-instruct-v0.1
|
||||
- bigcode/starcoder2-7b
|
||||
- deepseek-ai/DeepSeek-Prover-V1.5-SFT
|
||||
- deepseek-ai/DeepSeek-Prover-V2-7B
|
||||
- deepseek-ai/DeepSeek-R1-Distill-Llama-70B
|
||||
- deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
|
||||
- deepseek-ai/DeepSeek-R1-Distill-Qwen-32B
|
||||
- google/codegemma-7b-it
|
||||
- google/gemma-1.1-7b-it
|
||||
- google/gemma-2-27b-it
|
||||
- google/gemma-2-2b-it
|
||||
- google/gemma-2-9b-it
|
||||
- google/gemma-2b
|
||||
- google/gemma-3-1b-it
|
||||
- ibm-granite/granite-3.1-2b-instruct
|
||||
- ibm-granite/granite-3.1-8b-instruct
|
||||
- ibm-granite/granite-3.3-2b-instruct
|
||||
- ibm-granite/granite-3.3-8b-instruct
|
||||
- ibm-granite/granite-guardian-3.1-2b
|
||||
- ibm-granite/granite-guardian-3.2-5b
|
||||
- meta-llama/CodeLlama-34b-Instruct-hf
|
||||
- meta-llama/CodeLlama-7b-Instruct-hf
|
||||
- meta-llama/CodeLlama-7b-Python-hf
|
||||
- meta-llama/Llama-2-13b-chat-hf
|
||||
- meta-llama/Llama-2-7b-chat-hf
|
||||
- meta-llama/Llama-3.1-8B-Instruct
|
||||
- meta-llama/Llama-3.2-1B-Instruct
|
||||
- meta-llama/Llama-3.2-3B-Instruct
|
||||
- meta-llama/Llama-3.3-70B-Instruct
|
||||
- meta-llama/Llama-4-Maverick-17B-128E-Instruct
|
||||
- meta-llama/Llama-4-Scout-17B-16E-Instruct
|
||||
- microsoft/Phi-3-medium-128k-instruct
|
||||
- microsoft/Phi-3-medium-4k-instruct
|
||||
- microsoft/Phi-4-mini-instruct
|
||||
- microsoft/Phi-4-mini-reasoning
|
||||
- microsoft/Phi-4-reasoning
|
||||
- microsoft/Phi-4-reasoning-plus
|
||||
- microsoft/phi-4
|
||||
- mistralai/Codestral-22B-v0.1
|
||||
- mistralai/Mistral-7B-Instruct-v0.2
|
||||
- mistralai/Mistral-7B-Instruct-v0.3
|
||||
- mistralai/Mixtral-8x22B-Instruct-v0.1
|
||||
- nvidia/Llama-3.1-405B-Instruct-FP8
|
||||
- nvidia/Llama-3.1-70B-Instruct-FP8
|
||||
- nvidia/Llama-3.1-8B-Instruct-FP8
|
||||
- nvidia/Llama-3.1-Minitron-4B-Depth-Base
|
||||
- nvidia/Llama-3.1-Minitron-4B-Width-Base
|
||||
- nvidia/Llama-3.1-Nemotron-70B-Instruct-HF
|
||||
- nvidia/Llama-3.1-Nemotron-Nano-8B-v1
|
||||
- nvidia/Llama-3_1-Nemotron-51B-Instruct
|
||||
- nvidia/Llama-3_1-Nemotron-Ultra-253B-v1
|
||||
- nvidia/Llama-3_1-Nemotron-Ultra-253B-v1-FP8
|
||||
- nvidia/Llama-3_3-Nemotron-Super-49B-v1
|
||||
- nvidia/Mistral-NeMo-Minitron-8B-Base
|
||||
- nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16
|
||||
- nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8
|
||||
- perplexity-ai/r1-1776-distill-llama-70b
|
||||
|
||||
</details>
|
||||
|
||||
### Runtime Integrations
|
||||
|
||||
AutoDeploy runs natively with the complete `TRT-LLM` stack via the `LLM` API. In addition, we provide a light-weight wrapper of the `LLM` API for onboarding and debugging new models:
|
||||
|
||||
| `"runtime"` | Description |
|
||||
|-------------|-------------|
|
||||
| `trtllm` | A robust, production-grade runtime optimized for high-performance inference. |
|
||||
| `demollm` | A lightweight runtime wrapper designed for development and testing, featuring a naive scheduler and KV-cache manager for simplified debugging and testing. |
|
||||
|
||||
### Compile Backends
|
||||
|
||||
AutoDeploy supports multiple backends for compiling the exported Torch graph:
|
||||
|
||||
| `"compile_backend"` | Description |
|
||||
|--------------------|-------------|
|
||||
| `torch-simple` | Exports the graph without additional optimizations. |
|
||||
| `torch-compile` | Applies `torch.compile` to the graph after all AutoDeploy transformations have been completed. |
|
||||
| `torch-cudagraph` | Performs CUDA graph capture (without torch.compile). |
|
||||
| `torch-opt` | Uses `torch.compile` along with CUDA Graph capture to enhance inference performance. |
|
||||
|
||||
### Attention backends
|
||||
|
||||
Optimize attention operations with different attention kernel implementations:
|
||||
|
||||
| `"attn_backend"` | Description |
|
||||
|----------------------|-------------|
|
||||
| `triton` | Custom fused multi-head attention (MHA) with KV Cache kernels for efficient attention processing. |
|
||||
| `flashinfer` | Uses optimized attention kernels with KV Cache from the [`flashinfer`](https://github.com/flashinfer-ai/flashinfer.git) library. |
|
||||
|
||||
### Precision Support
|
||||
|
||||
AutoDeploy supports models with various precision formats, including quantized checkpoints generated by [`Model-Optimizer`](https://github.com/NVIDIA/Model-Optimizer).
|
||||
|
||||
**Supported precision types include:**
|
||||
|
||||
- BF16 / FP16 / FP32
|
||||
- FP8
|
||||
- [NVFP4](https://www.nvidia.com/en-us/data-center/technologies/blackwell-architecture/)
|
||||
@ -1,6 +1,6 @@
|
||||
# 🔥🚀⚡ AutoDeploy Examples
|
||||
|
||||
This folder contains runnable examples for **AutoDeploy**. For general AutoDeploy documentation, motivation, support matrix, and feature overview, please see the [official docs](https://nvidia.github.io/TensorRT-LLM/torch/auto_deploy/auto-deploy.html).
|
||||
This folder contains runnable examples for **AutoDeploy**. For general AutoDeploy documentation, motivation, support matrix, and feature overview, please see the [official docs](https://nvidia.github.io/TensorRT-LLM/features/auto_deploy/auto-deploy.html).
|
||||
|
||||
______________________________________________________________________
|
||||
|
||||
|
||||
@ -1,697 +0,0 @@
|
||||
"""Internal triton attention ops that are not actively used in the auto-deploy pipeline."""
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import triton
|
||||
|
||||
from .triton_attention import _flattened_context_mha, _generate_mha
|
||||
from .triton_kernels.attention_with_kv_cache import (
|
||||
attention_kv_stage2,
|
||||
context_attention_kv,
|
||||
context_attention_kv_flattened,
|
||||
gqa_attention_kv_stage1,
|
||||
update_kv_cache_rope_fusion,
|
||||
)
|
||||
from .triton_kernels.attention_with_paged_kv_cache import (
|
||||
attention_kv_paged_stage1,
|
||||
context_attention_kv_paged,
|
||||
update_paged_kv_cache,
|
||||
)
|
||||
|
||||
|
||||
def _paged_generate_mha(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
page_table: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
cache_loc: torch.Tensor,
|
||||
input_pos: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
max_seq_len: int,
|
||||
):
|
||||
b, (n_heads, d_head) = q.shape[0], q.shape[-2:]
|
||||
PAGE_SIZE, n_kv_heads = k_cache.shape[1:3]
|
||||
device = q.device
|
||||
|
||||
SEQ_BLOCK_SIZE = 64
|
||||
num_blocks = (max_seq_len + SEQ_BLOCK_SIZE - 1) // SEQ_BLOCK_SIZE
|
||||
stage1_output_values = torch.empty(
|
||||
b, n_heads, num_blocks, d_head, device=device, dtype=torch.float32
|
||||
)
|
||||
stage1_output_logsumexp = torch.empty(
|
||||
b, n_heads, num_blocks, device=device, dtype=torch.float32
|
||||
) - float("inf")
|
||||
|
||||
(
|
||||
update_paged_kv_cache[(b, n_kv_heads, 1)](
|
||||
k,
|
||||
v,
|
||||
None,
|
||||
None,
|
||||
k_cache,
|
||||
v_cache,
|
||||
cache_loc,
|
||||
input_pos,
|
||||
page_table,
|
||||
n_kv_heads,
|
||||
d_head,
|
||||
SEQ_BLOCK_SIZE,
|
||||
max_seq_len,
|
||||
PAGE_SIZE,
|
||||
page_table.stride(0),
|
||||
GENERATE_ONLY=True,
|
||||
),
|
||||
)
|
||||
|
||||
attention_kv_paged_stage1[
|
||||
(
|
||||
b,
|
||||
n_heads,
|
||||
num_blocks,
|
||||
)
|
||||
](
|
||||
q,
|
||||
k_cache,
|
||||
v_cache,
|
||||
cache_loc,
|
||||
page_table,
|
||||
input_pos,
|
||||
stage1_output_values,
|
||||
stage1_output_logsumexp,
|
||||
num_blocks,
|
||||
max_seq_len,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
d_head,
|
||||
SEQ_BLOCK_SIZE,
|
||||
PAGE_SIZE,
|
||||
page_table.stride(0),
|
||||
)
|
||||
attention_kv_stage2[(b, n_heads, 1)](
|
||||
stage1_output_values,
|
||||
stage1_output_logsumexp,
|
||||
out,
|
||||
input_pos,
|
||||
num_blocks,
|
||||
n_heads,
|
||||
d_head,
|
||||
SEQ_BLOCK_SIZE,
|
||||
False,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
def _paged_context_mha(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
input_pos: torch.Tensor,
|
||||
cache_loc: torch.Tensor,
|
||||
page_table: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
seq_len: torch.Tensor,
|
||||
seq_start: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
max_seq_len: int, # max cache length of sequence, kv_cache shape don't provide this info.
|
||||
) -> None:
|
||||
# NOTE: s_total == sum(seq_len)
|
||||
s_total, n_heads, d_head = q.shape
|
||||
PAGE_SIZE, n_kv_heads = k_cache.shape[1:3]
|
||||
BATCH_SIZE = len(input_pos)
|
||||
SEQ_BLOCK = 32
|
||||
(
|
||||
update_paged_kv_cache[
|
||||
(BATCH_SIZE, n_kv_heads, (max(seq_len) + SEQ_BLOCK - 1) // SEQ_BLOCK)
|
||||
](
|
||||
k,
|
||||
v,
|
||||
seq_len,
|
||||
seq_start,
|
||||
k_cache,
|
||||
v_cache,
|
||||
cache_loc,
|
||||
input_pos,
|
||||
page_table,
|
||||
n_kv_heads,
|
||||
d_head,
|
||||
SEQ_BLOCK,
|
||||
max_seq_len,
|
||||
PAGE_SIZE,
|
||||
page_table.stride(0),
|
||||
GENERATE_ONLY=False,
|
||||
),
|
||||
)
|
||||
softmax_scale = 1.0 / math.sqrt(d_head)
|
||||
grid = (BATCH_SIZE, n_heads, (max(seq_len) + SEQ_BLOCK - 1) // SEQ_BLOCK)
|
||||
context_attention_kv_paged[grid](
|
||||
q,
|
||||
seq_len,
|
||||
seq_start,
|
||||
k_cache,
|
||||
v_cache,
|
||||
cache_loc,
|
||||
input_pos,
|
||||
page_table,
|
||||
softmax_scale,
|
||||
out,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
d_head,
|
||||
SEQ_BLOCK,
|
||||
max_seq_len,
|
||||
PAGE_SIZE,
|
||||
page_table.stride(0),
|
||||
num_stages=2,
|
||||
)
|
||||
|
||||
|
||||
@torch.library.custom_op(
|
||||
"auto_deploy::triton_attention_fused_mha_with_paged_cache", mutates_args=()
|
||||
)
|
||||
def fused_mha_with_paged_cache(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
input_pos: torch.Tensor,
|
||||
cache_loc: torch.Tensor,
|
||||
seq_len: torch.Tensor,
|
||||
seq_start: torch.Tensor,
|
||||
page_table: torch.Tensor,
|
||||
max_seq_len: int,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
freqs_cis: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""Fused MHA with paged cache that takes raw input from q, k, v GEMMs.
|
||||
|
||||
NOTE: this op can also handle seq_len==0, which might be useful for CUDAGRAPH.
|
||||
"""
|
||||
# b, s info
|
||||
# NOTE: b, s are just the shapes of the input tensor q; not necessarily the number of sequences.
|
||||
# Generally speaking, we expect one of two cases here:
|
||||
# 1. b > 0, s==1: this indicates a generate-only batch of tokens.
|
||||
# 2. b==1, s > 0: this indicates a mixed context+generate phase. The actual number of sequences
|
||||
# and number of tokens per sequence are encoded in seq_len and seq_start.
|
||||
# Assuming that context seq_len always > 0.
|
||||
b, s, d = q.shape
|
||||
head_dim = k_cache.shape[-1]
|
||||
|
||||
# reshapes with num_heads and head_dim
|
||||
if s == 1:
|
||||
bs_view = (b, s)
|
||||
else:
|
||||
bs_view = (b * s,)
|
||||
q = q.view(*bs_view, q.shape[2] // head_dim, head_dim)
|
||||
k = k.view(*bs_view, k.shape[2] // head_dim, head_dim)
|
||||
v = v.view(*bs_view, v.shape[2] // head_dim, head_dim)
|
||||
|
||||
# rope embedding for generate-only or mixed
|
||||
if freqs_cis is not None:
|
||||
if s == 1:
|
||||
rope_args = (freqs_cis, input_pos, "bsnd")
|
||||
fn_rope = torch.ops.auto_deploy.triton_rope_with_input_pos
|
||||
else:
|
||||
rope_args = (freqs_cis, input_pos, seq_len, seq_start)
|
||||
fn_rope = torch.ops.auto_deploy.triton_rope_on_flattened_inputs
|
||||
q = fn_rope(q, *rope_args)
|
||||
k = fn_rope(k, *rope_args)
|
||||
|
||||
# run attention
|
||||
y = torch.empty_like(q)
|
||||
if s == 1:
|
||||
# generate-only phase
|
||||
_paged_generate_mha(
|
||||
q, k, v, page_table, k_cache, v_cache, cache_loc, input_pos, y, max_seq_len
|
||||
)
|
||||
else:
|
||||
# mixed context + generate phase
|
||||
_paged_context_mha(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
input_pos,
|
||||
cache_loc,
|
||||
page_table,
|
||||
k_cache,
|
||||
v_cache,
|
||||
seq_len,
|
||||
seq_start,
|
||||
y,
|
||||
max_seq_len,
|
||||
)
|
||||
|
||||
return y.view(b, s, d) # [b,s,n*h_d]
|
||||
|
||||
|
||||
@fused_mha_with_paged_cache.register_fake
|
||||
def fused_mha_with_paged_cache_fake(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
input_pos: torch.Tensor,
|
||||
cache_loc: torch.Tensor,
|
||||
seq_len: torch.Tensor,
|
||||
seq_start: torch.Tensor,
|
||||
page_table: torch.Tensor,
|
||||
max_seq_len: int,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
freqs_cis: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(q.contiguous())
|
||||
|
||||
|
||||
def _generate_mha_rope_fusion(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
cache_locs: torch.Tensor,
|
||||
input_pos: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
):
|
||||
b, (n_heads, d_head) = q.shape[0], q.shape[-2:]
|
||||
max_seq_len, n_kv_heads = k_cache.shape[1:3]
|
||||
device = q.device
|
||||
|
||||
SEQ_BLOCK_SIZE = 64
|
||||
num_blocks = (max_seq_len + SEQ_BLOCK_SIZE - 1) // SEQ_BLOCK_SIZE
|
||||
stage1_output_values = torch.empty(
|
||||
b, n_heads, num_blocks, d_head, device=device, dtype=torch.float32
|
||||
)
|
||||
stage1_output_logsumexp = torch.empty(
|
||||
b, n_heads, num_blocks, device=device, dtype=torch.float32
|
||||
) - float("inf")
|
||||
q_rope = torch.empty_like(q)
|
||||
HEAD_BLOCK_SIZE = max(16, triton.next_power_of_2(n_heads // n_kv_heads))
|
||||
|
||||
(
|
||||
update_kv_cache_rope_fusion[(b, n_kv_heads, 1)](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
None,
|
||||
None,
|
||||
q_rope,
|
||||
k_cache,
|
||||
v_cache,
|
||||
input_pos,
|
||||
cache_locs,
|
||||
freqs_cis,
|
||||
max_seq_len,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
d_head,
|
||||
1,
|
||||
HEAD_BLOCK_SIZE,
|
||||
GENERATE_ONLY=True,
|
||||
),
|
||||
)
|
||||
|
||||
HEAD_BLOCK_SIZE = max(16, triton.next_power_of_2(n_heads // n_kv_heads))
|
||||
scale = 1.0 / math.sqrt(d_head)
|
||||
gqa_attention_kv_stage1[
|
||||
(
|
||||
b,
|
||||
n_kv_heads,
|
||||
num_blocks,
|
||||
)
|
||||
](
|
||||
q_rope,
|
||||
k_cache,
|
||||
v_cache,
|
||||
cache_locs,
|
||||
input_pos,
|
||||
stage1_output_values,
|
||||
stage1_output_logsumexp,
|
||||
num_blocks,
|
||||
scale,
|
||||
max_seq_len,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
d_head,
|
||||
d_head,
|
||||
SEQ_BLOCK_SIZE,
|
||||
HEAD_BLOCK_SIZE,
|
||||
-1,
|
||||
)
|
||||
attention_kv_stage2[(b, n_heads, 1)](
|
||||
stage1_output_values,
|
||||
stage1_output_logsumexp,
|
||||
out,
|
||||
input_pos,
|
||||
num_blocks,
|
||||
n_heads,
|
||||
d_head,
|
||||
SEQ_BLOCK_SIZE,
|
||||
False,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
def _flattened_context_mha_rope_fusion(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
input_pos: torch.Tensor,
|
||||
cache_loc: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
seq_len: torch.Tensor,
|
||||
seq_start: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
) -> None:
|
||||
# NOTE: s_total == sum(seq_len)
|
||||
s_total, n_heads, d_head = q.shape
|
||||
max_cache_seq_len, n_kv_heads = k_cache.shape[1:3]
|
||||
BATCH_SIZE: int = len(input_pos)
|
||||
SEQ_BLOCK = 32
|
||||
q_rope = torch.empty_like(q)
|
||||
HEAD_BLOCK_SIZE = max(16, triton.next_power_of_2(n_heads // n_kv_heads))
|
||||
(
|
||||
update_kv_cache_rope_fusion[
|
||||
(BATCH_SIZE, n_kv_heads, (max(seq_len) + SEQ_BLOCK - 1) // SEQ_BLOCK)
|
||||
](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
seq_len,
|
||||
seq_start,
|
||||
q_rope,
|
||||
k_cache,
|
||||
v_cache,
|
||||
input_pos,
|
||||
cache_loc,
|
||||
freqs_cis,
|
||||
max_cache_seq_len,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
d_head,
|
||||
32,
|
||||
HEAD_BLOCK_SIZE,
|
||||
GENERATE_ONLY=False,
|
||||
),
|
||||
)
|
||||
# TODO: use input_pos to get the correct cache locations
|
||||
softmax_scale = 1.0 / math.sqrt(d_head)
|
||||
grid = (BATCH_SIZE, n_heads, (max(seq_len) + SEQ_BLOCK - 1) // SEQ_BLOCK)
|
||||
context_attention_kv_flattened[grid](
|
||||
q_rope,
|
||||
seq_len,
|
||||
seq_start,
|
||||
k_cache,
|
||||
v_cache,
|
||||
input_pos,
|
||||
cache_loc,
|
||||
out,
|
||||
softmax_scale,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
d_head,
|
||||
d_head,
|
||||
SEQ_BLOCK,
|
||||
max_cache_seq_len,
|
||||
-1,
|
||||
False,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
@torch.library.custom_op(
|
||||
"auto_deploy::triton_attention_fused_flattened_mha_with_cache_rope_fusion", mutates_args=()
|
||||
)
|
||||
def fused_flattened_mha_with_cache_rope_fusion(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
input_pos: torch.Tensor,
|
||||
cache_loc: torch.Tensor,
|
||||
seq_len: torch.Tensor,
|
||||
seq_start: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
freqs_cis: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""Flattened & fused MHA with cache that takes raw input from q, k, v GEMMs.
|
||||
|
||||
Fuse k rope in update_kv_cache and q rope in attention.
|
||||
NOTE: this op can also handle seq_len==0, which might be useful for CUDAGRAPH.
|
||||
"""
|
||||
# this function only handle requests with rope embadding.
|
||||
if freqs_cis is None:
|
||||
return fused_flattened_mha_with_cache(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
input_pos,
|
||||
cache_loc,
|
||||
seq_len,
|
||||
seq_start,
|
||||
k_cache,
|
||||
v_cache,
|
||||
freqs_cis,
|
||||
)
|
||||
|
||||
# b, s info
|
||||
# NOTE: b, s are just the shapes of the input tensor q; not necessarily the number of sequences.
|
||||
# Generally speaking, we expect one of two cases here:
|
||||
# 1. b > 0, s==1: this indicates a generate-only batch of tokens.
|
||||
# 2. b==1, s > 0: this indicates a mixed context+generate phase. The actual number of sequences
|
||||
# and number of tokens per sequence are encoded in seq_len and seq_start.
|
||||
b, s, d = q.shape
|
||||
head_dim = k_cache.shape[-1]
|
||||
|
||||
# reshapes with num_heads and head_dim
|
||||
if s == 1:
|
||||
bs_view = (b, s)
|
||||
else:
|
||||
bs_view = (b * s,)
|
||||
q = q.view(*bs_view, q.shape[2] // head_dim, head_dim)
|
||||
k = k.view(*bs_view, k.shape[2] // head_dim, head_dim)
|
||||
v = v.view(*bs_view, v.shape[2] // head_dim, head_dim)
|
||||
|
||||
# run attention
|
||||
y = torch.empty_like(q)
|
||||
if s == 1:
|
||||
# generate-only phase
|
||||
_generate_mha_rope_fusion(q, k, v, freqs_cis, k_cache, v_cache, cache_loc, input_pos, y)
|
||||
else:
|
||||
# mixed context + generate phase
|
||||
_flattened_context_mha_rope_fusion(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
freqs_cis,
|
||||
input_pos,
|
||||
cache_loc,
|
||||
k_cache,
|
||||
v_cache,
|
||||
seq_len,
|
||||
seq_start,
|
||||
y,
|
||||
)
|
||||
|
||||
return y.view(b, s, d) # [b,s,n*h_d]
|
||||
|
||||
|
||||
@fused_flattened_mha_with_cache_rope_fusion.register_fake
|
||||
def fused_flattened_mha_with_cache_rope_fusion_fake(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
input_pos: torch.Tensor,
|
||||
cache_loc: torch.Tensor,
|
||||
seq_len: torch.Tensor,
|
||||
seq_start: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
):
|
||||
return torch.empty_like(q.contiguous())
|
||||
|
||||
|
||||
def _context_mha(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
):
|
||||
b, s, n_heads, q_d_head = q.shape
|
||||
max_seq_len, n_kv_heads = k_cache.shape[1:3]
|
||||
v_d_head = v.shape[-1]
|
||||
|
||||
SEQ_BLOCK = 32
|
||||
softmax_scale = 1.0 / math.sqrt(q_d_head)
|
||||
grid = (b, n_heads, (s + SEQ_BLOCK - 1) // SEQ_BLOCK)
|
||||
context_attention_kv[grid](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
k_cache,
|
||||
v_cache,
|
||||
s,
|
||||
out,
|
||||
softmax_scale,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
q_d_head,
|
||||
v_d_head,
|
||||
SEQ_BLOCK,
|
||||
max_seq_len,
|
||||
num_stages=2,
|
||||
)
|
||||
|
||||
|
||||
@torch.library.custom_op("auto_deploy::triton_attention_fused_mha_with_cache", mutates_args=())
|
||||
def fused_mha_with_cache(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
input_pos: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
freqs_cis: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""Fused MHA with cache that takes raw input from q, k, v GEMMs."""
|
||||
# b, s info
|
||||
b, s = q.shape[:2]
|
||||
head_dim = k_cache.shape[-1]
|
||||
|
||||
# reshapes with num_heads and head_dim
|
||||
q = q.view(b, s, -1, head_dim)
|
||||
k = k.view(b, s, -1, head_dim)
|
||||
v = v.view(b, s, -1, head_dim)
|
||||
|
||||
# rope embedding
|
||||
if freqs_cis is not None:
|
||||
q = torch.ops.auto_deploy.triton_rope_with_input_pos(q, freqs_cis, input_pos, "bsnd")
|
||||
k = torch.ops.auto_deploy.triton_rope_with_input_pos(k, freqs_cis, input_pos, "bsnd")
|
||||
|
||||
# attention (assumed layout is bsnd)
|
||||
y = torch.empty_like(q)
|
||||
scale = 1.0 / math.sqrt(head_dim)
|
||||
if s > 1:
|
||||
# context phase
|
||||
_context_mha(q, k, v, k_cache, v_cache, y)
|
||||
else:
|
||||
# generate phase
|
||||
cache_locs = torch.arange(0, b, device=q.device, dtype=torch.int32)
|
||||
_generate_mha(q, k, v, k_cache, v_cache, cache_locs, input_pos, scale, y)
|
||||
|
||||
return y.view(b, s, -1) # [b,s,n*h_d]
|
||||
|
||||
|
||||
@fused_mha_with_cache.register_fake
|
||||
def fused_mha_fake(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
input_pos: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
):
|
||||
return torch.empty_like(q.contiguous())
|
||||
|
||||
|
||||
@torch.library.custom_op(
|
||||
"auto_deploy::triton_attention_fused_flattened_mha_with_cache", mutates_args=()
|
||||
)
|
||||
def fused_flattened_mha_with_cache(
|
||||
# Q, K, V
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
# METADATA
|
||||
seq_len: torch.Tensor,
|
||||
input_pos: torch.Tensor,
|
||||
cache_loc: torch.Tensor,
|
||||
seq_start: torch.Tensor,
|
||||
# CACHES
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
# BUFFERS
|
||||
freqs_cis: torch.Tensor,
|
||||
# CONSTANTS
|
||||
# <none>
|
||||
) -> torch.Tensor:
|
||||
"""Flattened & fused MHA with cache that takes raw input from q, k, v GEMMs.
|
||||
|
||||
NOTE: this op can also handle seq_len==0, which might be useful for CUDAGRAPH.
|
||||
"""
|
||||
# b, s info
|
||||
# NOTE: b, s are just the shapes of the input tensor q; not necessarily the number of sequences.
|
||||
# Generally speaking, we expect one of two cases here:
|
||||
# 1. b > 0, s==1: this indicates a generate-only batch of tokens.
|
||||
# 2. b==1, s > 0: this indicates a mixed context+generate phase. The actual number of sequences
|
||||
# and number of tokens per sequence are encoded in seq_len and seq_start.
|
||||
head_dim = k_cache.shape[-1]
|
||||
b, s, d = q.shape
|
||||
|
||||
# reshapes with num_heads and head_dim
|
||||
if s == 1:
|
||||
bs_view = (b, s)
|
||||
else:
|
||||
bs_view = (b * s,)
|
||||
q = q.view(*bs_view, q.shape[2] // head_dim, head_dim)
|
||||
k = k.view(*bs_view, k.shape[2] // head_dim, head_dim)
|
||||
v = v.view(*bs_view, v.shape[2] // head_dim, head_dim)
|
||||
|
||||
# rope embedding for generate-only or mixed
|
||||
if freqs_cis.numel() > 0:
|
||||
if s == 1:
|
||||
rope_args = (freqs_cis, input_pos, "bsnd")
|
||||
fn_rope = torch.ops.auto_deploy.triton_rope_with_input_pos
|
||||
else:
|
||||
rope_args = (freqs_cis, input_pos, seq_len, seq_start)
|
||||
fn_rope = torch.ops.auto_deploy.triton_rope_on_flattened_inputs
|
||||
q = fn_rope(q, *rope_args)
|
||||
k = fn_rope(k, *rope_args)
|
||||
|
||||
# run attention
|
||||
y = torch.empty_like(q)
|
||||
if s == 1:
|
||||
# generate-only phase
|
||||
_generate_mha(q, k, v, k_cache, v_cache, cache_loc, input_pos, y)
|
||||
else:
|
||||
# mixed context + generate phase
|
||||
_flattened_context_mha(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
input_pos,
|
||||
cache_loc,
|
||||
k_cache,
|
||||
v_cache,
|
||||
seq_len,
|
||||
seq_start,
|
||||
y,
|
||||
)
|
||||
|
||||
return y.view(b, s, d) # [b,s,n*h_d]
|
||||
|
||||
|
||||
@fused_flattened_mha_with_cache.register_fake
|
||||
def fused_flattened_mha_fake(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
seq_len: torch.Tensor,
|
||||
input_pos: torch.Tensor,
|
||||
cache_loc: torch.Tensor,
|
||||
seq_start: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
):
|
||||
return torch.empty_like(q.contiguous())
|
||||
@ -11,7 +11,7 @@ and operates on a purely functional paradigm that is compatible with the torch c
|
||||
|
||||
import math
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Literal, Optional, Protocol, Sequence, Set, Tuple, Type, Union, final
|
||||
from typing import Dict, List, Literal, Optional, Protocol, Sequence, Set, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
from torch._ops import OpOverloadPacket
|
||||
@ -395,6 +395,9 @@ class SequenceInfo:
|
||||
# see https://github.com/NVIDIA/TensorRT-LLM/issues/4504
|
||||
self.max_num_tokens = max_num_tokens or (max_seq_len + 1) * max_batch_size
|
||||
|
||||
# will store num_blocks later...
|
||||
self._num_blocks = None
|
||||
|
||||
# TODO (lucaslie): can we remove this eventually from this i/f?
|
||||
self.vocab_size_padded = vocab_size_padded
|
||||
|
||||
@ -569,6 +572,11 @@ class SequenceInfo:
|
||||
"""Return the page assignments for each sequence."""
|
||||
return self._get_page_assignments(self.cache_loc, self.pages_per_seq)
|
||||
|
||||
@property
|
||||
def num_blocks(self) -> int:
|
||||
assert self._num_blocks is not None, "num_blocks not set yet"
|
||||
return self._num_blocks
|
||||
|
||||
def estimate_cache_tokens_per_forward(self) -> int:
|
||||
"""Estimate the max number of tokens that will be cached for a forward pass.
|
||||
|
||||
@ -581,6 +589,10 @@ class SequenceInfo:
|
||||
|
||||
def estimate_cache_loc_capacity(self, num_blocks: int) -> None:
|
||||
"""Estimate needed capacity of cache_loc based on available blocks and resize."""
|
||||
# set num_blocks
|
||||
self._num_blocks = num_blocks
|
||||
|
||||
# get current capacity
|
||||
cache_loc_capacity = self._input_buffer.get_capacity("cache_loc")
|
||||
|
||||
# cache_loc requires some special treatment due to block reuse. Note that the constraint for
|
||||
@ -1011,42 +1023,135 @@ class ResourceHandler(ABC):
|
||||
performs on the resources providing an abstract handle.
|
||||
"""
|
||||
|
||||
@property
|
||||
def is_paged(self) -> bool:
|
||||
"""Whether the resource is paged.
|
||||
|
||||
If the resource is paged, it will participate in the resize computation of the caches and
|
||||
needs to implement the _get_bytes_per_token method.
|
||||
"""
|
||||
return False
|
||||
|
||||
@property
|
||||
def bytes_per_token(self) -> int:
|
||||
"""The size of the resource per token."""
|
||||
if self.is_paged:
|
||||
return self._get_bytes_per_token()
|
||||
else:
|
||||
raise NotImplementedError(f"Resource {self.__class__.__name__} is not paged.")
|
||||
|
||||
def _get_bytes_per_token(self) -> int:
|
||||
"""The size of the resource per token."""
|
||||
raise NotImplementedError(
|
||||
f"Resource {self.__class__.__name__} needs to implement _get_bytes_per_token."
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def allocate(self, sequence_info: SequenceInfo) -> torch.Tensor:
|
||||
"""Initialize the resource for the given sequence info."""
|
||||
|
||||
|
||||
class ManagedResourceHandler(ResourceHandler):
|
||||
"""An abstract interface to handle a resource that is managed by the cache manager."""
|
||||
class KVPagedResourceHandler(ResourceHandler):
|
||||
"""Handler for paged KV cache resources.
|
||||
|
||||
@final
|
||||
def allocate(self, sequence_info: SequenceInfo) -> torch.Tensor:
|
||||
"""Allocate the resource for the given sequence info."""
|
||||
raise NotImplementedError("Managed resources should not be allocated directly!")
|
||||
This handler indicates the resource should be managed by the standard KVCacheManager.
|
||||
|
||||
|
||||
class PagedResourceHandler(ManagedResourceHandler):
|
||||
"""An abstract interface to handle a paged resource.
|
||||
|
||||
The PagedResourceHandler can be used to handle resources that support paging such as kv-caches.
|
||||
Args:
|
||||
num_kv_heads: Number of key-value heads.
|
||||
head_dim: Dimension of each head.
|
||||
dtype: The dtype of the KV cache.
|
||||
kv_factor: The factor of the KV cache. Default is 2 for combined k/v cache.
|
||||
kv_layout: Memory layout for the KV cache. Either "HND" (head-num-dim) or "NHD" (num-head-dim).
|
||||
Default is "HND" which is the standard layout for flashinfer.
|
||||
"""
|
||||
|
||||
def __init__(self, *token_shape: int, dtype: torch.dtype) -> None:
|
||||
"""Initialize the PagedResourceHandler.
|
||||
@property
|
||||
def is_paged(self) -> bool:
|
||||
"""Whether the resource is paged."""
|
||||
return True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_kv_heads: int,
|
||||
head_dim: int,
|
||||
dtype: torch.dtype,
|
||||
kv_factor: int = 2,
|
||||
kv_layout: Literal["HND", "NHD"] = "HND",
|
||||
) -> None:
|
||||
"""Initialize the KVPagedResourceHandler.
|
||||
|
||||
Args:
|
||||
page_shape: The shape of a single page of the resource.
|
||||
dtype: The dtype of the resource.
|
||||
num_kv_heads: Number of key-value heads.
|
||||
head_dim: Dimension of each head.
|
||||
dtype: The dtype of the KV cache.
|
||||
kv_factor: The factor of the KV cache. Default is 2.
|
||||
kv_layout: Memory layout - "HND" or "NHD". Default is "HND".
|
||||
"""
|
||||
self.token_shape = token_shape
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.head_dim = head_dim
|
||||
self.dtype = dtype
|
||||
self.kv_factor = kv_factor
|
||||
assert kv_factor in [1, 2], f"Invalid kv_factor: {kv_factor}"
|
||||
self.kv_layout = kv_layout
|
||||
|
||||
def __eq__(self, other: Optional[ResourceHandler]) -> bool:
|
||||
"""Check compatibility for KVCacheManager (head_dim and dtype must match)."""
|
||||
if type(other) is not type(self):
|
||||
return False
|
||||
return (
|
||||
self.head_dim == other.head_dim
|
||||
and self.dtype == other.dtype
|
||||
and self.kv_factor == other.kv_factor
|
||||
and self.kv_layout == other.kv_layout
|
||||
)
|
||||
|
||||
def _get_bytes_per_token(self) -> int:
|
||||
"""The size of the resource per token in bytes."""
|
||||
return self.num_kv_heads * self.kv_factor * self.head_dim * self.dtype.itemsize
|
||||
|
||||
def allocate(self, sequence_info: SequenceInfo) -> torch.Tensor:
|
||||
"""Allocate paged resource locally when not managed by KVCacheManager.
|
||||
|
||||
Args:
|
||||
sequence_info: Sequence information containing device info.
|
||||
|
||||
Returns:
|
||||
Allocated tensor with shape depending on kv_layout:
|
||||
- NHD: [num_blocks, 2, tokens_per_block, num_kv_heads, head_dim]
|
||||
- HND: [num_blocks, 2, num_kv_heads, tokens_per_block, head_dim]
|
||||
"""
|
||||
if self.kv_layout == "HND":
|
||||
return torch.empty(
|
||||
sequence_info.num_blocks,
|
||||
self.kv_factor,
|
||||
self.num_kv_heads,
|
||||
sequence_info.tokens_per_block,
|
||||
self.head_dim,
|
||||
device=sequence_info.device,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
elif self.kv_layout == "NHD":
|
||||
return torch.empty(
|
||||
sequence_info.num_blocks,
|
||||
sequence_info.tokens_per_block,
|
||||
self.kv_factor,
|
||||
self.num_kv_heads,
|
||||
self.head_dim,
|
||||
device=sequence_info.device,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid kv_layout: {self.kv_layout}")
|
||||
|
||||
|
||||
class StateResourceHandler(ManagedResourceHandler):
|
||||
class StateResourceHandler(ResourceHandler):
|
||||
"""Handler for per-sequence state resources (e.g., Mamba SSM/conv states).
|
||||
|
||||
These resources have shape [max_batch_size, *state_shape] and are
|
||||
managed by MambaHybridCacheManager via byte-level pooling.
|
||||
These resources have shape [max_batch_size, *state_shape] and can be either:
|
||||
- Managed by MambaHybridCacheManager (for typed subclasses SSMResourceHandler, CausalConvResourceHandler)
|
||||
- Allocated locally via allocate() (for generic StateResourceHandler or when constraints don't hold)
|
||||
|
||||
Subclasses should define state_shape as a property that returns the appropriate shape.
|
||||
"""
|
||||
|
||||
def __init__(self, *state_shape: int, dtype: torch.dtype) -> None:
|
||||
@ -1056,9 +1161,96 @@ class StateResourceHandler(ManagedResourceHandler):
|
||||
state_shape: The shape of a single state resource.
|
||||
dtype: The dtype of the state resource.
|
||||
"""
|
||||
self.state_shape = state_shape
|
||||
self._state_shape = state_shape
|
||||
self.dtype = dtype
|
||||
|
||||
@property
|
||||
def state_shape(self) -> Tuple[int, ...]:
|
||||
"""Return the state shape. Subclasses may override this as a property."""
|
||||
return self._state_shape
|
||||
|
||||
def allocate(self, sequence_info: SequenceInfo) -> torch.Tensor:
|
||||
"""Allocate state resource locally (fallback when not managed by cache manager)."""
|
||||
return torch.empty(
|
||||
sequence_info.max_num_state_slots,
|
||||
*self.state_shape,
|
||||
device=sequence_info.device,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
def __eq__(self, other: Optional[ResourceHandler]) -> bool:
|
||||
"""Check compatibility for MambaHybridCacheManager state resources."""
|
||||
if type(other) is not type(self):
|
||||
return False
|
||||
return self.state_shape == other.state_shape and self.dtype == other.dtype
|
||||
|
||||
|
||||
class SSMResourceHandler(StateResourceHandler):
|
||||
"""Handler for SSM state resources that maps directly to MambaCacheManager's ssm_states buffer.
|
||||
|
||||
These resources have shape [max_batch_size, num_heads, head_dim, d_state] and are
|
||||
managed by MambaHybridCacheManager via the ssm_states buffer when compatible.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_dim: int,
|
||||
d_state: int,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
"""Initialize the SSMResourceHandler.
|
||||
|
||||
Args:
|
||||
num_heads: Number of attention heads.
|
||||
head_dim: Dimension per head.
|
||||
d_state: SSM state size.
|
||||
dtype: Data type for the state.
|
||||
"""
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
self.d_state = d_state
|
||||
# Call parent with dtype only (state_shape comes from property)
|
||||
super().__init__(dtype=dtype)
|
||||
|
||||
@property
|
||||
def state_shape(self) -> Tuple[int, int, int]:
|
||||
"""Return the SSM state shape: (num_heads, head_dim, d_state)."""
|
||||
return (self.num_heads, self.head_dim, self.d_state)
|
||||
|
||||
|
||||
class CausalConvResourceHandler(StateResourceHandler):
|
||||
"""Handler for causal conv state resources that maps to MambaCacheManager's conv_states buffer.
|
||||
|
||||
These resources have shape [max_batch_size, conv_dim, d_conv - 1] and are
|
||||
managed by MambaHybridCacheManager via the conv_states buffer when compatible.
|
||||
|
||||
Note: d_conv is the kernel size, and (d_conv - 1) is the state size stored in the cache.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
conv_dim: int,
|
||||
d_conv: int,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
"""Initialize the CausalConvResourceHandler.
|
||||
|
||||
Args:
|
||||
conv_dim: Convolution dimension (typically in_channels).
|
||||
d_conv: Kernel size. The cache stores d_conv - 1 elements.
|
||||
dtype: Data type for the state.
|
||||
"""
|
||||
self.conv_dim = conv_dim
|
||||
self.d_conv = d_conv # kernel_size
|
||||
# Call parent with dtype only (state_shape comes from property)
|
||||
super().__init__(dtype=dtype)
|
||||
|
||||
@property
|
||||
def state_shape(self) -> Tuple[int, int]:
|
||||
"""Return the conv state shape: (conv_dim, d_conv - 1)."""
|
||||
return (self.conv_dim, self.d_conv - 1)
|
||||
|
||||
|
||||
class UnpagedResourceHandler(ResourceHandler):
|
||||
"""Handler for per-token unpaged resources (e.g., unpaged KV caches).
|
||||
|
||||
@ -17,8 +17,8 @@ from .attention_interface import (
|
||||
AttentionLayout,
|
||||
AttentionRegistry,
|
||||
Constant,
|
||||
KVPagedResourceHandler,
|
||||
MHACallable,
|
||||
PagedResourceHandler,
|
||||
PrepareMetadataCallable,
|
||||
PrepareMetadataHostCallable,
|
||||
ResourceHandlerDict,
|
||||
@ -56,9 +56,7 @@ class _FlashInferPlanner:
|
||||
]
|
||||
plan_params_prefill: Optional[PlanParams]
|
||||
plan_params_decode: Optional[PlanParams]
|
||||
kv_layout: Literal["NHD", "HND"] = (
|
||||
"NHD" # TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/10966
|
||||
)
|
||||
kv_layout: Literal["NHD", "HND"] = "HND"
|
||||
|
||||
def __init__(self):
|
||||
self.workspace_buffer = None
|
||||
@ -320,16 +318,16 @@ def flashinfer_mha_with_cache(
|
||||
# EXTRA METADATA
|
||||
flashinfer_batch_indices: torch.Tensor,
|
||||
flashinfer_positions: torch.Tensor,
|
||||
# CACHES
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
# CACHES - combined KV cache with shape [num_blocks, 2, num_kv_heads, tokens_per_block, head_dim]
|
||||
kv_cache: torch.Tensor,
|
||||
# CONSTANTS
|
||||
scale: Optional[float],
|
||||
k_scale: float,
|
||||
v_scale: float,
|
||||
) -> torch.Tensor:
|
||||
# reshape to standard [b*s, n_heads, head_dim] layout
|
||||
head_dim = k_cache.shape[-1]
|
||||
# kv_cache shape: [num_blocks, 2, num_kv_heads, tokens_per_block, head_dim] (HND layout)
|
||||
head_dim = kv_cache.shape[-1]
|
||||
page_size = kv_cache.shape[3] # tokens_per_block
|
||||
q_shape_og = q.shape
|
||||
b, s = q_shape_og[:2]
|
||||
|
||||
@ -348,7 +346,7 @@ def flashinfer_mha_with_cache(
|
||||
# Assuming k_scale = v_scale = 1.0
|
||||
k_scale, v_scale = 1.0, 1.0
|
||||
# k = (k / k_scale).to(torch.float8_e4m3fn) if k_scale != 1.0, same for v
|
||||
if k_cache.dtype == torch.float8_e4m3fn:
|
||||
if kv_cache.dtype == torch.float8_e4m3fn:
|
||||
k = k.to(torch.float8_e4m3fn)
|
||||
v = v.to(torch.float8_e4m3fn)
|
||||
|
||||
@ -357,10 +355,11 @@ def flashinfer_mha_with_cache(
|
||||
append_value=v,
|
||||
batch_indices=flashinfer_batch_indices,
|
||||
positions=flashinfer_positions,
|
||||
paged_kv_cache=(k_cache, v_cache),
|
||||
paged_kv_cache=kv_cache,
|
||||
kv_indices=cache_loc,
|
||||
kv_indptr=cu_num_pages[: num_seq + 1],
|
||||
kv_last_page_len=last_page_len[:num_seq],
|
||||
kv_layout=_GlobalFlashInferPlanner.kv_layout,
|
||||
)
|
||||
|
||||
# check if we need to re-combine outputs
|
||||
@ -378,9 +377,9 @@ def flashinfer_mha_with_cache(
|
||||
n_kv_heads=n_kv_heads,
|
||||
head_dim=head_dim,
|
||||
num_seq=num_prefill,
|
||||
page_size=k_cache.shape[1],
|
||||
page_size=page_size,
|
||||
q_dtype=q_prefill.dtype,
|
||||
kv_dtype=k_cache.dtype,
|
||||
kv_dtype=kv_cache.dtype,
|
||||
sm_scale=scale,
|
||||
)
|
||||
|
||||
@ -395,7 +394,7 @@ def flashinfer_mha_with_cache(
|
||||
|
||||
y_prefill = wrapper_prefill.run(
|
||||
q_prefill,
|
||||
(k_cache, v_cache),
|
||||
kv_cache,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
enable_pdl=get_env_enable_pdl(),
|
||||
@ -413,9 +412,9 @@ def flashinfer_mha_with_cache(
|
||||
n_kv_heads=n_kv_heads,
|
||||
head_dim=head_dim,
|
||||
num_seq=num_decode,
|
||||
page_size=k_cache.shape[1],
|
||||
page_size=page_size,
|
||||
q_dtype=q_decode.dtype,
|
||||
kv_dtype=k_cache.dtype,
|
||||
kv_dtype=kv_cache.dtype,
|
||||
sm_scale=scale,
|
||||
)
|
||||
|
||||
@ -429,7 +428,7 @@ def flashinfer_mha_with_cache(
|
||||
|
||||
y_decode = wrapper_decode.run(
|
||||
q_decode,
|
||||
(k_cache, v_cache),
|
||||
kv_cache,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
enable_pdl=get_env_enable_pdl(),
|
||||
@ -460,9 +459,8 @@ def flashinfer_mha_with_cache_fake(
|
||||
# EXTRA METADATA
|
||||
flashinfer_batch_indices: torch.Tensor,
|
||||
flashinfer_positions: torch.Tensor,
|
||||
# CACHES
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
# CACHES - combined KV cache
|
||||
kv_cache: torch.Tensor,
|
||||
# CONSTANTS
|
||||
scale: Optional[float],
|
||||
k_scale: float,
|
||||
@ -525,16 +523,13 @@ class FlashInferAttention(AttentionDescriptor):
|
||||
head_dim = k_fake.shape[3]
|
||||
|
||||
return {
|
||||
"k_cache": PagedResourceHandler(
|
||||
"kv_cache": KVPagedResourceHandler(
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
dtype=cls.resolve_cache_dtype(cache_config.dtype, k_fake.dtype),
|
||||
),
|
||||
"v_cache": PagedResourceHandler(
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
dtype=cls.resolve_cache_dtype(cache_config.dtype, k_fake.dtype),
|
||||
),
|
||||
kv_factor=2,
|
||||
kv_layout=_GlobalFlashInferPlanner.kv_layout,
|
||||
)
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -39,10 +39,10 @@ from ..attention_interface import (
|
||||
AttentionDescriptor,
|
||||
AttentionLayout,
|
||||
AttentionRegistry,
|
||||
CausalConvResourceHandler,
|
||||
Constant,
|
||||
MHACallable,
|
||||
ResourceHandlerDict,
|
||||
StateResourceHandler,
|
||||
)
|
||||
|
||||
|
||||
@ -197,10 +197,12 @@ class CudaBackendCausalConv(AttentionDescriptor):
|
||||
in_channels = inp_fake.shape[-1]
|
||||
kernel_size = w_fake.shape[-1]
|
||||
|
||||
conv_state_handler = StateResourceHandler(
|
||||
in_channels,
|
||||
max(1, kernel_size - 1),
|
||||
# NOTE: not configurable at the moment, using auto to match the input dtype
|
||||
# NOTE: cuda backend stores kernel_size - 1 elements in state.
|
||||
# CausalConvResourceHandler.state_shape = (conv_dim, d_conv - 1), so d_conv = kernel_size.
|
||||
# Ensure d_conv >= 1 (state_shape[-1] >= 0).
|
||||
conv_state_handler = CausalConvResourceHandler(
|
||||
conv_dim=in_channels,
|
||||
d_conv=max(1, kernel_size), # state_shape[-1] = d_conv - 1 = kernel_size - 1
|
||||
dtype=cls.resolve_cache_dtype("auto", inp_fake.dtype),
|
||||
)
|
||||
return {"conv_state_cache": conv_state_handler}
|
||||
|
||||
@ -4,13 +4,7 @@ import torch
|
||||
from torch.fx import Node
|
||||
|
||||
from .....llmapi.llm_args import KvCacheConfig
|
||||
from ..attention_interface import (
|
||||
AttentionRegistry,
|
||||
MHACallable,
|
||||
ResourceHandler,
|
||||
ResourceHandlerDict,
|
||||
SequenceInfo,
|
||||
)
|
||||
from ..attention_interface import AttentionRegistry, MHACallable, ResourceHandlerDict
|
||||
from .mamba_backend_common import (
|
||||
BaseBackendSSM,
|
||||
_flatten_ssm_inputs,
|
||||
@ -177,29 +171,6 @@ def _flashinfer_cached_ssm_fake(
|
||||
FLASHINFER_SUPPORTED_HEAD_DIMS = [64, 128]
|
||||
|
||||
|
||||
class FlashInferStateResourceHandler(ResourceHandler):
|
||||
"""Handler for flashinfer SSM state resources.
|
||||
|
||||
Unlike the default StateResourceHandler which uses byte-level pooling (resulting
|
||||
in non-contiguous strided views), this handler allocates a separate contiguous
|
||||
buffer. This is required because flashinfer's selective_state_update kernel
|
||||
requires the entire state tensor to be contiguous.
|
||||
"""
|
||||
|
||||
def __init__(self, *state_shape: int, dtype: torch.dtype) -> None:
|
||||
self.state_shape = state_shape
|
||||
self.dtype = dtype
|
||||
|
||||
def allocate(self, sequence_info: SequenceInfo) -> torch.Tensor:
|
||||
"""Allocate a contiguous state buffer for flashinfer."""
|
||||
return torch.empty(
|
||||
sequence_info.max_num_state_slots,
|
||||
*self.state_shape,
|
||||
device=sequence_info.device,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
|
||||
@AttentionRegistry.register("flashinfer_ssm")
|
||||
class FlashinferBackendSSM(BaseBackendSSM):
|
||||
@classmethod
|
||||
@ -210,37 +181,14 @@ class FlashinferBackendSSM(BaseBackendSSM):
|
||||
def get_cache_initializers(
|
||||
cls, source_attn_node: Node, cache_config: KvCacheConfig
|
||||
) -> ResourceHandlerDict:
|
||||
"""Get cache initializers using FlashInferStateResourceHandler.
|
||||
ret = super().get_cache_initializers(source_attn_node, cache_config)
|
||||
|
||||
We use a custom handler that allocates contiguous buffers directly,
|
||||
instead of the default StateResourceHandler which creates non-contiguous
|
||||
views from a shared byte buffer. This is required because flashinfer's
|
||||
selective_state_update kernel requires contiguous state tensors.
|
||||
"""
|
||||
# Shapes from fake tensors
|
||||
hs_fake: torch.Tensor = source_attn_node.args[0].meta["val"]
|
||||
B_fake: torch.Tensor = source_attn_node.args[2].meta["val"]
|
||||
|
||||
num_heads = hs_fake.shape[-2]
|
||||
head_dim = hs_fake.shape[-1]
|
||||
|
||||
# Validate head_dim is supported by flashinfer
|
||||
if head_dim not in FLASHINFER_SUPPORTED_HEAD_DIMS:
|
||||
# check head_dim is supported by flashinfer
|
||||
if ret["ssm_state_cache"].head_dim not in FLASHINFER_SUPPORTED_HEAD_DIMS:
|
||||
raise ValueError(
|
||||
f"Flashinfer SSM backend only supports head_dim in {FLASHINFER_SUPPORTED_HEAD_DIMS}, "
|
||||
f"but got head_dim={head_dim}. Consider using 'triton_ssm' backend instead."
|
||||
f"flashinfer_ssm only supports head_dim in {FLASHINFER_SUPPORTED_HEAD_DIMS}. "
|
||||
f"Got head_dim={ret['ssm_state_cache'].head_dim}. "
|
||||
"Consider using 'triton_ssm' backend instead."
|
||||
)
|
||||
|
||||
if B_fake.ndim >= 4:
|
||||
ssm_state_size = B_fake.shape[-1]
|
||||
else:
|
||||
ssm_state_size = max(1, B_fake.shape[-1])
|
||||
|
||||
# Extract ssm_state_dtype from cache_config or hs_fake
|
||||
ssm_state_dtype = cls.resolve_cache_dtype(cache_config.mamba_ssm_cache_dtype, hs_fake.dtype)
|
||||
|
||||
return {
|
||||
"ssm_state_cache": FlashInferStateResourceHandler(
|
||||
num_heads, head_dim, ssm_state_size, dtype=ssm_state_dtype
|
||||
)
|
||||
}
|
||||
return ret
|
||||
|
||||
@ -30,7 +30,7 @@ from ..attention_interface import (
|
||||
Constant,
|
||||
PrepareMetadataCallable,
|
||||
ResourceHandlerDict,
|
||||
StateResourceHandler,
|
||||
SSMResourceHandler,
|
||||
)
|
||||
|
||||
|
||||
@ -278,8 +278,11 @@ class BaseBackendSSM(AttentionDescriptor):
|
||||
ssm_state_dtype = cls.resolve_cache_dtype(cache_config.mamba_ssm_cache_dtype, hs_fake.dtype)
|
||||
|
||||
return {
|
||||
"ssm_state_cache": StateResourceHandler(
|
||||
num_heads, head_dim, ssm_state_size, dtype=ssm_state_dtype
|
||||
"ssm_state_cache": SSMResourceHandler(
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
d_state=ssm_state_size,
|
||||
dtype=ssm_state_dtype,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@ -22,10 +22,10 @@ from ..attention_interface import (
|
||||
AttentionDescriptor,
|
||||
AttentionLayout,
|
||||
AttentionRegistry,
|
||||
CausalConvResourceHandler,
|
||||
Constant,
|
||||
MHACallable,
|
||||
ResourceHandlerDict,
|
||||
StateResourceHandler,
|
||||
)
|
||||
|
||||
|
||||
@ -302,11 +302,12 @@ class TorchBackendCausalConv(AttentionDescriptor):
|
||||
in_channels = inp_fake.shape[-1]
|
||||
kernel_size = w_fake.shape[-1]
|
||||
|
||||
# NOTE: torch backend stores kernel_size elements in state (full conv window).
|
||||
# CausalConvResourceHandler.state_shape = (conv_dim, d_conv - 1), so d_conv = kernel_size + 1.
|
||||
return {
|
||||
"conv_state_cache": StateResourceHandler(
|
||||
in_channels,
|
||||
kernel_size,
|
||||
# NOTE: not configurable at the moment, using auto to match the input dtype
|
||||
"conv_state_cache": CausalConvResourceHandler(
|
||||
conv_dim=in_channels,
|
||||
d_conv=kernel_size + 1, # state_shape[-1] = d_conv - 1 = kernel_size
|
||||
dtype=cls.resolve_cache_dtype("auto", inp_fake.dtype),
|
||||
)
|
||||
}
|
||||
|
||||
@ -21,7 +21,7 @@ from ..attention_interface import (
|
||||
Constant,
|
||||
MHACallable,
|
||||
ResourceHandlerDict,
|
||||
StateResourceHandler,
|
||||
SSMResourceHandler,
|
||||
)
|
||||
from .torch_mamba import _torch_ssm_prefill
|
||||
|
||||
@ -312,8 +312,11 @@ class TorchBackendSSM(AttentionDescriptor):
|
||||
ssm_state_dtype = cls.resolve_cache_dtype(cache_config.mamba_ssm_cache_dtype, hs_fake.dtype)
|
||||
|
||||
return {
|
||||
"ssm_state_cache": StateResourceHandler(
|
||||
num_heads, head_dim, ssm_state_size, dtype=ssm_state_dtype
|
||||
"ssm_state_cache": SSMResourceHandler(
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
d_state=ssm_state_size,
|
||||
dtype=ssm_state_dtype,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
"""Custom ops for MultiHead Latent attention."""
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
@ -15,7 +16,7 @@ from .attention_interface import (
|
||||
ResourceHandlerDict,
|
||||
UnpagedResourceHandler,
|
||||
)
|
||||
from .triton_attention import _flattened_context_mha, _generate_mha
|
||||
from .triton_attention import _decode_attention, _prefill_attention
|
||||
|
||||
Constant = Union[int, float, str, None]
|
||||
|
||||
@ -140,11 +141,15 @@ def fused_flattened_mla_with_cache(
|
||||
query_states = torch.cat((q_nope, q_pe), dim=-1) # [b*s,n,d]
|
||||
key_states = torch.cat((k_nope, k_pe.expand(*bs_view, num_heads, -1)), dim=-1) # [b*s,n,d]
|
||||
|
||||
# Compute scale if not provided
|
||||
qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
||||
scale = softmax_scale if softmax_scale is not None else 1.0 / math.sqrt(qk_head_dim)
|
||||
|
||||
# Compute attention
|
||||
y = torch.empty_like(value_states)
|
||||
if s == 1:
|
||||
# generate-only phase
|
||||
_generate_mha(
|
||||
# generate-only phase (decode)
|
||||
_decode_attention(
|
||||
query_states.contiguous(),
|
||||
key_states.contiguous(),
|
||||
value_states.contiguous(),
|
||||
@ -152,21 +157,23 @@ def fused_flattened_mla_with_cache(
|
||||
v_cache,
|
||||
cache_loc,
|
||||
input_pos,
|
||||
scale,
|
||||
y,
|
||||
)
|
||||
|
||||
else:
|
||||
# mixed context + generate phase
|
||||
_flattened_context_mha(
|
||||
# mixed context + generate phase (prefill)
|
||||
_prefill_attention(
|
||||
query_states.contiguous(),
|
||||
key_states.contiguous(),
|
||||
value_states.contiguous(),
|
||||
input_pos,
|
||||
cache_loc,
|
||||
k_cache,
|
||||
v_cache,
|
||||
input_pos,
|
||||
cache_loc,
|
||||
seq_len,
|
||||
seq_start,
|
||||
scale,
|
||||
y,
|
||||
)
|
||||
|
||||
|
||||
@ -236,7 +236,7 @@ def update_kv_cache(
|
||||
v_cache: torch.Tensor,
|
||||
seq_len: torch.Tensor, # metadata
|
||||
input_pos: torch.Tensor, # metadata
|
||||
cache_loc: torch.Tensor,
|
||||
slot_idx: torch.Tensor,
|
||||
seq_start: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@ -245,12 +245,12 @@ def update_kv_cache(
|
||||
"""
|
||||
|
||||
for idx in range(seq_len.shape[0]):
|
||||
k_cache[cache_loc[idx], input_pos[idx] : input_pos[idx] + seq_len[idx], :, :] = key_states[
|
||||
k_cache[slot_idx[idx], input_pos[idx] : input_pos[idx] + seq_len[idx], :, :] = key_states[
|
||||
seq_start[idx] : seq_start[idx] + seq_len[idx], ...
|
||||
]
|
||||
v_cache[slot_idx[idx], input_pos[idx] : input_pos[idx] + seq_len[idx], :, :] = value_states[
|
||||
seq_start[idx] : seq_start[idx] + seq_len[idx], ...
|
||||
]
|
||||
v_cache[cache_loc[idx], input_pos[idx] : input_pos[idx] + seq_len[idx], :, :] = (
|
||||
value_states[seq_start[idx] : seq_start[idx] + seq_len[idx], ...]
|
||||
)
|
||||
|
||||
|
||||
@torch.library.custom_op("auto_deploy::torch_attention_fused_mla_ref", mutates_args=())
|
||||
|
||||
@ -36,7 +36,7 @@ def _torch_generate_mha(
|
||||
v: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
cache_loc: torch.Tensor,
|
||||
slot_idx: torch.Tensor,
|
||||
input_pos: torch.Tensor,
|
||||
scale: float,
|
||||
out: torch.Tensor,
|
||||
@ -51,14 +51,14 @@ def _torch_generate_mha(
|
||||
|
||||
# Update KV cache for single token
|
||||
for i in range(b):
|
||||
cache_idx = cache_loc[i].item()
|
||||
cache_idx = slot_idx[i].item()
|
||||
pos = input_pos[i].item()
|
||||
k_cache[cache_idx, pos] = k[i, 0] # Remove sequence dim
|
||||
v_cache[cache_idx, pos] = v[i, 0] # Remove sequence dim
|
||||
|
||||
# Compute attention for each sequence using manual computation
|
||||
for i in range(b):
|
||||
cache_idx = cache_loc[i].item()
|
||||
cache_idx = slot_idx[i].item()
|
||||
pos = input_pos[i].item()
|
||||
|
||||
# Get query, key, value for this sequence
|
||||
@ -121,7 +121,7 @@ def _torch_context_mha(
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
input_pos: torch.Tensor,
|
||||
cache_loc: torch.Tensor,
|
||||
slot_idx: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
seq_len: torch.Tensor,
|
||||
@ -134,14 +134,14 @@ def _torch_context_mha(
|
||||
) -> None:
|
||||
"""Context attention (multiple tokens, potentially multiple sequences) using existing torch functions."""
|
||||
# Update KV cache first using existing function
|
||||
update_kv_cache(k, v, k_cache, v_cache, seq_len, input_pos, cache_loc, seq_start)
|
||||
update_kv_cache(k, v, k_cache, v_cache, seq_len, input_pos, slot_idx, seq_start)
|
||||
|
||||
# Compute attention for each sequence
|
||||
attn_outputs = []
|
||||
for idx in range(seq_len.shape[0]):
|
||||
seq_len_i = seq_len[idx].item()
|
||||
input_pos_i = input_pos[idx].item()
|
||||
cache_loc_i = cache_loc[idx].item()
|
||||
slot_idx_i = slot_idx[idx].item()
|
||||
seq_start_i = seq_start[idx].item()
|
||||
|
||||
# Skip sequences with zero length
|
||||
@ -153,8 +153,8 @@ def _torch_context_mha(
|
||||
|
||||
# Get keys and values from cache
|
||||
kv_seq_len = input_pos_i + seq_len_i
|
||||
k_seq = k_cache[cache_loc_i, :kv_seq_len] # [kv_seq_len, n_kv_heads, head_dim]
|
||||
v_seq = v_cache[cache_loc_i, :kv_seq_len] # [kv_seq_len, n_kv_heads, head_dim]
|
||||
k_seq = k_cache[slot_idx_i, :kv_seq_len] # [kv_seq_len, n_kv_heads, head_dim]
|
||||
v_seq = v_cache[slot_idx_i, :kv_seq_len] # [kv_seq_len, n_kv_heads, head_dim]
|
||||
|
||||
# Manual attention computation (shared path for both softcapping and non-softcapping)
|
||||
n_heads = q_seq.shape[1]
|
||||
@ -255,7 +255,7 @@ def torch_backend_mha_with_cache(
|
||||
batch_info_host: torch.Tensor,
|
||||
seq_len: torch.Tensor,
|
||||
input_pos: torch.Tensor,
|
||||
cache_loc: torch.Tensor,
|
||||
slot_idx: torch.Tensor,
|
||||
cu_seqlen: torch.Tensor,
|
||||
# EXTRA METADATA
|
||||
#
|
||||
@ -281,7 +281,7 @@ def torch_backend_mha_with_cache(
|
||||
num_seq = num_prefill + num_decode
|
||||
seq_len = seq_len[:num_seq]
|
||||
input_pos = input_pos[:num_seq]
|
||||
cache_loc = cache_loc[:num_seq]
|
||||
slot_idx = slot_idx[:num_seq]
|
||||
seq_start = cu_seqlen[:num_seq]
|
||||
|
||||
# check for num_heads
|
||||
@ -314,7 +314,7 @@ def torch_backend_mha_with_cache(
|
||||
v,
|
||||
k_cache,
|
||||
v_cache,
|
||||
cache_loc,
|
||||
slot_idx,
|
||||
input_pos,
|
||||
scale,
|
||||
y,
|
||||
@ -329,7 +329,7 @@ def torch_backend_mha_with_cache(
|
||||
k,
|
||||
v,
|
||||
input_pos,
|
||||
cache_loc,
|
||||
slot_idx,
|
||||
k_cache,
|
||||
v_cache,
|
||||
seq_len,
|
||||
@ -354,7 +354,7 @@ def torch_backend_mha_with_cache_fake(
|
||||
batch_info_host: torch.Tensor,
|
||||
seq_len: torch.Tensor,
|
||||
input_pos: torch.Tensor,
|
||||
cache_loc: torch.Tensor,
|
||||
slot_idx: torch.Tensor,
|
||||
cu_seqlen: torch.Tensor,
|
||||
# EXTRA METADATA
|
||||
#
|
||||
@ -394,7 +394,7 @@ class TorchBackendAttention(AttentionDescriptor):
|
||||
|
||||
@classmethod
|
||||
def get_standard_metadata_args(cls) -> List[str]:
|
||||
return ["batch_info_host", "seq_len", "input_pos", "cache_loc", "cu_seqlen"]
|
||||
return ["batch_info_host", "seq_len", "input_pos", "slot_idx", "cu_seqlen"]
|
||||
|
||||
@classmethod
|
||||
def get_cache_initializers(
|
||||
|
||||
@ -29,20 +29,22 @@ from .triton_kernels.attention_with_kv_cache import (
|
||||
)
|
||||
|
||||
|
||||
def _generate_mha(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
def _decode_attention(
|
||||
q: torch.Tensor, # [num_decode, num_heads, qk_head_dim]
|
||||
k: torch.Tensor, # [num_decode, num_kv_heads, qk_head_dim]
|
||||
v: torch.Tensor, # [num_decode, num_kv_heads, v_head_dim]
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
cache_locs: torch.Tensor,
|
||||
input_pos: torch.Tensor,
|
||||
slot_idx: torch.Tensor, # [num_decode]
|
||||
input_pos: torch.Tensor, # [num_decode]
|
||||
scale: float,
|
||||
out: torch.Tensor,
|
||||
out: torch.Tensor, # [num_decode, num_heads, v_head_dim]
|
||||
sinks: Optional[torch.Tensor] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
):
|
||||
b, (n_heads, q_d_head) = q.shape[0], q.shape[-2:]
|
||||
) -> None:
|
||||
"""Handle decode phase - single token generation attention."""
|
||||
num_decode = q.shape[0]
|
||||
n_heads, q_d_head = q.shape[-2:]
|
||||
max_seq_len, n_kv_heads = k_cache.shape[1:3]
|
||||
v_d_head = v.shape[-1]
|
||||
device = q.device
|
||||
@ -50,13 +52,13 @@ def _generate_mha(
|
||||
SEQ_BLOCK_SIZE = 64
|
||||
num_blocks = (max_seq_len + SEQ_BLOCK_SIZE - 1) // SEQ_BLOCK_SIZE
|
||||
stage1_output_values = torch.empty(
|
||||
b, n_heads, num_blocks, v_d_head, device=device, dtype=torch.float32
|
||||
num_decode, n_heads, num_blocks, v_d_head, device=device, dtype=torch.float32
|
||||
)
|
||||
stage1_output_logsumexp = torch.empty(
|
||||
b, n_heads, num_blocks, device=device, dtype=torch.float32
|
||||
num_decode, n_heads, num_blocks, device=device, dtype=torch.float32
|
||||
) - float("inf")
|
||||
|
||||
update_kv_cache[(b, n_kv_heads, 1)](
|
||||
update_kv_cache[(num_decode, n_kv_heads, 1)](
|
||||
k,
|
||||
v,
|
||||
None,
|
||||
@ -64,7 +66,7 @@ def _generate_mha(
|
||||
k_cache,
|
||||
v_cache,
|
||||
input_pos,
|
||||
cache_locs,
|
||||
slot_idx,
|
||||
max_seq_len,
|
||||
n_kv_heads,
|
||||
q_d_head,
|
||||
@ -76,7 +78,7 @@ def _generate_mha(
|
||||
HEAD_BLOCK_SIZE = max(16, triton.next_power_of_2(n_heads // n_kv_heads))
|
||||
gqa_attention_kv_stage1[
|
||||
(
|
||||
b,
|
||||
num_decode,
|
||||
n_kv_heads,
|
||||
num_blocks,
|
||||
)
|
||||
@ -84,7 +86,7 @@ def _generate_mha(
|
||||
q,
|
||||
k_cache,
|
||||
v_cache,
|
||||
cache_locs,
|
||||
slot_idx,
|
||||
input_pos,
|
||||
stage1_output_values,
|
||||
stage1_output_logsumexp,
|
||||
@ -101,7 +103,7 @@ def _generate_mha(
|
||||
)
|
||||
has_sinks = sinks is not None
|
||||
|
||||
attention_kv_stage2[(b, n_heads, 1)](
|
||||
attention_kv_stage2[(num_decode, n_heads, 1)](
|
||||
stage1_output_values,
|
||||
stage1_output_logsumexp,
|
||||
out,
|
||||
@ -115,29 +117,30 @@ def _generate_mha(
|
||||
)
|
||||
|
||||
|
||||
def _flattened_context_mha(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
input_pos: torch.Tensor,
|
||||
cache_loc: torch.Tensor,
|
||||
def _prefill_attention(
|
||||
q: torch.Tensor, # [num_prefill_tokens, num_heads, qk_head_dim]
|
||||
k: torch.Tensor, # [num_prefill_tokens, num_kv_heads, qk_head_dim]
|
||||
v: torch.Tensor, # [num_prefill_tokens, num_kv_heads, v_head_dim]
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
seq_len: torch.Tensor,
|
||||
seq_start: torch.Tensor,
|
||||
input_pos: torch.Tensor, # [num_prefill]
|
||||
slot_idx: torch.Tensor, # [num_prefill]
|
||||
seq_len: torch.Tensor, # [num_prefill]
|
||||
seq_start: torch.Tensor, # [num_prefill]
|
||||
scale: float,
|
||||
out: torch.Tensor,
|
||||
out: torch.Tensor, # [num_prefill_tokens, num_heads, v_head_dim]
|
||||
sinks: Optional[torch.Tensor] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
) -> None:
|
||||
# NOTE: s_total == sum(seq_len)
|
||||
s_total, n_heads, q_d_head = q.shape
|
||||
"""Handle prefill phase - context attention with variable sequence lengths."""
|
||||
# NOTE: num_prefill_tokens == sum(seq_len)
|
||||
num_prefill_tokens, n_heads, q_d_head = q.shape
|
||||
max_cache_seq_len, n_kv_heads = k_cache.shape[1:3]
|
||||
v_d_head = v.shape[-1]
|
||||
BATCH_SIZE: int = len(input_pos)
|
||||
num_prefill = len(input_pos)
|
||||
SEQ_BLOCK = 32
|
||||
|
||||
update_kv_cache[(BATCH_SIZE, n_kv_heads, (max(seq_len) + SEQ_BLOCK - 1) // SEQ_BLOCK)](
|
||||
update_kv_cache[(num_prefill, n_kv_heads, (max(seq_len) + SEQ_BLOCK - 1) // SEQ_BLOCK)](
|
||||
k,
|
||||
v,
|
||||
seq_len,
|
||||
@ -145,7 +148,7 @@ def _flattened_context_mha(
|
||||
k_cache,
|
||||
v_cache,
|
||||
input_pos,
|
||||
cache_loc,
|
||||
slot_idx,
|
||||
max_cache_seq_len,
|
||||
n_kv_heads,
|
||||
q_d_head,
|
||||
@ -154,8 +157,7 @@ def _flattened_context_mha(
|
||||
GENERATE_ONLY=False,
|
||||
)
|
||||
|
||||
# TODO: use input_pos to get the correct cache locations
|
||||
grid = (BATCH_SIZE, n_heads, (max(seq_len) + SEQ_BLOCK - 1) // SEQ_BLOCK)
|
||||
grid = (num_prefill, n_heads, (max(seq_len) + SEQ_BLOCK - 1) // SEQ_BLOCK)
|
||||
has_sinks = sinks is not None
|
||||
|
||||
context_attention_kv_flattened[grid](
|
||||
@ -165,7 +167,7 @@ def _flattened_context_mha(
|
||||
k_cache,
|
||||
v_cache,
|
||||
input_pos,
|
||||
cache_loc,
|
||||
slot_idx,
|
||||
out,
|
||||
scale,
|
||||
n_heads,
|
||||
@ -190,7 +192,7 @@ def flattened_mha_with_cache(
|
||||
batch_info_host: torch.Tensor,
|
||||
seq_len: torch.Tensor,
|
||||
input_pos: torch.Tensor,
|
||||
cache_loc: torch.Tensor,
|
||||
slot_idx: torch.Tensor,
|
||||
cu_seqlen: torch.Tensor,
|
||||
# EXTRA METADATA
|
||||
#
|
||||
@ -208,73 +210,64 @@ def flattened_mha_with_cache(
|
||||
|
||||
NOTE: this op can also handle seq_len==0, which might be useful for CUDAGRAPH.
|
||||
"""
|
||||
# check for sequence info and truncate metadata
|
||||
# Extract batch info from batch_info_host
|
||||
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
|
||||
num_seq = num_prefill + num_decode
|
||||
num_total_tokens = num_prefill_tokens + num_decode
|
||||
|
||||
seq_len = seq_len[:num_seq]
|
||||
input_pos = input_pos[:num_seq]
|
||||
cache_loc = cache_loc[:num_seq]
|
||||
seq_start = cu_seqlen[:num_seq]
|
||||
|
||||
# b, s info
|
||||
# NOTE: b, s are just the shapes of the input tensor q; not necessarily the number of sequences.
|
||||
# Generally speaking, we expect one of two cases here:
|
||||
# 1. b > 0, s==1: this indicates a generate-only batch of tokens.
|
||||
# 2. b==1, s > 0: this indicates a mixed context+generate phase. The actual number of sequences
|
||||
# and number of tokens per sequence are encoded in seq_len and seq_start.
|
||||
# Get cache and head dimensions
|
||||
num_kv_heads, qk_head_dim = k_cache.shape[-2:]
|
||||
v_head_dim = v_cache.shape[-1]
|
||||
b, s = q.shape[:2]
|
||||
|
||||
# check for num_heads
|
||||
# Determine num_heads from input shape
|
||||
num_heads = q.shape[2] // qk_head_dim if q.ndim == 3 else q.shape[2]
|
||||
|
||||
# Define output shape
|
||||
# Define output shape (preserve original input format)
|
||||
output_shape = (b, s, num_heads * v_head_dim) if q.ndim == 3 else (b, s, num_heads, v_head_dim)
|
||||
|
||||
# reshapes with head_dim
|
||||
if s == 1:
|
||||
bs_view = (b, s)
|
||||
else:
|
||||
bs_view = (b * s,)
|
||||
|
||||
q = q.contiguous().view(*bs_view, num_heads, qk_head_dim)
|
||||
k = k.contiguous().view(*bs_view, num_kv_heads, qk_head_dim)
|
||||
v = v.contiguous().view(*bs_view, num_kv_heads, v_head_dim)
|
||||
# Flatten Q, K, V to [total_tokens, heads, head_dim]
|
||||
bs = b * s
|
||||
q_flat = q.contiguous().view(bs, num_heads, qk_head_dim)
|
||||
k_flat = k.contiguous().view(bs, num_kv_heads, qk_head_dim)
|
||||
v_flat = v.contiguous().view(bs, num_kv_heads, v_head_dim)
|
||||
|
||||
# Compute scale if not provided
|
||||
scale = 1.0 / math.sqrt(qk_head_dim) if scale is None else scale
|
||||
# run attention
|
||||
y = q.new_empty(*bs_view, num_heads, v_head_dim).contiguous()
|
||||
if s == 1:
|
||||
# generate-only phase
|
||||
_generate_mha(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
|
||||
# Preallocate output tensor
|
||||
y = q_flat.new_empty(bs, num_heads, v_head_dim)
|
||||
|
||||
# PREFILL: process context tokens with variable sequence lengths
|
||||
if num_prefill > 0:
|
||||
_prefill_attention(
|
||||
q_flat[:num_prefill_tokens],
|
||||
k_flat[:num_prefill_tokens],
|
||||
v_flat[:num_prefill_tokens],
|
||||
k_cache,
|
||||
v_cache,
|
||||
cache_loc,
|
||||
input_pos,
|
||||
input_pos[:num_prefill],
|
||||
slot_idx[:num_prefill],
|
||||
seq_len[:num_prefill],
|
||||
cu_seqlen[:num_prefill],
|
||||
scale,
|
||||
y,
|
||||
y[:num_prefill_tokens],
|
||||
sinks,
|
||||
sliding_window,
|
||||
)
|
||||
else:
|
||||
# mixed context + generate phase
|
||||
_flattened_context_mha(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
input_pos,
|
||||
cache_loc,
|
||||
|
||||
# DECODE: process single-token generation
|
||||
if num_decode > 0:
|
||||
_decode_attention(
|
||||
q_flat[num_prefill_tokens:num_total_tokens],
|
||||
k_flat[num_prefill_tokens:num_total_tokens],
|
||||
v_flat[num_prefill_tokens:num_total_tokens],
|
||||
k_cache,
|
||||
v_cache,
|
||||
seq_len,
|
||||
seq_start,
|
||||
slot_idx[num_prefill:num_seq],
|
||||
input_pos[num_prefill:num_seq],
|
||||
scale,
|
||||
y,
|
||||
y[num_prefill_tokens:num_total_tokens],
|
||||
sinks,
|
||||
sliding_window,
|
||||
)
|
||||
@ -292,7 +285,7 @@ def flattened_mha_fake(
|
||||
batch_info_host: torch.Tensor,
|
||||
seq_len: torch.Tensor,
|
||||
input_pos: torch.Tensor,
|
||||
cache_loc: torch.Tensor,
|
||||
slot_idx: torch.Tensor,
|
||||
cu_seqlen: torch.Tensor,
|
||||
# EXTRA METADATA
|
||||
#
|
||||
@ -331,7 +324,7 @@ class TritonAttention(AttentionDescriptor):
|
||||
|
||||
@classmethod
|
||||
def get_standard_metadata_args(cls) -> List[str]:
|
||||
return ["batch_info_host", "seq_len", "input_pos", "cache_loc", "cu_seqlen"]
|
||||
return ["batch_info_host", "seq_len", "input_pos", "slot_idx", "cu_seqlen"]
|
||||
|
||||
@classmethod
|
||||
def get_cache_initializers(
|
||||
|
||||
@ -13,7 +13,7 @@ def update_kv_cache(
|
||||
k_cache_ptr, # [MAX_BATCH_SIZE, MAX_SEQ_LEN, N_HEADS, D_HEAD]
|
||||
v_cache_ptr, # [MAX_BATCH_SIZE, MAX_SEQ_LEN, N_HEADS, D_HEAD]
|
||||
input_pos_ptr, # Specifies the sequence index in the caches at which to write the provided kv
|
||||
cache_loc_ptr, # Specifies the batch index for each of the input sequences
|
||||
slot_idx_ptr, # Specifies the slot index for each of the input sequences
|
||||
MAX_SEQ_LENGTH: tl.constexpr,
|
||||
N_KV_HEADS: tl.constexpr,
|
||||
Q_D_HEAD: tl.constexpr,
|
||||
@ -34,14 +34,14 @@ def update_kv_cache(
|
||||
seq_len = tl.load(seq_len_ptr + batch_id)
|
||||
|
||||
# cache is [bsnd]
|
||||
# cache_loc_ptr stores the batch index for the sequences provided to the kernel.
|
||||
cache_loc = tl.load(cache_loc_ptr + batch_id)
|
||||
# slot_idx_ptr stores the slot index for the sequences provided to the kernel.
|
||||
slot_idx = tl.load(slot_idx_ptr + batch_id)
|
||||
|
||||
kv_position = tl.load(input_pos_ptr + batch_id)
|
||||
|
||||
K_D_HEAD: tl.constexpr = Q_D_HEAD
|
||||
k_cache_batch_offset = cache_loc * N_KV_HEADS * MAX_SEQ_LENGTH * K_D_HEAD
|
||||
v_cache_batch_offset = cache_loc * N_KV_HEADS * MAX_SEQ_LENGTH * V_D_HEAD
|
||||
k_cache_batch_offset = slot_idx * N_KV_HEADS * MAX_SEQ_LENGTH * K_D_HEAD
|
||||
v_cache_batch_offset = slot_idx * N_KV_HEADS * MAX_SEQ_LENGTH * V_D_HEAD
|
||||
|
||||
k_dhead_offsets = tl.arange(0, triton.next_power_of_2(K_D_HEAD))
|
||||
k_dhead_mask = k_dhead_offsets < K_D_HEAD
|
||||
@ -99,7 +99,7 @@ def gqa_attention_kv_stage1(
|
||||
q_ptr, # [Batch, 1, N_HEADS, D_HEAD]
|
||||
k_cache_ptr, # [MAX_BATCH_SIZE, MAX_SEQ_LEN, N_HEADS, D_HEAD]
|
||||
v_cache_ptr, # [MAX_BATCH_SIZE, MAX_SEQ_LEN, N_HEADS, D_HEAD]
|
||||
cache_loc_ptr, # [Batch] # Specifies the batch index for each of the generate tokens.
|
||||
slot_idx_ptr, # [Batch] # Specifies the slot index for each of the generate tokens.
|
||||
input_pos_ptr, # [Batch]
|
||||
output_values_ptr, # [Batch, N_HEADS, num_blocks, D_HEAD]
|
||||
output_logsumexp_ptr, # [Batch, N_HEADS, num_blocks]
|
||||
@ -137,9 +137,9 @@ def gqa_attention_kv_stage1(
|
||||
seq_block_id = tl.program_id(axis=2)
|
||||
|
||||
kv_position = tl.load(input_pos_ptr + batch_id)
|
||||
kv_batch_id = tl.load(cache_loc_ptr + batch_id)
|
||||
slot_idx = tl.load(slot_idx_ptr + batch_id)
|
||||
K_D_HEAD: tl.constexpr = Q_D_HEAD
|
||||
batch_offset = kv_batch_id * N_KV_HEADS * MAX_SEQ_LEN
|
||||
batch_offset = slot_idx * N_KV_HEADS * MAX_SEQ_LEN
|
||||
|
||||
# Offsets for the block of sequences this program processes.
|
||||
seq_start_pos = seq_block_id * SEQ_BLOCK_SIZE
|
||||
@ -252,7 +252,7 @@ def attention_kv_stage1(
|
||||
q_ptr, # [Batch, 1, N_HEADS, D_HEAD]
|
||||
k_cache_ptr, # [MAX_BATCH_SIZE, MAX_SEQ_LEN, N_HEADS, D_HEAD]
|
||||
v_cache_ptr, # [MAX_BATCH_SIZE, MAX_SEQ_LEN, N_HEADS, D_HEAD]
|
||||
cache_loc_ptr, # [Batch] # Specifies the batch index for each of the generate tokens.
|
||||
slot_idx_ptr, # [Batch] # Specifies the slot index for each of the generate tokens.
|
||||
input_pos_ptr, # [Batch]
|
||||
output_values_ptr, # [Batch, N_HEADS, num_blocks, D_HEAD]
|
||||
output_logsumexp_ptr, # [Batch, N_HEADS, num_blocks]
|
||||
@ -283,8 +283,8 @@ def attention_kv_stage1(
|
||||
epsilon: tl.constexpr = 1e-38 # float32 smallest positive number
|
||||
|
||||
kv_position = tl.load(input_pos_ptr + batch_id)
|
||||
kv_batch_id = tl.load(cache_loc_ptr + batch_id)
|
||||
kv_batch_offset = kv_batch_id * N_KV_HEADS * MAX_SEQ_LEN * D_HEAD
|
||||
slot_idx = tl.load(slot_idx_ptr + batch_id)
|
||||
slot_batch_offset = slot_idx * N_KV_HEADS * MAX_SEQ_LEN * D_HEAD
|
||||
# Offsets for the block of sequences this program processes.
|
||||
seq_start_pos = seq_block_id * SEQ_BLOCK_SIZE
|
||||
|
||||
@ -308,7 +308,7 @@ def attention_kv_stage1(
|
||||
q = tl.load(q_ptr + q_batch_offset + q_head_offset + dhead_offsets, mask=dhead_mask)
|
||||
|
||||
kv_block_offsets = (
|
||||
kv_batch_offset
|
||||
slot_batch_offset
|
||||
+ seq_offsets[:, None] * D_HEAD * N_KV_HEADS
|
||||
+ kv_head_offset
|
||||
+ dhead_offsets[None, :]
|
||||
@ -582,7 +582,7 @@ def context_attention_kv_flattened(
|
||||
k_cache_ptr, # [bsnd]
|
||||
v_cache_ptr, # [bsnd]
|
||||
input_pos_ptr, # [b] # specifies the location in the sequence where kv must be written back.
|
||||
cache_loc_ptr, # [b] # location of the sequence in the cache.
|
||||
slot_idx_ptr, # [b] # slot index of the sequence in the cache.
|
||||
o_ptr,
|
||||
SCALE: tl.constexpr,
|
||||
N_HEADS: tl.constexpr, # Number of heads
|
||||
@ -611,10 +611,10 @@ def context_attention_kv_flattened(
|
||||
HEAD_RATIO: tl.constexpr = N_HEADS // N_KV_HEADS
|
||||
|
||||
# cache is [bsnd]
|
||||
# cache_loc_ptr stores the batch index for the sequences provided to the kernel.
|
||||
cache_loc = tl.load(cache_loc_ptr + batch_id)
|
||||
# slot_idx_ptr stores the slot index for the sequences provided to the kernel.
|
||||
slot_idx = tl.load(slot_idx_ptr + batch_id)
|
||||
|
||||
cache_batch_offset = cache_loc * N_KV_HEADS * MAX_SEQ_LENGTH
|
||||
cache_batch_offset = slot_idx * N_KV_HEADS * MAX_SEQ_LENGTH
|
||||
cache_head_offset = head_id // HEAD_RATIO
|
||||
|
||||
q_dhead_offsets = tl.arange(0, triton.next_power_of_2(Q_D_HEAD))
|
||||
@ -735,7 +735,7 @@ def update_kv_cache_rope_fusion(
|
||||
k_cache_ptr, # [MAX_BATCH_SIZE, MAX_SEQ_LEN, N_HEADS, D_HEAD]
|
||||
v_cache_ptr, # [MAX_BATCH_SIZE, MAX_SEQ_LEN, N_HEADS, D_HEAD]
|
||||
input_pos_ptr, # Specifies the sequence index in the caches at which to write the provided kv
|
||||
cache_loc_ptr, # Specifies the batch index for each of the input sequences
|
||||
slot_idx_ptr, # Specifies the slot index for each of the input sequences
|
||||
f_ptr, # [MAX_SEQ_LEN, D_HEAD//2, 2] # frequencies for rope embadding.
|
||||
MAX_SEQ_LENGTH: tl.constexpr,
|
||||
N_HEADS: tl.constexpr,
|
||||
@ -766,12 +766,12 @@ def update_kv_cache_rope_fusion(
|
||||
seq_len = tl.load(seq_len_ptr + batch_id)
|
||||
|
||||
# cache is [bsnd]
|
||||
# cache_loc_ptr stores the batch index for the sequences provided to the kernel.
|
||||
cache_loc = tl.load(cache_loc_ptr + batch_id)
|
||||
# slot_idx_ptr stores the slot index for the sequences provided to the kernel.
|
||||
slot_idx = tl.load(slot_idx_ptr + batch_id)
|
||||
|
||||
kv_position = tl.load(input_pos_ptr + batch_id)
|
||||
|
||||
cache_batch_offset = cache_loc * N_KV_HEADS * MAX_SEQ_LENGTH * D_HEAD
|
||||
cache_batch_offset = slot_idx * N_KV_HEADS * MAX_SEQ_LENGTH * D_HEAD
|
||||
cache_head_offset = kv_head_id * D_HEAD
|
||||
|
||||
# Assuming D_HEAD is a power of 2
|
||||
|
||||
@ -1,395 +0,0 @@
|
||||
import triton
|
||||
from triton import language as tl
|
||||
|
||||
"""
|
||||
Kernels based on paged KV Cache.
|
||||
Parameter infos:
|
||||
tensors:
|
||||
- q: [b*s, n, d], flattened queries.
|
||||
- k/v: [b*s, n, d], flattened key/value.
|
||||
- seq_len: [b], length of each sequence in the batch.
|
||||
`seq_len` can be 1 (generate) or larger (context).
|
||||
- seq_start: [b], start index of each sequence in b*s dim of q/k/v.
|
||||
- k_cache/v_cache: [num_pages, PAGE_SIZE, n, d], paged KV Cache.
|
||||
New-coming k/v is split into small group of PAGE_SIZE, and then
|
||||
mapped to incontinuous memory in KV Cache.
|
||||
- page_table: [b, max_num_pages_per_seq], mapping logic of each sequence.
|
||||
- cache_loc: [b], mapping logic of `batch_id` in q/k/v to index in `page_table`.
|
||||
- cache_len: [b], existing cached k/v length of each sequence.
|
||||
|
||||
constexpr:
|
||||
- N_HEADS/N_KV_HEADS: shape of dim [n] in q or k/v.
|
||||
- D_HEAD: shape of dim [d] in q/k/v.
|
||||
Assuming power of 2.
|
||||
- SEQ_BLOCK: block size to split dim [s].
|
||||
Assuming power of 2.
|
||||
Split k/v in update kernel and split q in context/generate kernel.
|
||||
- MAX_SEQ_LENGTH: seq_len <= MAX_SEQ_LENGTH.
|
||||
- PAGE_SIZE: shape of each kv cache page,
|
||||
Assuming power of 2 and SEQ_BLOCK % PAGE_SIZE = 0.
|
||||
- PAGE_TABLE_STIDE: stride of dim [b] in `page_table`.
|
||||
|
||||
KV Cache access logic in update kernel:
|
||||
1. batch_id i access k[seq_start[i] : seq_start[i] + seq_len[i]]
|
||||
and can be split into pages [a:b] in the sequence.
|
||||
2. Look up cache_len[i] to find if the sequence has cached k/v.
|
||||
3. Look up page_table[cache_loc[i], cache_len[i] + a : cache_len[i] + b]
|
||||
to get the corresponding pages in the k_cache, with result [c:d].
|
||||
4. Then update k_cache[c:d] with the k value.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
@triton.jit
|
||||
def update_paged_kv_cache(
|
||||
k_ptr, # [B*S, N, D]
|
||||
v_ptr, # [B*S, N, D]
|
||||
seq_len_ptr, # [b] # length of each sequence in a batch
|
||||
seq_start_indices_ptr, # [b] # start indices of a sequence in flattened q/k/v.
|
||||
k_cache_ptr, # [num_pages, page_size, n, d]
|
||||
v_cache_ptr, # [num_pages, page_size, n, d]
|
||||
cache_loc_ptr, # [b] # index of the sequence in the page table.
|
||||
cache_len_ptr, # [b] # length of the sequence already in kv cache.
|
||||
page_table_ptr, # [b, max_num_pages_per_seq] # loc of the block page in the cache.
|
||||
N_KV_HEADS: tl.constexpr, # Number of KV heads.
|
||||
D_HEAD: tl.constexpr, # Dimension of each head.
|
||||
SEQ_BLOCK: tl.constexpr,
|
||||
MAX_SEQ_LENGTH: tl.constexpr,
|
||||
PAGE_SIZE: tl.constexpr,
|
||||
PAGE_TABLE_STRIDE: tl.constexpr,
|
||||
GENERATE_ONLY: tl.constexpr,
|
||||
):
|
||||
batch_id = tl.program_id(axis=0)
|
||||
head_id = tl.program_id(axis=1)
|
||||
seq_block_id = tl.program_id(axis=2)
|
||||
|
||||
# Each program is responsible for a block of tokens in a single batch.
|
||||
if GENERATE_ONLY:
|
||||
seq_start_index = batch_id
|
||||
seq_len: tl.constexpr = 1
|
||||
else:
|
||||
seq_start_index = tl.load(seq_start_indices_ptr + batch_id)
|
||||
seq_len = tl.load(seq_len_ptr + batch_id)
|
||||
|
||||
cache_len = tl.load(cache_len_ptr + batch_id)
|
||||
|
||||
# cache is [num_pages, page_size, n, d]
|
||||
# cache_loc_ptr stores the batch index for the sequences provided to the kernel.
|
||||
cache_loc = tl.load(cache_loc_ptr + batch_id)
|
||||
cache_head_offset = head_id * D_HEAD
|
||||
|
||||
# Assuming D_HEAD is a power of 2
|
||||
dhead_offsets = tl.arange(0, D_HEAD)
|
||||
dhead_mask = dhead_offsets < D_HEAD
|
||||
|
||||
seq_offsets = seq_block_id * SEQ_BLOCK + tl.arange(0, SEQ_BLOCK)
|
||||
seq_mask = seq_offsets < seq_len
|
||||
|
||||
load_mask = seq_mask[:, None] * dhead_mask[None, :]
|
||||
|
||||
kv_batch_offset = seq_start_index * N_KV_HEADS * D_HEAD
|
||||
kv_head_offset = cache_head_offset
|
||||
|
||||
# Write back to kv-caches
|
||||
ks = tl.load(
|
||||
k_ptr
|
||||
+ kv_batch_offset
|
||||
+ seq_offsets[:, None] * N_KV_HEADS * D_HEAD
|
||||
+ kv_head_offset
|
||||
+ dhead_offsets[None, :],
|
||||
mask=load_mask,
|
||||
)
|
||||
vs = tl.load(
|
||||
v_ptr
|
||||
+ kv_batch_offset
|
||||
+ seq_offsets[:, None] * N_KV_HEADS * D_HEAD
|
||||
+ kv_head_offset
|
||||
+ dhead_offsets[None, :],
|
||||
mask=load_mask,
|
||||
)
|
||||
|
||||
# assuming SEQ_BLOCK can be divided by PAGE_SIZE and PAGE_SIZE is a power of 2.
|
||||
SEQ_BLOCK_PAGE: tl.constexpr = SEQ_BLOCK // PAGE_SIZE
|
||||
MAX_NUM_PAGES: tl.constexpr = (MAX_SEQ_LENGTH + PAGE_SIZE - 1) // PAGE_SIZE
|
||||
# cache_len // PAGE_SIZE means history pages
|
||||
# if decode sequence, then seq_len = 1 and only seq_block_id = 0 works,
|
||||
kv_pages = seq_block_id * SEQ_BLOCK_PAGE + tl.arange(0, SEQ_BLOCK_PAGE) + cache_len // PAGE_SIZE
|
||||
cache_pages = tl.load(
|
||||
page_table_ptr + cache_loc * PAGE_TABLE_STRIDE + kv_pages, mask=kv_pages < MAX_NUM_PAGES
|
||||
)
|
||||
|
||||
page_offsets = tl.arange(0, PAGE_SIZE)
|
||||
# shape [SEQ_BLOCK], means [cache_pages, page_offsets]
|
||||
cache_seq_offset = tl.reshape(
|
||||
cache_pages[:, None] * PAGE_SIZE + page_offsets[None, :], [SEQ_BLOCK]
|
||||
)
|
||||
# write offset inside the page
|
||||
cache_seq_offset += cache_len % PAGE_SIZE
|
||||
|
||||
cache_offsets = (
|
||||
cache_seq_offset[:, None] * N_KV_HEADS * D_HEAD + kv_head_offset + dhead_offsets[None, :]
|
||||
)
|
||||
tl.store(k_cache_ptr + cache_offsets, ks, load_mask)
|
||||
tl.store(v_cache_ptr + cache_offsets, vs, load_mask)
|
||||
|
||||
|
||||
# TODO: Write a doc describing the 2 stage algorithm
|
||||
@triton.jit
|
||||
def attention_kv_paged_stage1(
|
||||
q_ptr, # [Batch, 1, N_HEADS, D_HEAD]
|
||||
k_cache_ptr, # [NUM_PAGES, PAGE_SIZE, N_HEADS, D_HEAD]
|
||||
v_cache_ptr, # [NUM_PAGES, PAGE_SIZE, N_HEADS, D_HEAD]
|
||||
cache_loc_ptr, # [Batch] # Specifies the batch index for each of the generate tokens.
|
||||
page_table_ptr, # [Batch, num_pages_per_seq]
|
||||
cache_len_ptr, # [Batch] # Number of tokens in kv cache.
|
||||
output_values_ptr, # [Batch, N_HEADS, num_blocks, D_HEAD]
|
||||
output_logsumexp_ptr, # [Batch, N_HEADS, num_blocks]
|
||||
num_blocks,
|
||||
MAX_SEQ_LEN: tl.constexpr, # Maximum supported sequence length
|
||||
N_HEADS: tl.constexpr, # Number of heads
|
||||
N_KV_HEADS: tl.constexpr, # Number of KV heads.
|
||||
D_HEAD: tl.constexpr, # Dimension of each head.
|
||||
# Block size used for tiling the sequence dim.
|
||||
SEQ_BLOCK_SIZE: tl.constexpr,
|
||||
PAGE_SIZE: tl.constexpr,
|
||||
PAGE_TABLE_STRIDE: tl.constexpr,
|
||||
):
|
||||
"""Attention kernel to be used during the generate phase.
|
||||
|
||||
Uses flash decoding.
|
||||
KV-cache layout is assumed to be [Batch, Head, Seq, Dim]
|
||||
1. Fetch the K-cache from 0 to input_pos
|
||||
2. Fetch the V-cache from 0 to input_pos
|
||||
3. A = Q*K^T [1,D_HEAD] * [1,seq_len,D_HEAD] -> [1, seq_len]
|
||||
4. S = softmax(A)
|
||||
5. O = S*V [1, seq_len] * [1, seq_len, D_HEAD] -> [1, D_HEAD]
|
||||
"""
|
||||
# Assume KV-cache layout: [Batch, Head, Seq, Dim]
|
||||
# A program is responsible for 1 batch, 1 head and a block of sequences.
|
||||
batch_id = tl.program_id(axis=0)
|
||||
head_id = tl.program_id(axis=1)
|
||||
seq_block_id = tl.program_id(axis=2)
|
||||
|
||||
SEQ_BLOCK_PAGE: tl.constexpr = SEQ_BLOCK_SIZE // PAGE_SIZE
|
||||
MAX_NUM_PAGES: tl.constexpr = MAX_SEQ_LEN // PAGE_SIZE
|
||||
|
||||
cache_loc = tl.load(cache_loc_ptr + batch_id)
|
||||
seq_len = tl.load(cache_len_ptr + batch_id)
|
||||
# Offsets for the block of sequences this program processes.
|
||||
seq_start_pos = seq_block_id * SEQ_BLOCK_SIZE
|
||||
|
||||
if seq_start_pos > seq_len:
|
||||
return
|
||||
seq_offsets = seq_start_pos + tl.arange(0, SEQ_BLOCK_SIZE)
|
||||
seq_mask = seq_offsets <= seq_len
|
||||
# Assuming D_HEAD is a power of 2
|
||||
dhead_offsets = tl.arange(0, D_HEAD)
|
||||
dhead_mask = dhead_offsets < D_HEAD
|
||||
|
||||
HEAD_RATIO: tl.constexpr = N_HEADS // N_KV_HEADS
|
||||
cache_head_offset = (head_id // HEAD_RATIO) * D_HEAD
|
||||
|
||||
sm_scale: tl.constexpr = 1 / (D_HEAD**0.5)
|
||||
|
||||
# Program loads the entire Q for the head assigned to it.
|
||||
# [D_HEAD]
|
||||
q_batch_offset = batch_id * N_HEADS * D_HEAD
|
||||
q_head_offset = head_id * D_HEAD
|
||||
q = tl.load(q_ptr + q_batch_offset + q_head_offset + dhead_offsets)
|
||||
|
||||
kv_mask = seq_mask[:, None] * dhead_mask[None, :]
|
||||
|
||||
kv_pages = seq_block_id * SEQ_BLOCK_PAGE + tl.arange(0, SEQ_BLOCK_PAGE)
|
||||
cache_pages = tl.load(
|
||||
page_table_ptr + cache_loc * PAGE_TABLE_STRIDE + kv_pages, mask=kv_pages < MAX_NUM_PAGES
|
||||
)
|
||||
|
||||
page_offsets = tl.arange(0, PAGE_SIZE)
|
||||
# shape [SEQ_BLOCK], means [cache_pages, page_offsets]
|
||||
# token offsets in the paged kv cache
|
||||
cache_seq_offset = tl.reshape(
|
||||
cache_pages[:, None] * PAGE_SIZE + page_offsets[None, :], [SEQ_BLOCK_SIZE]
|
||||
)
|
||||
|
||||
cache_offsets = (
|
||||
cache_seq_offset[:, None] * N_KV_HEADS * D_HEAD + cache_head_offset + dhead_offsets[None, :]
|
||||
)
|
||||
|
||||
k = tl.load(k_cache_ptr + cache_offsets, mask=kv_mask)
|
||||
v = tl.load(v_cache_ptr + cache_offsets, mask=kv_mask)
|
||||
|
||||
# Note: check the output precision of the sum.
|
||||
# compute q*K^T
|
||||
# [D_HEAD] * [seq_block, D_HEAD], sum along axis 1
|
||||
attn = tl.sum(q[None, :] * k, axis=1) # [seq_block]
|
||||
attn = attn.to(tl.float32)
|
||||
attn *= sm_scale
|
||||
max_attn = tl.max(attn)
|
||||
# Set to -inf attn values where mask is not set. This forces exp(attn) to 0.
|
||||
attn = tl.where(seq_mask, attn, float("-inf"))
|
||||
exp_attn = tl.exp(attn - max_attn)
|
||||
|
||||
sumexp = tl.sum(exp_attn, axis=0) # scalar.
|
||||
|
||||
# [seq_len] * [seq_len, D_HEAD], sum along axis 0
|
||||
output = tl.sum(exp_attn[:, None] * v, axis=0) # [D_HEAD]
|
||||
|
||||
output = output / sumexp
|
||||
|
||||
# We store the log-sum-exp after removing the max.
|
||||
logsumexp = tl.log(sumexp) + max_attn
|
||||
# when seq_mask is all false, max_attn will be -inf and sumexp is zero
|
||||
|
||||
tl.store(
|
||||
output_values_ptr
|
||||
+ batch_id * N_HEADS * D_HEAD * num_blocks
|
||||
+ head_id * D_HEAD * num_blocks
|
||||
+ seq_block_id * D_HEAD
|
||||
+ dhead_offsets,
|
||||
output,
|
||||
)
|
||||
tl.store(
|
||||
output_logsumexp_ptr
|
||||
+ batch_id * N_HEADS * num_blocks
|
||||
+ head_id * num_blocks
|
||||
+ seq_block_id,
|
||||
logsumexp,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def context_attention_kv_paged(
|
||||
q_ptr, # [b*s,nd]
|
||||
seq_len_ptr, # [b] # length of each sequence in a batch
|
||||
seq_start_ptr, # [b] # start indices of a sequence in flattened q/k/v.
|
||||
k_cache_ptr, # [num_pages, page_size, n, d]
|
||||
v_cache_ptr, # [num_pages, page_size, n, d]
|
||||
cache_loc_ptr, # [b] # index of the sequence in the page table.
|
||||
cache_len_ptr, # [Batch] # Number of tokens in kv cache.
|
||||
page_table_ptr, # [b, max_num_pages_per_seq] # loc of the block page in the cache.
|
||||
softmax_scale,
|
||||
o_ptr,
|
||||
N_HEADS: tl.constexpr, # Number of heads
|
||||
N_KV_HEADS: tl.constexpr, # Number of KV heads.
|
||||
D_HEAD: tl.constexpr, # Dimension of each head.
|
||||
SEQ_BLOCK: tl.constexpr,
|
||||
MAX_SEQ_LENGTH: tl.constexpr,
|
||||
PAGE_SIZE: tl.constexpr,
|
||||
PAGE_TABLE_STRIDE: tl.constexpr,
|
||||
):
|
||||
"""Kernel for context phase.
|
||||
|
||||
Fuses rope
|
||||
Assuming:
|
||||
1. Self-attention [seqlen(Q) == seqlen(K)]
|
||||
2. Causal attention
|
||||
3. QKV layout: [b*s,n,d]
|
||||
"""
|
||||
batch_id = tl.program_id(axis=0)
|
||||
head_id = tl.program_id(axis=1)
|
||||
seq_block_id = tl.program_id(axis=2)
|
||||
|
||||
# Each program is responsible for a block of tokens in a single batch.
|
||||
seq_start_index = tl.load(seq_start_ptr + batch_id)
|
||||
seq_len = tl.load(seq_len_ptr + batch_id)
|
||||
|
||||
HEAD_RATIO: tl.constexpr = N_HEADS // N_KV_HEADS
|
||||
|
||||
# assuming SEQ_BLOCK can be divided by PAGE_SIZE and PAGE_SIZE is a power of 2.
|
||||
SEQ_BLOCK_PAGE: tl.constexpr = SEQ_BLOCK // PAGE_SIZE
|
||||
MAX_NUM_PAGES: tl.constexpr = (MAX_SEQ_LENGTH + PAGE_SIZE - 1) // PAGE_SIZE
|
||||
|
||||
# cache is [num_pages, page_size, n, d]
|
||||
# cache_loc_ptr stores the batch index for the sequences provided to the kernel.
|
||||
cache_loc = tl.load(cache_loc_ptr + batch_id)
|
||||
table_batch_offset = cache_loc * PAGE_TABLE_STRIDE
|
||||
|
||||
# Assuming D_HEAD is a power of 2
|
||||
dhead_offsets = tl.arange(0, D_HEAD)
|
||||
dhead_mask = dhead_offsets < D_HEAD
|
||||
|
||||
seq_offsets = tl.arange(0, SEQ_BLOCK)
|
||||
q_seq_offsets = seq_block_id * SEQ_BLOCK + seq_offsets
|
||||
seq_mask = q_seq_offsets < seq_len
|
||||
|
||||
load_mask = seq_mask[:, None] * dhead_mask[None, :]
|
||||
|
||||
q_batch_offset = seq_start_index * N_HEADS * D_HEAD
|
||||
q_head_offset = head_id * D_HEAD
|
||||
cache_head_offset = (head_id // HEAD_RATIO) * D_HEAD
|
||||
|
||||
# Q will stay in SRAM
|
||||
q = tl.load(
|
||||
q_ptr
|
||||
+ q_batch_offset
|
||||
+ q_seq_offsets[:, None] * N_HEADS * D_HEAD
|
||||
+ q_head_offset
|
||||
+ dhead_offsets[None, :],
|
||||
mask=load_mask,
|
||||
)
|
||||
acc = tl.zeros([SEQ_BLOCK, D_HEAD], dtype=tl.float32)
|
||||
lse_i = tl.zeros([SEQ_BLOCK], dtype=tl.float32) - float("inf")
|
||||
m_i = tl.zeros([SEQ_BLOCK], dtype=tl.float32) - float("inf")
|
||||
|
||||
cache_len = tl.load(cache_len_ptr + batch_id)
|
||||
total_len = cache_len + seq_len
|
||||
num_blocks = (total_len + SEQ_BLOCK - 1) // SEQ_BLOCK
|
||||
for s in range(0, num_blocks + 1, 1):
|
||||
kv_pages = s * SEQ_BLOCK_PAGE + tl.arange(0, SEQ_BLOCK_PAGE)
|
||||
cache_pages = tl.load(
|
||||
page_table_ptr + table_batch_offset + kv_pages, mask=kv_pages < MAX_NUM_PAGES
|
||||
)
|
||||
|
||||
page_offsets = tl.arange(0, PAGE_SIZE)
|
||||
# shape [SEQ_BLOCK], means [cache_pages, page_offsets]
|
||||
# physical token offsets in the paged kv cache
|
||||
cache_seq_offset = tl.reshape(
|
||||
cache_pages[:, None] * PAGE_SIZE + page_offsets[None, :], [SEQ_BLOCK]
|
||||
)
|
||||
cache_offsets = (
|
||||
cache_seq_offset[:, None] * N_KV_HEADS * D_HEAD
|
||||
+ cache_head_offset
|
||||
+ dhead_offsets[None, :]
|
||||
)
|
||||
|
||||
# logical kv tokens offsets
|
||||
kv_seq_offsets = s * SEQ_BLOCK + seq_offsets
|
||||
kv_seq_mask = kv_seq_offsets < total_len
|
||||
kv_load_mask = kv_seq_mask[:, None] * dhead_mask[None, :]
|
||||
|
||||
k = tl.load(k_cache_ptr + cache_offsets, mask=kv_load_mask)
|
||||
qk = tl.zeros([SEQ_BLOCK, SEQ_BLOCK], dtype=tl.float32)
|
||||
qk += tl.dot(q, k.trans())
|
||||
# causal mask, need to use kv_seq_offsets
|
||||
qk = tl.where(
|
||||
(q_seq_offsets[:, None] + cache_len) >= kv_seq_offsets[None, :], qk, float("-inf")
|
||||
)
|
||||
|
||||
qk *= softmax_scale
|
||||
# rowmax
|
||||
m_ij = tl.maximum(tl.max(qk, 1), lse_i)
|
||||
p = tl.exp(qk - m_ij[:, None])
|
||||
v = tl.load(v_cache_ptr + cache_offsets, mask=kv_load_mask)
|
||||
|
||||
l_ij = tl.sum(p, 1)
|
||||
acc_scale = tl.exp(m_i - m_ij)
|
||||
acc = acc * acc_scale[:, None]
|
||||
p = p.to(v.dtype)
|
||||
acc += tl.dot(p, v)
|
||||
m_i = m_ij
|
||||
l_i_new = tl.exp(lse_i - m_ij) + l_ij
|
||||
lse_i = m_ij + tl.log(l_i_new)
|
||||
|
||||
o_scale = tl.exp(m_i - lse_i)
|
||||
|
||||
acc = acc * o_scale[:, None]
|
||||
|
||||
tl.store(
|
||||
o_ptr
|
||||
+ q_batch_offset
|
||||
+ q_seq_offsets[:, None] * N_HEADS * D_HEAD
|
||||
+ q_head_offset
|
||||
+ dhead_offsets[None, :],
|
||||
acc,
|
||||
mask=load_mask,
|
||||
)
|
||||
@ -1,7 +1,6 @@
|
||||
import copy
|
||||
import functools
|
||||
import math
|
||||
from typing import Callable, Dict, Optional, Tuple, Union, final
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union, final
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -11,13 +10,16 @@ import tensorrt_llm.bindings
|
||||
from tensorrt_llm.llmapi.llm_args import KvCacheConfig
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
|
||||
from ...._utils import torch_dtype_to_binding
|
||||
from ...pyexecutor.mamba_cache_manager import MambaHybridCacheManager
|
||||
from ...pyexecutor.resource_manager import KVCacheManager
|
||||
from ..custom_ops.attention_interface import (
|
||||
PagedResourceHandler,
|
||||
CausalConvResourceHandler,
|
||||
KVPagedResourceHandler,
|
||||
ResourceHandler,
|
||||
ResourceHandlerDict,
|
||||
SequenceInfo,
|
||||
SSMResourceHandler,
|
||||
StateResourceHandler,
|
||||
)
|
||||
from ..distributed.common import all_gather_object, get_world_size
|
||||
@ -93,9 +95,8 @@ class CachedSequenceInterface:
|
||||
self._caches: Dict[str, torch.Tensor] = {}
|
||||
# KVCacheManager (or MambaHybridCacheManager) for managed resources
|
||||
self._kv_cache_manager: Optional[Union[KVCacheManager, MambaHybridCacheManager]] = None
|
||||
# Ordered dicts tracking resource handlers by type
|
||||
self._paged_cache_order: ResourceHandlerDict = {} # Paged resources (kv caches)
|
||||
self._state_resource_order: ResourceHandlerDict = {} # State resources (ssm states)
|
||||
# lookup of unmanaged resources
|
||||
self._unmanaged_resources: List[str] = []
|
||||
|
||||
@property
|
||||
def args(self) -> Tuple[torch.Tensor, ...]:
|
||||
@ -111,7 +112,7 @@ class CachedSequenceInterface:
|
||||
self.info.to(*args, **kwargs)
|
||||
# Only move locally-allocated caches (paged/state caches are managed by cache managers)
|
||||
for name, cache in self._caches.items():
|
||||
if name not in self._paged_cache_order and name not in self._state_resource_order:
|
||||
if name in self._unmanaged_resources:
|
||||
cache.to(*args, **kwargs)
|
||||
|
||||
def update_kv_cache_config(self, **kwargs) -> None:
|
||||
@ -126,54 +127,185 @@ class CachedSequenceInterface:
|
||||
"""Add a resource handler to the cache interface."""
|
||||
self._resource_lookup[name] = resource_handler
|
||||
|
||||
def _create_kv_cache_manager(self, max_tokens: Optional[int] = None) -> int:
|
||||
"""Create KVCacheManager or MambaHybridCacheManager with multi-layer byte-level params.
|
||||
@staticmethod
|
||||
def _check_n_groups_constraint(
|
||||
ssm: SSMResourceHandler, conv: CausalConvResourceHandler
|
||||
) -> bool:
|
||||
"""Check if SSM and Conv handlers satisfy the n_groups constraint.
|
||||
|
||||
This uses a multi-layer approach with byte-level abstraction:
|
||||
- Paged resources: Each resource gets its own layer in KVCacheManager with
|
||||
num_kv_heads=bytes_per_token for that resource, head_dim=1.
|
||||
- State resources: Each resource gets its own layer in MambaCacheManager with
|
||||
head_dim=bytes_per_slot for that resource.
|
||||
|
||||
Each layer's cache is contiguous, avoiding byte-offset slicing within layers.
|
||||
|
||||
When state resources exist, MambaHybridCacheManager is used to manage both.
|
||||
|
||||
Important NOTE on contiguity of managed resources:
|
||||
- We only guarantee contiguity for an individual page or an individual state slot.
|
||||
- Outside of these individual pages/slots, resources are NOT guaranteed to be contiguous.
|
||||
The MambaCacheManager requires: conv_dim = head_dim * num_heads + 2 * n_groups * d_state
|
||||
This method checks if this constraint can be satisfied with integer n_groups >= 0.
|
||||
|
||||
Args:
|
||||
max_tokens: Maximum number of tokens to allocate. If provided, it will use the min value
|
||||
between this value and max_tokens in kv_cache_config.
|
||||
ssm: SSM resource handler with num_heads, head_dim, d_state.
|
||||
conv: Conv resource handler with conv_dim.
|
||||
|
||||
Returns:
|
||||
The final number of tokens that can be cached in the KVCacheManager.
|
||||
NOTE: this number may differ from the provided ``max_tokens`` arg for two reasons:
|
||||
1. the final number of tokens is synced (min) across ranks
|
||||
2. rounding for getting a multiple of tokens_per_block
|
||||
True if the constraint is satisfied, False otherwise.
|
||||
"""
|
||||
# Build per-layer num_kv_heads list for paged resources
|
||||
# Each paged resource becomes one "layer" with num_kv_heads = bytes_per_token
|
||||
num_kv_heads_per_layer = [
|
||||
math.prod(h.token_shape) * h.dtype.itemsize for h in self._paged_cache_order.values()
|
||||
]
|
||||
if ssm.d_state == 0:
|
||||
# d_state=0 means SSM buffer is empty, any conv_dim works
|
||||
return True
|
||||
diff = conv.conv_dim - ssm.head_dim * ssm.num_heads
|
||||
return diff >= 0 and diff % (2 * ssm.d_state) == 0
|
||||
|
||||
# Calculate total bytes per slot for state resources (modeled as single layer)
|
||||
cumulative_bytes_per_state = [0]
|
||||
for name, handler in self._state_resource_order.items():
|
||||
byte_size = math.prod(handler.state_shape) * handler.dtype.itemsize
|
||||
cumulative_bytes_per_state.append(cumulative_bytes_per_state[-1] + byte_size)
|
||||
@staticmethod
|
||||
def _get_mamba_state_params(
|
||||
ssm_ref: Optional[SSMResourceHandler],
|
||||
ssm_count: int,
|
||||
conv_ref: Optional[CausalConvResourceHandler],
|
||||
conv_count: int,
|
||||
) -> Dict[str, Union[int, torch.dtype, None]]:
|
||||
"""Derive MambaHybridCacheManager parameters from reference state handlers.
|
||||
|
||||
Precondition: If both ssm_ref and conv_ref are provided,
|
||||
the n_groups constraint has already been verified to hold.
|
||||
|
||||
Args:
|
||||
ssm_ref: Reference SSM handler (defines shape/dtype), or None.
|
||||
ssm_count: Number of compatible SSM resources.
|
||||
conv_ref: Reference Conv handler (defines shape/dtype), or None.
|
||||
conv_count: Number of compatible Conv resources.
|
||||
|
||||
Returns:
|
||||
Dictionary of MambaHybridCacheManager constructor parameters.
|
||||
"""
|
||||
# Get SSM parameters (or dummy if not managing SSM)
|
||||
if ssm_ref:
|
||||
num_heads = ssm_ref.num_heads
|
||||
head_dim = ssm_ref.head_dim
|
||||
d_state = ssm_ref.d_state
|
||||
ssm_dtype = ssm_ref.dtype
|
||||
else:
|
||||
# Dummy SSM params - d_state=0 means empty tensor (no memory used)
|
||||
num_heads, head_dim, d_state = 1, 1, 0
|
||||
ssm_dtype = torch.float32
|
||||
|
||||
# Get Conv parameters (or dummy if not managing Conv)
|
||||
if conv_ref:
|
||||
conv_dim = conv_ref.conv_dim
|
||||
d_conv = conv_ref.d_conv
|
||||
conv_dtype = conv_ref.dtype
|
||||
else:
|
||||
# Dummy Conv params - d_conv=1 means state shape (..., 0) (no memory used)
|
||||
conv_dim, d_conv = 1, 1
|
||||
conv_dtype = torch.float32
|
||||
|
||||
# Determine layer count:
|
||||
# - If both buffers used: min() to avoid wasting memory
|
||||
# - If only one buffer used: use that buffer's count
|
||||
if ssm_count > 0 and conv_count > 0:
|
||||
num_layers = min(ssm_count, conv_count)
|
||||
else:
|
||||
num_layers = max(ssm_count, conv_count)
|
||||
assert num_layers > 0, "At least one layer is expected."
|
||||
|
||||
# Derive n_groups from conv_dim constraint (already verified if both are managed)
|
||||
# conv_dim = head_dim * num_heads + 2 * n_groups * d_state
|
||||
if d_state > 0 and conv_ref:
|
||||
n_groups = (conv_dim - head_dim * num_heads) // (2 * d_state)
|
||||
else:
|
||||
n_groups = 0
|
||||
|
||||
return {
|
||||
"mamba_d_state": d_state,
|
||||
"mamba_d_conv": d_conv,
|
||||
"mamba_num_heads": num_heads,
|
||||
"mamba_n_groups": n_groups,
|
||||
"mamba_head_dim": head_dim,
|
||||
"mamba_num_layers": num_layers,
|
||||
"mamba_layer_mask": None,
|
||||
"mamba_cache_dtype": conv_dtype,
|
||||
"mamba_ssm_cache_dtype": ssm_dtype,
|
||||
}
|
||||
|
||||
def _identify_managed_kv_resources(
|
||||
self,
|
||||
) -> Tuple[Optional[KVPagedResourceHandler], ResourceHandlerDict]:
|
||||
"""Identify KV resources compatible with the reference handler for KVCacheManager.
|
||||
|
||||
The first KVPagedResourceHandler becomes the reference. All handlers matching
|
||||
the reference (via __eq__) are collected for managed allocation.
|
||||
|
||||
Returns:
|
||||
Tuple of (reference_handler, managed_resources_dict).
|
||||
reference_handler is None if no KV paged resources exist.
|
||||
"""
|
||||
kv_ref: Optional[KVPagedResourceHandler] = None
|
||||
kv_managed: ResourceHandlerDict = {}
|
||||
|
||||
for name, handler in self._resource_lookup.items():
|
||||
if not isinstance(handler, KVPagedResourceHandler):
|
||||
continue
|
||||
if kv_ref is None:
|
||||
kv_ref = handler
|
||||
if handler == kv_ref:
|
||||
kv_managed[name] = handler
|
||||
|
||||
return kv_ref, kv_managed
|
||||
|
||||
def _identify_managed_state_resources(
|
||||
self,
|
||||
) -> Tuple[Optional[SSMResourceHandler], list, Optional[CausalConvResourceHandler], list]:
|
||||
"""Identify SSM and Conv resources compatible with MambaHybridCacheManager.
|
||||
|
||||
Finds reference handlers for SSM and Conv resources, checks the n_groups constraint,
|
||||
and collects all compatible resources for each type.
|
||||
|
||||
Returns:
|
||||
Tuple of (ssm_ref, ssm_managed, conv_ref, conv_managed) where:
|
||||
- ssm_ref: Reference SSM handler or None
|
||||
- ssm_managed: List of (name, handler) tuples for compatible SSM resources
|
||||
- conv_ref: Reference Conv handler or None (may be None if constraint fails)
|
||||
- conv_managed: List of (name, handler) tuples for compatible Conv resources
|
||||
"""
|
||||
ssm_ref: Optional[SSMResourceHandler] = None
|
||||
conv_ref: Optional[CausalConvResourceHandler] = None
|
||||
|
||||
# Find reference handlers for each state resource type
|
||||
for handler in self._resource_lookup.values():
|
||||
if isinstance(handler, SSMResourceHandler) and ssm_ref is None:
|
||||
ssm_ref = handler
|
||||
elif isinstance(handler, CausalConvResourceHandler) and conv_ref is None:
|
||||
conv_ref = handler
|
||||
if ssm_ref and conv_ref:
|
||||
break
|
||||
|
||||
# Check n_groups constraint: conv_dim = head_dim * num_heads + 2 * n_groups * d_state
|
||||
# If constraint doesn't hold, only manage SSM (more common); Conv goes to local allocation
|
||||
if ssm_ref and conv_ref and not self._check_n_groups_constraint(ssm_ref, conv_ref):
|
||||
ad_logger.debug(
|
||||
"n_groups constraint not satisfied between SSM and Conv handlers. "
|
||||
"Conv resources will be allocated locally."
|
||||
)
|
||||
conv_ref = None # Don't manage Conv via cache manager
|
||||
|
||||
# Collect compatible resources for each managed type (using __eq__ for comparison)
|
||||
ssm_managed = [(n, h) for n, h in self._resource_lookup.items() if ssm_ref == h]
|
||||
conv_managed = [(n, h) for n, h in self._resource_lookup.items() if conv_ref == h]
|
||||
|
||||
return ssm_ref, ssm_managed, conv_ref, conv_managed
|
||||
|
||||
def _prepare_kv_cache_config(
|
||||
self,
|
||||
max_tokens: Optional[int],
|
||||
kv_managed: ResourceHandlerDict,
|
||||
) -> KvCacheConfig:
|
||||
"""Prepare and configure KvCacheConfig for cache manager creation.
|
||||
|
||||
Handles deep copy, max_tokens synchronization across ranks, block reuse settings,
|
||||
copy_on_partial_reuse validation, and free_gpu_memory_fraction normalization.
|
||||
|
||||
Args:
|
||||
max_tokens: Maximum tokens to allocate, or None to use config defaults.
|
||||
kv_managed: Dict of KV resources that will be managed by KVCacheManager.
|
||||
|
||||
Returns:
|
||||
Configured KvCacheConfig ready for cache manager creation.
|
||||
"""
|
||||
# Make a deep copy of the kv_cache_config to avoid modifying the original object
|
||||
kv_cache_config = copy.deepcopy(self._kv_cache_config_original)
|
||||
|
||||
# Disable copy_on_partial_reuse
|
||||
# TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/10966
|
||||
if kv_cache_config.copy_on_partial_reuse:
|
||||
kv_cache_config.copy_on_partial_reuse = False
|
||||
ad_logger.info("Disabling copy_on_partial_reuse for AutoDeploy backend.")
|
||||
|
||||
# Update kv_cache_config based on max_tokens if provided
|
||||
if max_tokens is not None:
|
||||
# sync max_tokens across ranks
|
||||
@ -185,10 +317,24 @@ class CachedSequenceInterface:
|
||||
kv_cache_config.max_tokens = min(kv_cache_config.max_tokens or max_tokens, max_tokens)
|
||||
|
||||
# Check if we should disable block reuse
|
||||
if kv_cache_config.enable_block_reuse and not self.is_paged():
|
||||
is_paged = all(handler.is_paged for handler in self._resource_lookup.values())
|
||||
if kv_cache_config.enable_block_reuse and not is_paged:
|
||||
kv_cache_config.enable_block_reuse = False
|
||||
ad_logger.info(f"Setting {kv_cache_config.enable_block_reuse=} for non-paged models.")
|
||||
|
||||
# Check if we can use copy on partial reuse
|
||||
num_non_kv_managed_caches = len(self._caches) - len(kv_managed)
|
||||
if (
|
||||
kv_cache_config.enable_block_reuse
|
||||
and kv_cache_config.copy_on_partial_reuse
|
||||
and num_non_kv_managed_caches > 0
|
||||
):
|
||||
kv_cache_config.copy_on_partial_reuse = False
|
||||
ad_logger.info(
|
||||
"Disabling copy_on_partial_reuse: requires all resources to be paged and managed by"
|
||||
f" KVCacheManager ({num_non_kv_managed_caches=})."
|
||||
)
|
||||
|
||||
# Make sure to set free_gpu_memory_fraction to None if set to 0.0
|
||||
# NOTE: KVCacheConfig validator enforces that free_gpu_memory_fraction must be between 0.0
|
||||
# and 1.0 but we allow 0.0 to be set to disable resizing (corresponding to None in the
|
||||
@ -196,113 +342,240 @@ class CachedSequenceInterface:
|
||||
if kv_cache_config.free_gpu_memory_fraction == 0.0:
|
||||
kv_cache_config.free_gpu_memory_fraction = None
|
||||
|
||||
# Common KV cache parameters
|
||||
kv_cache_kwargs = {
|
||||
"kv_cache_config": kv_cache_config,
|
||||
"kv_cache_type": CacheTypeCpp.SELFKONLY, # kv_factor=1, treat K, V separately
|
||||
"num_layers": len(self._paged_cache_order), # correct num layers
|
||||
"num_kv_heads": num_kv_heads_per_layer, # per-layer bytes_per_token
|
||||
"head_dim": 1, # all bytes in num_kv_heads
|
||||
"tokens_per_block": kv_cache_config.tokens_per_block,
|
||||
"max_seq_len": self.info.max_seq_len,
|
||||
"max_batch_size": self.info.max_batch_size,
|
||||
"mapping": Mapping(),
|
||||
# NOTE (lucaslie): this is the only 1-byte dtype currently supported by the
|
||||
# KVCacheManager. Ideally, we would use the typical uint8 dtype for byte-level
|
||||
# abstraction, but this is not supported.
|
||||
"dtype": DataType.FP8, # 1-byte dtype for byte-level abstraction
|
||||
"layer_mask": None,
|
||||
# NOTE (lucaslie): we can always run with False here since when we are estimating, we
|
||||
# are explicitly setting the max_tokens in which case it's okay to use False here since
|
||||
# we don't rely on free_gpu_memory_fraction inside the KVCacheManager. This is similar
|
||||
# to _torch.pyexecutor._util.KVCacheCreator, which explicitly estimates the max_tokens
|
||||
# outside of the KVCacheManager.
|
||||
"is_estimating_kv_cache": False,
|
||||
}
|
||||
return kv_cache_config
|
||||
|
||||
# update args if we are just doing a dummy cache manager
|
||||
if not len(self._paged_cache_order):
|
||||
def _build_kv_cache_kwargs(
|
||||
self,
|
||||
kv_ref: Optional[KVPagedResourceHandler],
|
||||
kv_managed: ResourceHandlerDict,
|
||||
kv_cache_config: KvCacheConfig,
|
||||
) -> Dict:
|
||||
"""Build common kwargs for KVCacheManager or MambaHybridCacheManager.
|
||||
|
||||
Args:
|
||||
kv_ref: Reference KV handler defining head_dim and dtype, or None.
|
||||
kv_managed: Dict of KV resources to be managed.
|
||||
kv_cache_config: Configured KvCacheConfig.
|
||||
|
||||
Returns:
|
||||
Dict of kwargs suitable for both KVCacheManager and MambaHybridCacheManager.
|
||||
"""
|
||||
# create arguments first that differ whether we have managed kv caches or not
|
||||
kv_cache_kwargs = {}
|
||||
if kv_managed:
|
||||
kv_cache_type = CacheTypeCpp.SELFKONLY if kv_ref.kv_factor == 1 else CacheTypeCpp.SELF
|
||||
kv_cache_kwargs.update(
|
||||
{
|
||||
"num_layers": 1,
|
||||
"num_kv_heads": 1,
|
||||
"head_dim": 1,
|
||||
"kv_cache_type": kv_cache_type,
|
||||
"num_layers": len(kv_managed),
|
||||
"num_kv_heads": [h.num_kv_heads for h in kv_managed.values()],
|
||||
"head_dim": kv_ref.head_dim,
|
||||
"dtype": torch_dtype_to_binding(kv_ref.dtype),
|
||||
}
|
||||
)
|
||||
|
||||
if self._state_resource_order:
|
||||
# NOTE: +1 for cuda graph padding
|
||||
kv_cache_kwargs["max_batch_size"] = self.info.max_num_state_slots
|
||||
|
||||
self._kv_cache_manager = MambaHybridCacheManager(
|
||||
# Mamba params for single-layer byte buffer
|
||||
mamba_d_state=1,
|
||||
mamba_d_conv=1, # conv_states will have shape [..., 0] (empty)
|
||||
mamba_num_heads=1,
|
||||
mamba_n_groups=1,
|
||||
mamba_head_dim=cumulative_bytes_per_state[-1], # Total bytes per slot
|
||||
mamba_num_layers=1, # Single layer
|
||||
mamba_layer_mask=None, # Single enabled layer
|
||||
mamba_cache_dtype=torch.uint8, # Byte-level
|
||||
mamba_ssm_cache_dtype=torch.uint8, # Byte-level
|
||||
# KV cache params
|
||||
**kv_cache_kwargs,
|
||||
)
|
||||
else:
|
||||
# No state resources - use pure KVCacheManager
|
||||
self._kv_cache_manager = KVCacheManager(**kv_cache_kwargs)
|
||||
|
||||
# store the tuned kv_cache_config
|
||||
self._kv_cache_config_tuned = kv_cache_config
|
||||
|
||||
# Ensure cache_loc capacity is sufficient for the new KVCacheManager
|
||||
blocks_in_primary_pool = self._kv_cache_manager.blocks_in_primary_pool
|
||||
tokens_per_block = self._kv_cache_manager.tokens_per_block
|
||||
self.info.estimate_cache_loc_capacity(blocks_in_primary_pool)
|
||||
|
||||
# Create paged resource views from per-layer buffers
|
||||
for layer_idx, (name, handler) in enumerate(self._paged_cache_order.items()):
|
||||
view = self._kv_cache_manager.get_buffers(layer_idx, kv_layout="NHD")
|
||||
view = view.view(blocks_in_primary_pool, tokens_per_block, -1).view(handler.dtype)
|
||||
view = view.view(blocks_in_primary_pool, tokens_per_block, *handler.token_shape)
|
||||
|
||||
# Sanity check on contiguity of individual pages
|
||||
view_one_page = view[0]
|
||||
assert view_one_page.is_contiguous(), f"Per-page cache for {name} is not contiguous"
|
||||
|
||||
self._caches[name] = view
|
||||
|
||||
for layer_idx, (name, handler) in enumerate(self._state_resource_order.items()):
|
||||
num_states = len(self._kv_cache_manager.state_indices)
|
||||
# Get the single-layer ssm_states buffer
|
||||
# ssm_states shape: [1, num_states, 1, total_bytes_per_slot, 1]
|
||||
ssm_buffer = self._kv_cache_manager.get_ssm_states(0)
|
||||
# Flatten to [max_batch, total_bytes_per_slot_for_all_layers]
|
||||
ssm_buffer = ssm_buffer.view(num_states, -1)
|
||||
|
||||
offset_start = cumulative_bytes_per_state[layer_idx]
|
||||
offset_end = cumulative_bytes_per_state[layer_idx + 1]
|
||||
|
||||
# Slice at byte offset, reinterpret dtype, reshape
|
||||
view = ssm_buffer[:, offset_start:offset_end]
|
||||
view = view.view(handler.dtype)
|
||||
view = view.view(num_states, *handler.state_shape)
|
||||
|
||||
# Sanity check on contiguity of individual state slots
|
||||
assert view[0].is_contiguous(), f"Per-slot state for {name} cache is not contiguous"
|
||||
|
||||
self._caches[name] = view
|
||||
|
||||
# Patch shutdown to clear cache views before pool release
|
||||
self._kv_cache_manager.shutdown = with_pre_callback(
|
||||
self._kv_cache_manager.shutdown,
|
||||
self._clear_cache_views,
|
||||
kv_cache_kwargs.update(
|
||||
{
|
||||
"kv_cache_type": CacheTypeCpp.SELF,
|
||||
"num_layers": 1,
|
||||
"num_kv_heads": [1],
|
||||
"head_dim": 1,
|
||||
"dtype": DataType.HALF,
|
||||
}
|
||||
)
|
||||
# remaining arguments are the same for both cases
|
||||
kv_cache_kwargs.update(
|
||||
{
|
||||
"kv_cache_config": kv_cache_config,
|
||||
"tokens_per_block": kv_cache_config.tokens_per_block,
|
||||
"max_seq_len": self.info.max_seq_len,
|
||||
"max_batch_size": self.info.max_batch_size,
|
||||
"mapping": Mapping(),
|
||||
"layer_mask": None,
|
||||
# NOTE (lucaslie): we can always run with False here since when we are estimating,
|
||||
# we are explicitly setting the max_tokens in which case it's okay to use False here
|
||||
# since we don't rely on free_gpu_memory_fraction inside the KVCacheManager. This is
|
||||
# similar to _torch.pyexecutor._util.KVCacheCreator, which explicitly estimates the
|
||||
# max_tokens outside of the KVCacheManager.
|
||||
"is_estimating_kv_cache": False,
|
||||
}
|
||||
)
|
||||
|
||||
return kv_cache_kwargs
|
||||
|
||||
def _create_and_assign_state_views(
|
||||
self,
|
||||
kv_cache_kwargs: Dict,
|
||||
ssm_ref: Optional[SSMResourceHandler],
|
||||
ssm_managed: list,
|
||||
conv_ref: Optional[CausalConvResourceHandler],
|
||||
conv_managed: list,
|
||||
) -> Tuple[MambaHybridCacheManager, int]:
|
||||
"""Create MambaHybridCacheManager and assign views for state resources.
|
||||
|
||||
Creates the hybrid cache manager with mamba parameters derived from the reference
|
||||
handlers, then retrieves and assigns buffer views for all managed SSM and Conv resources.
|
||||
|
||||
Args:
|
||||
kv_cache_kwargs: Base kwargs for cache manager (will be extended with mamba params).
|
||||
ssm_ref: Reference SSM handler or None.
|
||||
ssm_managed: List of (name, handler) tuples for SSM resources.
|
||||
conv_ref: Reference Conv handler or None.
|
||||
conv_managed: List of (name, handler) tuples for Conv resources.
|
||||
|
||||
Returns:
|
||||
Tuple of (manager, num_managed_mamba_layers).
|
||||
"""
|
||||
# Derive Mamba parameters from reference handlers
|
||||
mamba_params = self._get_mamba_state_params(
|
||||
ssm_ref, len(ssm_managed), conv_ref, len(conv_managed)
|
||||
)
|
||||
num_managed_mamba_layers = mamba_params["mamba_num_layers"]
|
||||
|
||||
# Create the hybrid cache manager
|
||||
manager = MambaHybridCacheManager(
|
||||
**mamba_params,
|
||||
**kv_cache_kwargs,
|
||||
)
|
||||
|
||||
# Retrieve and assign views for Mamba-managed resources (up to num_managed_mamba_layers)
|
||||
for layer_idx in range(num_managed_mamba_layers):
|
||||
if ssm_managed:
|
||||
ssm_view = manager.get_ssm_states(layer_idx)
|
||||
assert ssm_view.is_contiguous(), f"Non-contiguous state {ssm_managed[layer_idx][0]}"
|
||||
self._caches[ssm_managed[layer_idx][0]] = ssm_view
|
||||
if conv_managed:
|
||||
conv_view = manager.get_conv_states(layer_idx)
|
||||
assert conv_view.is_contiguous(), (
|
||||
f"Non-contiguous state {conv_managed[layer_idx][0]}"
|
||||
)
|
||||
self._caches[conv_managed[layer_idx][0]] = conv_view
|
||||
|
||||
return manager, num_managed_mamba_layers
|
||||
|
||||
def _assign_kv_cache_views(self, kv_managed: Dict[str, KVPagedResourceHandler]) -> None:
|
||||
"""Retrieve and assign buffer views for managed KV paged resources.
|
||||
|
||||
Args:
|
||||
kv_managed: Dict of KV resources managed by the cache manager.
|
||||
"""
|
||||
for idx, (name, h) in enumerate(kv_managed.items()):
|
||||
view = self._kv_cache_manager.get_buffers(idx, kv_layout=h.kv_layout)
|
||||
assert view[0].is_contiguous(), f"Non-contiguous kv cache resource for {name}"
|
||||
self._caches[name] = view
|
||||
|
||||
def _allocate_unmanaged_resources(self) -> None:
|
||||
"""Allocate resources not managed by cache managers.
|
||||
|
||||
Resources that haven't been assigned a tensor (still None) are allocated
|
||||
locally via their handler's allocate() method.
|
||||
"""
|
||||
self._unmanaged_resources.clear()
|
||||
for name, handler in self._resource_lookup.items():
|
||||
if self._caches[name] is None: # Not yet assigned a tensor
|
||||
self._caches[name] = handler.allocate(self.info)
|
||||
self._unmanaged_resources.append(name)
|
||||
|
||||
def _create_kv_cache_manager(self, max_tokens: Optional[int] = None) -> Dict:
|
||||
"""Create KVCacheManager or MambaHybridCacheManager with standard layout.
|
||||
|
||||
For paged resources (KVPagedResourceHandler):
|
||||
- Uses the first KVPagedResourceHandler's head_dim and dtype as reference
|
||||
- Compatible resources (matching head_dim and dtype) go into KVCacheManager
|
||||
- Incompatible resources are allocated locally via handler.allocate()
|
||||
|
||||
For state resources (SSMResourceHandler, CausalConvResourceHandler, StateResourceHandler):
|
||||
- SSMResourceHandler maps to MambaHybridCacheManager's ssm_states buffer
|
||||
- CausalConvResourceHandler maps to MambaHybridCacheManager's conv_states buffer
|
||||
- Generic StateResourceHandler and incompatible typed handlers are allocated locally
|
||||
- When both SSM and Conv handlers exist, uses min(ssm_count, conv_count) layers
|
||||
|
||||
Args:
|
||||
max_tokens: Maximum number of tokens to allocate. If provided, it will use the min value
|
||||
between this value and max_tokens in kv_cache_config.
|
||||
|
||||
Returns:
|
||||
Dictionary with cache statistics including max_tokens. The max_tokens value may differ
|
||||
from the provided ``max_tokens`` arg for two reasons:
|
||||
1. the final number of tokens is synced (min) across ranks
|
||||
2. rounding for getting a multiple of tokens_per_block
|
||||
"""
|
||||
# 1. Identify managed resources
|
||||
kv_ref, kv_managed = self._identify_managed_kv_resources()
|
||||
ssm_ref, ssm_managed, conv_ref, conv_managed = self._identify_managed_state_resources()
|
||||
|
||||
# 2. Prepare configuration
|
||||
kv_cache_config = self._prepare_kv_cache_config(max_tokens, kv_managed)
|
||||
kv_cache_kwargs = self._build_kv_cache_kwargs(kv_ref, kv_managed, kv_cache_config)
|
||||
|
||||
# 3. Create cache manager (delegate to state helper if state resources exist)
|
||||
has_state_resources = ssm_managed or conv_managed
|
||||
if has_state_resources:
|
||||
# NOTE: +1 for cuda graph padding
|
||||
kv_cache_kwargs["max_batch_size"] = self.info.max_num_state_slots
|
||||
self._kv_cache_manager, _ = self._create_and_assign_state_views(
|
||||
kv_cache_kwargs, ssm_ref, ssm_managed, conv_ref, conv_managed
|
||||
)
|
||||
else:
|
||||
# No typed state resources - use pure KVCacheManager
|
||||
self._kv_cache_manager = KVCacheManager(**kv_cache_kwargs)
|
||||
|
||||
# 4. Store tuned config and ensure capacity
|
||||
self._kv_cache_config_tuned = kv_cache_config
|
||||
self.info.estimate_cache_loc_capacity(self._kv_cache_manager.blocks_in_primary_pool)
|
||||
|
||||
# 5. Assign KV views
|
||||
self._assign_kv_cache_views(kv_managed)
|
||||
|
||||
# 6. Allocate remaining unmanaged resources
|
||||
self._allocate_unmanaged_resources()
|
||||
|
||||
# 7. Patch shutdown
|
||||
self._kv_cache_manager.shutdown = with_pre_callback(
|
||||
self._kv_cache_manager.shutdown,
|
||||
self._clear_caches,
|
||||
)
|
||||
|
||||
# 8. Compute final token count and cache statistics
|
||||
max_resource_count = self._kv_cache_manager.get_max_resource_count()
|
||||
max_tokens_final = max_resource_count * self._kv_cache_manager.tokens_per_block
|
||||
|
||||
return max_tokens_final
|
||||
# 9. Collect statistics of different types of resources
|
||||
num_state_total = sum(
|
||||
1 for h in self._resource_lookup.values() if isinstance(h, StateResourceHandler)
|
||||
)
|
||||
num_ssm_total = sum(
|
||||
1 for h in self._resource_lookup.values() if isinstance(h, SSMResourceHandler)
|
||||
)
|
||||
num_conv_total = sum(
|
||||
1 for h in self._resource_lookup.values() if isinstance(h, CausalConvResourceHandler)
|
||||
)
|
||||
num_state_other = num_state_total - num_ssm_total - num_conv_total
|
||||
|
||||
total_managed = len(kv_managed) + len(ssm_managed) + len(conv_managed)
|
||||
|
||||
paged_total = sum(1 for h in self._resource_lookup.values() if h.is_paged)
|
||||
kv_total = sum(
|
||||
1 for h in self._resource_lookup.values() if isinstance(h, KVPagedResourceHandler)
|
||||
)
|
||||
paged_other = paged_total - kv_total
|
||||
|
||||
other_total = len(self._caches) - paged_total - num_state_total
|
||||
|
||||
return {
|
||||
"total": len(self._caches),
|
||||
"total_managed": total_managed,
|
||||
"kv_total": kv_total,
|
||||
"kv_managed": len(kv_managed),
|
||||
"paged_other": paged_other,
|
||||
"ssm_total": num_ssm_total,
|
||||
"ssm_managed": len(ssm_managed),
|
||||
"conv_total": num_conv_total,
|
||||
"conv_managed": len(conv_managed),
|
||||
"state_other": num_state_other,
|
||||
"other": other_total,
|
||||
"max_tokens": max_tokens_final,
|
||||
}
|
||||
|
||||
def initialize_resources(self) -> int:
|
||||
"""Initialize resources - paged/state caches via cache managers, others separately.
|
||||
@ -314,20 +587,12 @@ class CachedSequenceInterface:
|
||||
Returns:
|
||||
The number of caches initialized.
|
||||
"""
|
||||
assert not self._caches and not self._paged_cache_order, "Caches already initialized."
|
||||
assert not self._caches, "Caches already initialized."
|
||||
self.info.to(self.device)
|
||||
|
||||
# Separate resources by type
|
||||
for name, handler in self._resource_lookup.items():
|
||||
if isinstance(handler, PagedResourceHandler):
|
||||
self._paged_cache_order[name] = handler
|
||||
self._caches[name] = None # Will be set by _create_kv_cache_manager
|
||||
elif isinstance(handler, StateResourceHandler):
|
||||
self._state_resource_order[name] = handler
|
||||
self._caches[name] = None # Will be set by _create_kv_cache_manager
|
||||
else:
|
||||
# Unknown handler type - allocate locally (fallback)
|
||||
self._caches[name] = handler.allocate(self.info)
|
||||
# Make sure self._caches has the same order as self._resource_lookup
|
||||
for name in self._resource_lookup.keys():
|
||||
self._caches[name] = None # Will be set by _create_kv_cache_manager
|
||||
|
||||
# Create unified cache manager (handles both paged and state resources)
|
||||
if self.needs_resize() or self._requires_token_estimate():
|
||||
@ -336,104 +601,94 @@ class CachedSequenceInterface:
|
||||
# if we don't need a resize, we will just use the original settings in kv_cache_config
|
||||
# instead of passing in an overwrite here.
|
||||
max_tokens_estimate = None
|
||||
self._create_kv_cache_manager(max_tokens=max_tokens_estimate)
|
||||
cache_stats = self._create_kv_cache_manager(max_tokens=max_tokens_estimate)
|
||||
|
||||
# Log cache statistics summary (format: total/managed)
|
||||
s = cache_stats
|
||||
ad_logger.info(
|
||||
f"Cache stats (total/managed): total={s['total']}/{s['total_managed']}, "
|
||||
f"kv={s['kv_total']}/{s['kv_managed']}, "
|
||||
f"paged_other={s['paged_other']}, "
|
||||
f"ssm={s['ssm_total']}/{s['ssm_managed']}, "
|
||||
f"conv={s['conv_total']}/{s['conv_managed']}, "
|
||||
f"state_other={s['state_other']}, "
|
||||
f"other={s['other']}, "
|
||||
f"max_tokens={s['max_tokens']}"
|
||||
)
|
||||
|
||||
return len(self._caches)
|
||||
|
||||
def is_paged(self) -> bool:
|
||||
"""Return True if all resources are paged and part of the KVCacheManager."""
|
||||
return set(self._paged_cache_order.keys()) == set(self._resource_lookup.keys())
|
||||
|
||||
def _requires_token_estimate(self) -> bool:
|
||||
"""Check if our kv_cache_config requires."""
|
||||
return (
|
||||
self._kv_cache_config_original.free_gpu_memory_fraction in [None, 0.0]
|
||||
and self._kv_cache_config_original.max_tokens is None
|
||||
)
|
||||
"""Check if our kv_cache_config requires max_tokens to be estimated."""
|
||||
needs_max_tokens = self._kv_cache_config_original.free_gpu_memory_fraction in [None, 0.0]
|
||||
needs_max_tokens |= not any(handler.is_paged for handler in self._resource_lookup.values())
|
||||
return needs_max_tokens and self._kv_cache_config_original.max_tokens is None
|
||||
|
||||
def needs_resize(self) -> bool:
|
||||
"""Check if we need a resize or not."""
|
||||
has_paged = bool(self._paged_cache_order)
|
||||
has_paged = any(handler.is_paged for handler in self._resource_lookup.values())
|
||||
return has_paged and self._kv_cache_config_original.free_gpu_memory_fraction not in [
|
||||
None,
|
||||
0.0,
|
||||
]
|
||||
|
||||
def resize_kv_cache_manager(self, mem_exclude: int = 0) -> None:
|
||||
"""Shutdown existing KVCacheManager and create new one with optimal capacity.
|
||||
"""Shutdown existing caches and recreate with optimal capacity for paged resources.
|
||||
|
||||
Args:
|
||||
mem_exclude: Extra memory to exclude from the calculation of optimal capacity.
|
||||
This is in bytes and typically the memory reserved for the forward pass.
|
||||
|
||||
This implements the two-phase approach: after running a forward pass during estimation
|
||||
to allocate intermediate memory, call this method to recreate the KVCacheManager.
|
||||
The new manager will compute optimal capacity based on current free GPU memory
|
||||
via calculate_max_num_blocks.
|
||||
to allocate intermediate memory, call this method to recreate the cache manager.
|
||||
The new manager will compute optimal capacity based on current free GPU memory.
|
||||
"""
|
||||
if not self.needs_resize():
|
||||
return
|
||||
|
||||
# get per-token cache size for resizable resources
|
||||
paged_cache_bytes_per_token = self._kv_cache_manager.get_cache_bytes_per_token()
|
||||
|
||||
# get total cache size of state resources that cannot be resized
|
||||
# NOTE: this does NOT include resources handled OUTSIDE of the KVCacheManager or
|
||||
# MambaHybridCacheManager. Those will persistent and will be accounted for via free_mem even
|
||||
# after the initialize kv_cache_manager is shutdown.
|
||||
state_cache_bytes_total = sum(
|
||||
cache.numel() * cache.element_size()
|
||||
for name, cache in self._caches.items()
|
||||
if name in self._state_resource_order
|
||||
# Calculate bytes-per-token for paged (resizable) resources
|
||||
paged_bytes_per_token = sum(
|
||||
h.bytes_per_token for h in self._resource_lookup.values() if h.is_paged
|
||||
)
|
||||
|
||||
# get unmanaged cache size
|
||||
unmanaged_cache_bytes_total = sum(
|
||||
# Calculate total bytes for non-paged (non-resizable) resources
|
||||
non_paged_bytes_total = sum(
|
||||
cache.numel() * cache.element_size()
|
||||
for name, cache in self._caches.items()
|
||||
if name not in self._paged_cache_order and name not in self._state_resource_order
|
||||
if not self._resource_lookup[name].is_paged
|
||||
)
|
||||
|
||||
# Shutdown existing KVCacheManager to free memory
|
||||
# Shutdown clears ALL cache views (paged and non-paged)
|
||||
self._kv_cache_manager.shutdown()
|
||||
|
||||
# Get current free GPU memory (roughly includes model weights + non-managed resources)
|
||||
# Get current free GPU memory after shutdown
|
||||
_, free_mem, *_ = get_mem_info(empty_cache=True)
|
||||
|
||||
# Compute available memory for the KVCacheManager
|
||||
# NOTE: free_mem was obtained AFTER shutdown of initial KVCacheManager - hence it accounts
|
||||
# for unmanaged resources but it does NOT account for state resources since those were
|
||||
# freed as part of the shutdown.
|
||||
# Compute available memory for paged caches
|
||||
# Reserve space for non-paged caches and mem_exclude, then apply free_gpu_memory_fraction
|
||||
free_gpu_memory_fraction = self._kv_cache_config_original.free_gpu_memory_fraction
|
||||
mem_for_paged_optimal = (
|
||||
free_mem - state_cache_bytes_total - mem_exclude
|
||||
free_mem - non_paged_bytes_total - mem_exclude
|
||||
) * free_gpu_memory_fraction
|
||||
# Check how many tokens we can fit into the paged cache
|
||||
max_tokens_optimal = int(mem_for_paged_optimal // paged_cache_bytes_per_token)
|
||||
max_tokens_optimal = int(mem_for_paged_optimal // paged_bytes_per_token)
|
||||
|
||||
# Create new KVCacheManager with final capacity
|
||||
max_tokens_final = self._create_kv_cache_manager(max_tokens=max_tokens_optimal)
|
||||
# Create new cache manager with optimal capacity
|
||||
cache_stats = self._create_kv_cache_manager(max_tokens=max_tokens_optimal)
|
||||
max_tokens_final = cache_stats["max_tokens"]
|
||||
|
||||
# Log resulting memory information
|
||||
mem_info = [
|
||||
f"free_mem={bytes_to(free_mem, unit='GB'):.2f}GB",
|
||||
f"free_gpu_memory_fraction={free_gpu_memory_fraction}",
|
||||
f"mem_exclude={bytes_to(mem_exclude, unit='GB'):.2f}GB",
|
||||
f"mem_exclude_for_state={bytes_to(state_cache_bytes_total, unit='GB'):.2f}GB",
|
||||
f"mem_for_paged_optimal={bytes_to(mem_for_paged_optimal, unit='GB'):.2f}GB",
|
||||
]
|
||||
total_cache_bytes = (
|
||||
mem_for_paged_optimal + state_cache_bytes_total + unmanaged_cache_bytes_total
|
||||
total_cache_bytes = mem_for_paged_optimal + non_paged_bytes_total
|
||||
ad_logger.info(
|
||||
f"Resize mem info: free_mem={bytes_to(free_mem, unit='GB'):.2f}GB, "
|
||||
f"free_gpu_memory_fraction={free_gpu_memory_fraction}, "
|
||||
f"mem_exclude={bytes_to(mem_exclude, unit='GB'):.2f}GB"
|
||||
)
|
||||
ad_logger.info(
|
||||
f"Final cache mem: max_tokens={max_tokens_final}, "
|
||||
f"paged={bytes_to(mem_for_paged_optimal, unit='GB'):.2f}GB, "
|
||||
f"non_paged={bytes_to(non_paged_bytes_total, unit='GB'):.2f}GB, "
|
||||
f"total={bytes_to(total_cache_bytes, unit='GB'):.2f}GB"
|
||||
)
|
||||
mem_cache_info = [
|
||||
f"Max Tokens={max_tokens_final}",
|
||||
f"Paged={bytes_to(mem_for_paged_optimal, unit='GB'):.2f}GB",
|
||||
f"State={bytes_to(state_cache_bytes_total, unit='GB'):.2f}GB",
|
||||
f"Unmanaged={bytes_to(unmanaged_cache_bytes_total, unit='GB'):.2f}GB",
|
||||
f"Total={bytes_to(total_cache_bytes, unit='GB'):.2f}GB",
|
||||
]
|
||||
ad_logger.info(f"Mem info for resize: {' | '.join(mem_info)}")
|
||||
ad_logger.info(f"Final Cache Mem: {' | '.join(mem_cache_info)}")
|
||||
|
||||
@property
|
||||
def kv_cache_manager(self) -> Optional[KVCacheManager]:
|
||||
@ -454,20 +709,17 @@ class CachedSequenceInterface:
|
||||
"""Return the original KVCacheConfig as passed in."""
|
||||
return self._kv_cache_config_original
|
||||
|
||||
def _clear_cache_views(self) -> None:
|
||||
"""Set paged and state cache views to None before pool release."""
|
||||
self._kv_cache_config_tuned = None
|
||||
for name in self._paged_cache_order:
|
||||
self._caches[name] = None
|
||||
for name in self._state_resource_order:
|
||||
self._caches[name] = None
|
||||
def _clear_caches(self) -> None:
|
||||
"""Clear all caches and views before pool release."""
|
||||
for k in self._caches:
|
||||
self._caches[k] = None
|
||||
self._unmanaged_resources.clear()
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""Shutdown and release all resources."""
|
||||
if self._kv_cache_manager is not None:
|
||||
self._kv_cache_manager.shutdown()
|
||||
self._kv_cache_config_tuned = None
|
||||
self._caches.clear()
|
||||
self._clear_caches()
|
||||
|
||||
|
||||
GetInferenceModel = Callable[[CachedSequenceInterface], nn.Module]
|
||||
|
||||
@ -655,6 +655,7 @@ class BaseTransform(ABC):
|
||||
abs(diff.resv) >= mem_change_threshold
|
||||
or abs(diff.alloc) >= mem_change_threshold
|
||||
or abs(diff.frag) >= mem_change_threshold
|
||||
or abs(diff.free) >= mem_change_threshold
|
||||
)
|
||||
|
||||
def _fmt_val_with_delta(val: float, delta: float, color: str) -> str:
|
||||
|
||||
@ -41,7 +41,7 @@ from ..interface import (
|
||||
TransformInfo,
|
||||
TransformRegistry,
|
||||
)
|
||||
from .kvcache import InsertCachedAttention
|
||||
from .kvcache import _InsertCachedOperator
|
||||
|
||||
|
||||
@torch.library.custom_op("auto_deploy::residual_add_for_capture", mutates_args=())
|
||||
@ -251,5 +251,5 @@ class CachedResidualAdd(AttentionDescriptor):
|
||||
|
||||
|
||||
@TransformRegistry.register("insert_cached_residual_add")
|
||||
class InsertCachedResidualAdd(InsertCachedAttention):
|
||||
class InsertCachedResidualAdd(_InsertCachedOperator):
|
||||
"""A transform to handle residual add cache operations."""
|
||||
|
||||
@ -51,10 +51,8 @@ class InsertCachedAttentionConfig(TransformConfig):
|
||||
backend: Optional[str] = Field(default=None, description="The attention backend to use.")
|
||||
|
||||
|
||||
@TransformRegistry.register("insert_cached_attention")
|
||||
class InsertCachedAttention(BaseTransform):
|
||||
"""
|
||||
A transform to insert cached attention into the graph module."""
|
||||
class _InsertCachedOperator(BaseTransform):
|
||||
"""A generic base transform to insert cached operators into the graph module."""
|
||||
|
||||
config: InsertCachedAttentionConfig
|
||||
|
||||
@ -236,15 +234,22 @@ class InsertCachedAttention(BaseTransform):
|
||||
return gm, info
|
||||
|
||||
|
||||
@TransformRegistry.register("insert_cached_attention")
|
||||
class InsertCachedAttention(_InsertCachedOperator):
|
||||
"""A transform to insert cached attention into the graph module."""
|
||||
|
||||
def _apply(self, *args, **kwargs):
|
||||
if self.config.backend == "triton":
|
||||
self._log_warning(
|
||||
"'triton' backend only supports GREEDY sampling (top_k=1). "
|
||||
"Please set top_k=1 in the sampling parameters to ensure cohesive output."
|
||||
)
|
||||
return super()._apply(*args, **kwargs)
|
||||
|
||||
|
||||
@TransformRegistry.register("insert_cached_mla_attention")
|
||||
class InsertCachedMLAAttention(InsertCachedAttention):
|
||||
"""
|
||||
A transform to insert cached MLA attention into the graph module.
|
||||
|
||||
This class is identical to InsertCachedAttention and inherits all its behavior.
|
||||
"""
|
||||
|
||||
pass
|
||||
class InsertCachedMLAAttention(_InsertCachedOperator):
|
||||
"""A transform to insert cached MLA attention into the graph module."""
|
||||
|
||||
|
||||
@TransformRegistry.register("resize_kv_cache")
|
||||
@ -316,7 +321,6 @@ class InitializeCache(BaseTransform):
|
||||
# Initialize with estimation mode
|
||||
# This allows resize_kv_cache to recreate with correct capacity after measuring memory
|
||||
num_caches = cm.initialize_resources()
|
||||
self._log_info(f"Initialized {num_caches} caches for cached attention")
|
||||
|
||||
info = TransformInfo(
|
||||
skipped=False, num_matches=num_caches, is_clean=True, has_valid_shapes=True
|
||||
|
||||
@ -16,7 +16,7 @@ from ...export.library.unified_attn import HF_ATTN_KWARGS_MAPPING
|
||||
from ...models.factory import ModelFactory
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry
|
||||
from .kvcache import InsertCachedAttention
|
||||
from .kvcache import _InsertCachedOperator
|
||||
|
||||
|
||||
def fake_profiler_mha(
|
||||
@ -202,7 +202,7 @@ def forward_with_prepare_metadata(mod: nn.Module, **cm_kwargs):
|
||||
# TODO: how running different kv cache transforms look like? This one below wouldn't work if we
|
||||
# had multiple types attention to replace...
|
||||
@TransformRegistry.register("transformers_replace_cached_attn")
|
||||
class HFReplaceCachedAttn(InsertCachedAttention):
|
||||
class HFReplaceCachedAttn(_InsertCachedOperator):
|
||||
"""Replace cached attention for the factory model, update inputs and outputs, and patch the gm forward."""
|
||||
|
||||
def _add_or_retrieve_input(
|
||||
|
||||
@ -1,20 +1,20 @@
|
||||
"""A set of transforms to handle SSM cache transforms."""
|
||||
|
||||
from ..interface import TransformRegistry
|
||||
from .kvcache import InsertCachedAttention
|
||||
from .kvcache import _InsertCachedOperator
|
||||
|
||||
|
||||
# TODO: think about separating valid attention backends per transform better in the future
|
||||
@TransformRegistry.register("insert_cached_ssm_attention")
|
||||
class SSMCacheTransform(InsertCachedAttention):
|
||||
class SSMCacheTransform(_InsertCachedOperator):
|
||||
"""A transform to handle SSM cache operations."""
|
||||
|
||||
|
||||
@TransformRegistry.register("insert_cached_causal_conv")
|
||||
class InitializeCausalConvCache(InsertCachedAttention):
|
||||
class InitializeCausalConvCache(_InsertCachedOperator):
|
||||
"""A transform to handle causal conv cache operations."""
|
||||
|
||||
|
||||
@TransformRegistry.register("insert_cached_delta_rule")
|
||||
class InsertCachedDeltaRule(InsertCachedAttention):
|
||||
class InsertCachedDeltaRule(_InsertCachedOperator):
|
||||
"""A transform to handle delta rule cache operations."""
|
||||
|
||||
@ -58,13 +58,32 @@ class TestLlama3_1_8B(LlmapiAccuracyTestHarness):
|
||||
MODEL_NAME = "meta-llama/Llama-3.1-8B"
|
||||
MODEL_PATH = hf_id_to_local_model_dir(MODEL_NAME)
|
||||
|
||||
def get_default_kwargs(self, enable_chunked_prefill=False):
|
||||
# Configuration presets for different attention backends
|
||||
ATTN_BACKEND_CONFIGS = {
|
||||
"flashinfer": {
|
||||
"max_batch_size": 512,
|
||||
"max_seq_len": 8192,
|
||||
"compile_backend": "torch-cudagraph",
|
||||
},
|
||||
"torch": {
|
||||
"max_batch_size": 128,
|
||||
"max_seq_len": 2048,
|
||||
"compile_backend": "torch-simple",
|
||||
},
|
||||
}
|
||||
|
||||
def get_default_kwargs(self,
|
||||
enable_chunked_prefill=False,
|
||||
attn_backend="flashinfer"):
|
||||
backend_cfg = self.ATTN_BACKEND_CONFIGS[attn_backend]
|
||||
|
||||
config = {
|
||||
"skip_tokenizer_init": False,
|
||||
"trust_remote_code": True,
|
||||
"max_batch_size": 512,
|
||||
"attn_backend": attn_backend,
|
||||
"max_batch_size": backend_cfg["max_batch_size"],
|
||||
# 131072 is the max seq len for the model
|
||||
"max_seq_len": 8192,
|
||||
"max_seq_len": backend_cfg["max_seq_len"],
|
||||
# max num tokens is derived in the build_config, which is not used by AutoDeploy llmargs.
|
||||
# Set it explicitly here to 8192 which is the default in build_config.
|
||||
"max_num_tokens": 8192,
|
||||
@ -75,7 +94,7 @@ class TestLlama3_1_8B(LlmapiAccuracyTestHarness):
|
||||
"transforms": {
|
||||
"compile_model": {
|
||||
"backend":
|
||||
"torch-cudagraph",
|
||||
backend_cfg["compile_backend"],
|
||||
"cuda_graph_batch_sizes":
|
||||
[1, 2, 4, 8, 16, 32, 64, 128, 256, 512],
|
||||
},
|
||||
@ -83,8 +102,8 @@ class TestLlama3_1_8B(LlmapiAccuracyTestHarness):
|
||||
}
|
||||
if enable_chunked_prefill:
|
||||
config["enable_chunked_prefill"] = True
|
||||
config[
|
||||
"max_num_tokens"] = 512 # NOTE: must be > max(tokens_per_block, max_batch_size)
|
||||
# NOTE: must be > max(tokens_per_block, max_batch_size)
|
||||
config["max_num_tokens"] = 512
|
||||
return config
|
||||
|
||||
def get_default_sampling_params(self):
|
||||
@ -98,8 +117,9 @@ class TestLlama3_1_8B(LlmapiAccuracyTestHarness):
|
||||
@pytest.mark.skip_less_device_memory(32000)
|
||||
@pytest.mark.parametrize("world_size", [1, 2, 4])
|
||||
@pytest.mark.parametrize("enable_chunked_prefill", [False, True])
|
||||
def test_auto_dtype(self, world_size, enable_chunked_prefill):
|
||||
kwargs = self.get_default_kwargs(enable_chunked_prefill)
|
||||
@pytest.mark.parametrize("attn_backend", ["flashinfer", "torch"])
|
||||
def test_auto_dtype(self, world_size, enable_chunked_prefill, attn_backend):
|
||||
kwargs = self.get_default_kwargs(enable_chunked_prefill, attn_backend)
|
||||
sampling_params = self.get_default_sampling_params()
|
||||
with AutoDeployLLM(model=self.MODEL_PATH,
|
||||
tokenizer=self.MODEL_PATH,
|
||||
@ -107,8 +127,9 @@ class TestLlama3_1_8B(LlmapiAccuracyTestHarness):
|
||||
**kwargs) as llm:
|
||||
task = CnnDailymail(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm, sampling_params=sampling_params)
|
||||
if attn_backend != "torch":
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm, sampling_params=sampling_params)
|
||||
|
||||
@pytest.mark.skip_less_device_memory(32000)
|
||||
@pytest.mark.skip_less_device(2)
|
||||
|
||||
@ -176,6 +176,7 @@ l0_b200:
|
||||
stage: pre_merge
|
||||
backend: autodeploy
|
||||
tests:
|
||||
- accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[False-1]
|
||||
- accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[flashinfer-False-1]
|
||||
- accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[torch-True-1]
|
||||
- accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_fp8[1]
|
||||
- unittest/_torch/auto_deploy/unit/singlegpu
|
||||
|
||||
@ -221,7 +221,7 @@ l0_dgx_b200:
|
||||
orchestrator: mpi
|
||||
tests:
|
||||
- unittest/_torch/auto_deploy/unit/multigpu
|
||||
- accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[False-4]
|
||||
- accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[flashinfer-False-4]
|
||||
- accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_fp8[4]
|
||||
- accuracy/test_llm_api_autodeploy.py::TestNemotronSuperV3::test_bf16
|
||||
- accuracy/test_llm_api_autodeploy.py::TestNemotronSuperV3::test_fp8[4]
|
||||
|
||||
@ -333,7 +333,7 @@ l0_dgx_h100:
|
||||
orchestrator: mpi
|
||||
tests:
|
||||
- unittest/_torch/auto_deploy/unit/multigpu
|
||||
- accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[False-4]
|
||||
- accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[flashinfer-False-4]
|
||||
- accuracy/test_llm_api_autodeploy.py::TestNemotronSuperV3::test_bf16
|
||||
- accuracy/test_llm_api_autodeploy.py::TestNemotronSuperV3::test_fp8[4]
|
||||
- accuracy/test_llm_api_autodeploy.py::TestNemotronSuperV3::test_fp8[8]
|
||||
|
||||
@ -432,8 +432,8 @@ l0_h100:
|
||||
orchestrator: mpi
|
||||
tests:
|
||||
- unittest/_torch/auto_deploy/unit/singlegpu
|
||||
- accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[False-1]
|
||||
- accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[True-1]
|
||||
- accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[flashinfer-False-1]
|
||||
- accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[flashinfer-True-1]
|
||||
- accuracy/test_llm_api_autodeploy.py::TestNemotronH::test_auto_dtype[triton_ssm-False]
|
||||
- accuracy/test_llm_api_autodeploy.py::TestNemotronH::test_auto_dtype[flashinfer_ssm-False]
|
||||
- accuracy/test_llm_api_autodeploy.py::TestNemotronH::test_auto_dtype[triton_ssm-True]
|
||||
|
||||
@ -1,83 +1,10 @@
|
||||
import pytest
|
||||
import torch
|
||||
from _custom_op_utils import torch_rope_reference
|
||||
from torch_attention_reference import TorchAttentionReference
|
||||
|
||||
import tensorrt_llm._torch.auto_deploy # noqa: F401
|
||||
|
||||
|
||||
def test_attention_op():
|
||||
DEVICE = "cuda"
|
||||
DTYPE = torch.float16
|
||||
BATCH_SIZE = 2
|
||||
N_HEADS = 8
|
||||
D_HEAD = 32
|
||||
MAX_SEQ_LEN = 128
|
||||
|
||||
# Q,K,V are computed using GEMM.
|
||||
qkv = torch.randn(BATCH_SIZE, 3, N_HEADS, D_HEAD, dtype=DTYPE, device=DEVICE)
|
||||
k_cache = torch.zeros((BATCH_SIZE, MAX_SEQ_LEN, N_HEADS, D_HEAD), dtype=DTYPE, device=DEVICE)
|
||||
v_cache = torch.zeros((BATCH_SIZE, MAX_SEQ_LEN, N_HEADS, D_HEAD), dtype=DTYPE, device=DEVICE)
|
||||
input_positions = torch.zeros(BATCH_SIZE, device=DEVICE, dtype=torch.int) + 1
|
||||
|
||||
q, k, v = (x.contiguous() for x in torch.split(qkv, 1, dim=1))
|
||||
|
||||
output = torch.ops.auto_deploy.triton_attention_fused_mha_with_cache(
|
||||
q, k, v, input_positions, k_cache, v_cache, None
|
||||
)
|
||||
# Use torch backend as clean reference
|
||||
ref = TorchAttentionReference.basic_mha_with_cache(q, k, v, k_cache, v_cache, input_positions)
|
||||
assert torch.allclose(
|
||||
ref.cpu().to(torch.float32),
|
||||
output.cpu().to(torch.float32),
|
||||
atol=1e-2,
|
||||
rtol=1e-2,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seq_len", [1, 8])
|
||||
@pytest.mark.parametrize("group_size", [1, 4])
|
||||
@pytest.mark.parametrize("n_heads", [8])
|
||||
@pytest.mark.parametrize("dtype", ["float16", "float32"])
|
||||
@pytest.mark.parametrize("device", ["cuda"])
|
||||
def test_gqa_op(device, dtype, n_heads, group_size, seq_len):
|
||||
BATCH_SIZE = 2
|
||||
D_HEAD = 16
|
||||
CACHE_SEQ_LEN = 8
|
||||
|
||||
dtype = getattr(torch, dtype)
|
||||
n_kv_heads = n_heads // group_size
|
||||
|
||||
if seq_len == 1:
|
||||
offset = seq_len // 2
|
||||
input_positions = torch.zeros(BATCH_SIZE, device=device, dtype=torch.int) + offset
|
||||
else:
|
||||
input_positions = torch.zeros(BATCH_SIZE, device=device, dtype=torch.int)
|
||||
|
||||
q = torch.randn(BATCH_SIZE, seq_len, n_heads, D_HEAD, dtype=dtype, device=device)
|
||||
k = torch.randn(BATCH_SIZE, seq_len, n_kv_heads, D_HEAD, dtype=dtype, device=device)
|
||||
v = torch.randn(BATCH_SIZE, seq_len, n_kv_heads, D_HEAD, dtype=dtype, device=device)
|
||||
|
||||
# setup kv-cache
|
||||
k_cache = torch.randn(BATCH_SIZE, CACHE_SEQ_LEN, n_kv_heads, D_HEAD, dtype=dtype, device=device)
|
||||
v_cache = torch.randn(BATCH_SIZE, CACHE_SEQ_LEN, n_kv_heads, D_HEAD, dtype=dtype, device=device)
|
||||
|
||||
# run custom op
|
||||
output = torch.ops.auto_deploy.triton_attention_fused_mha_with_cache(
|
||||
q, k, v, input_positions, k_cache, v_cache, None
|
||||
)
|
||||
|
||||
# Use torch backend as clean reference
|
||||
ref = TorchAttentionReference.basic_mha_with_cache(q, k, v, k_cache, v_cache, input_positions)
|
||||
|
||||
assert torch.allclose(
|
||||
ref.cpu().to(torch.float32),
|
||||
output.cpu().to(torch.float32),
|
||||
atol=1e-2,
|
||||
rtol=1e-2,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_generate_ratio", [0.0, 0.5, 1.0])
|
||||
@pytest.mark.parametrize("max_seq_len", [0, 1, 16])
|
||||
@pytest.mark.parametrize("group_size", [1, 4])
|
||||
@ -159,331 +86,3 @@ def test_flat_gqa_op(
|
||||
atol=1e-2,
|
||||
rtol=1e-2,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_generate_ratio", [0.0, 0.5, 1.0])
|
||||
@pytest.mark.parametrize("max_seq_len", [0, 1, 16])
|
||||
@pytest.mark.parametrize("group_size", [1, 4])
|
||||
@pytest.mark.parametrize("n_heads", [8])
|
||||
@pytest.mark.parametrize("batch_size", [1, 16])
|
||||
@pytest.mark.parametrize("dtype", ["float16", "float32"])
|
||||
@pytest.mark.parametrize("device", ["cuda"])
|
||||
def test_flat_gqa_op_with_rope(
|
||||
device, dtype, batch_size, n_heads, group_size, max_seq_len, num_generate_ratio
|
||||
):
|
||||
n_heads = n_heads
|
||||
n_kv_heads = n_heads // group_size
|
||||
D_HEAD = 16
|
||||
dtype = getattr(torch, dtype)
|
||||
int_kwargs = {"device": device, "dtype": torch.int32}
|
||||
dtype_kwargs = {"device": device, "dtype": dtype}
|
||||
|
||||
# setup caches with 2*batch_size, 2*max_seq_len since we also randomize input_pos
|
||||
cache_max_seq_len = 2 * (max_seq_len + 1)
|
||||
cache_max_batch_size = 2 * batch_size
|
||||
cache_size = (cache_max_batch_size, cache_max_seq_len, n_kv_heads, D_HEAD)
|
||||
cache_loc = torch.randperm(cache_max_batch_size, **int_kwargs)[:batch_size]
|
||||
|
||||
k_cache = torch.randn(cache_size, **dtype_kwargs)
|
||||
v_cache = torch.randn(cache_size, **dtype_kwargs)
|
||||
|
||||
# randomize num_context vs num_generate; NOTE: we can use context kernel for generate as well
|
||||
num_generate = torch.tensor(num_generate_ratio * batch_size, **int_kwargs)
|
||||
num_context = batch_size - num_generate
|
||||
|
||||
# construct random input_positions
|
||||
input_positions = torch.randint(0, max_seq_len + 1, (batch_size,), **int_kwargs)
|
||||
|
||||
# construct seq_len, seq_start;
|
||||
seq_len = torch.cat(
|
||||
[
|
||||
torch.randint(0, max_seq_len + 1, (num_context,), **int_kwargs), # context
|
||||
torch.zeros(num_generate, **int_kwargs) + (max_seq_len > 0), # generate
|
||||
]
|
||||
)
|
||||
seq_start = seq_len.cumsum(0) - seq_len
|
||||
|
||||
# get fake input
|
||||
q = torch.randn(1, seq_len.sum(), n_heads * D_HEAD, **dtype_kwargs)
|
||||
k = torch.randn(1, seq_len.sum(), n_kv_heads * D_HEAD, **dtype_kwargs)
|
||||
v = torch.randn(1, seq_len.sum(), n_kv_heads * D_HEAD, **dtype_kwargs)
|
||||
|
||||
# rope can modify the original tensor value
|
||||
q_o = q.clone()
|
||||
k_o = k.clone()
|
||||
|
||||
freqs_cis = torch.rand([cache_max_seq_len, D_HEAD // 2, 2], device=device, dtype=torch.float32)
|
||||
|
||||
# run op
|
||||
source = 1
|
||||
if source == 1:
|
||||
# call rope fusion kernels
|
||||
output = torch.ops.auto_deploy.triton_attention_fused_flattened_mha_with_cache_rope_fusion(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
input_positions,
|
||||
cache_loc,
|
||||
seq_len,
|
||||
seq_start,
|
||||
k_cache,
|
||||
v_cache,
|
||||
freqs_cis,
|
||||
)
|
||||
else:
|
||||
# call stand-alone rope embedding kernel
|
||||
output = torch.ops.auto_deploy.triton_attention_fused_flattened_mha_with_cache(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
input_positions,
|
||||
cache_loc,
|
||||
seq_len,
|
||||
seq_start,
|
||||
k_cache,
|
||||
v_cache,
|
||||
freqs_cis,
|
||||
)
|
||||
|
||||
# prep batched tensors for comparison
|
||||
q_b = torch.zeros(batch_size, n_heads, max_seq_len, D_HEAD, **dtype_kwargs)
|
||||
k_cache_b = k_cache[cache_loc].transpose(1, 2)
|
||||
v_cache_b = v_cache[cache_loc].transpose(1, 2)
|
||||
|
||||
def _store(t_batched, t_flat):
|
||||
# batched layout: [n,s,d]; flat layout: [s,n*d]
|
||||
n_h, _, d_h = t_batched.shape
|
||||
t_batched[:] = t_flat.view(-1, n_h, d_h).transpose(0, 1)
|
||||
|
||||
def _store_rope(t_batched, t_flat, input_pos):
|
||||
# batched layout: [n,s,d];
|
||||
# flat layout: [s,n*d], and in interleaved order
|
||||
# need to reorder to normal for torch_rope_reference and then reorder back
|
||||
n_h, _, d_h = t_batched.shape
|
||||
t_i = t_flat.view(-1, n_h, d_h).unsqueeze(0)
|
||||
t = t_i.unflatten(-1, (2, D_HEAD // 2)).transpose(-1, -2).flatten(-2).contiguous()
|
||||
t_rope = torch_rope_reference(t, freqs_cis, input_pos)
|
||||
t_rope = t_rope.unflatten(-1, (D_HEAD // 2, 2)).transpose(-1, -2).flatten(-2).contiguous()
|
||||
t_batched[:] = t_rope[0].transpose(0, 1)
|
||||
|
||||
for i_b, (i_pos, s_start, s_len) in enumerate(zip(input_positions, seq_start, seq_len)):
|
||||
# fill roped q, k in a batched manner
|
||||
_store_rope(q_b[i_b, :, :s_len], q_o[0, s_start : s_start + s_len], input_positions[i_b])
|
||||
_store_rope(
|
||||
k_cache_b[i_b, :, i_pos : i_pos + s_len],
|
||||
k_o[0, s_start : s_start + s_len],
|
||||
input_positions[i_b],
|
||||
)
|
||||
# fill v in a batched manner
|
||||
_store(v_cache_b[i_b, :, i_pos : i_pos + s_len], v[0, s_start : s_start + s_len])
|
||||
|
||||
k_cache_b = torch.repeat_interleave(k_cache_b, group_size, dim=1) # [b,n,s,d]
|
||||
v_cache_b = torch.repeat_interleave(v_cache_b, group_size, dim=1) # [b,n,s,d]
|
||||
|
||||
# run comparison
|
||||
refs = []
|
||||
for i_b, (i_pos, s_start, s_len) in enumerate(zip(input_positions, seq_start, seq_len)):
|
||||
mask = torch.cat(
|
||||
[
|
||||
torch.ones(s_len, i_pos, device=device, dtype=torch.bool),
|
||||
torch.tril(torch.ones(s_len, s_len, device=device, dtype=torch.bool)),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
ref_i = torch.nn.functional.scaled_dot_product_attention(
|
||||
q_b[i_b, :, :s_len],
|
||||
k_cache_b[i_b, :, : i_pos + s_len],
|
||||
v_cache_b[i_b, :, : i_pos + s_len],
|
||||
attn_mask=mask,
|
||||
) # [n,s,d]
|
||||
ref_i = ref_i.transpose(0, 1).contiguous().view(s_len, n_heads * D_HEAD) # [s,n*d]
|
||||
refs.append(ref_i)
|
||||
|
||||
# flatten output for comparison
|
||||
ref_flat = torch.cat(refs, dim=0)[None] # [1,s_total,n*d]
|
||||
|
||||
assert torch.allclose(
|
||||
ref_flat.cpu().to(torch.float32),
|
||||
output.cpu().to(torch.float32),
|
||||
atol=1e-2,
|
||||
rtol=1e-2,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_generate_ratio", [0.0, 0.5, 1.0])
|
||||
@pytest.mark.parametrize("max_seq_len", [0, 1, 16])
|
||||
@pytest.mark.parametrize("group_size", [1, 4])
|
||||
@pytest.mark.parametrize("n_heads", [8])
|
||||
@pytest.mark.parametrize("batch_size", [1, 16])
|
||||
@pytest.mark.parametrize("dtype", ["float16", "float32"])
|
||||
@pytest.mark.parametrize("device", ["cuda"])
|
||||
def test_paged_gqa_op(
|
||||
device, dtype, batch_size, n_heads, group_size, max_seq_len, num_generate_ratio
|
||||
):
|
||||
n_heads = n_heads
|
||||
n_kv_heads = n_heads // group_size
|
||||
D_HEAD = 16
|
||||
dtype = getattr(torch, dtype)
|
||||
int_kwargs = {"device": device, "dtype": torch.int32}
|
||||
dtype_kwargs = {"device": device, "dtype": dtype}
|
||||
PAGE_SIZE = 4
|
||||
|
||||
# setup caches with 2*batch_size, 2*max_seq_len since we also randomize input_pos
|
||||
cache_max_seq_len = 2 * (max_seq_len + 1)
|
||||
cache_max_batch_size = 2 * batch_size
|
||||
cache_max_pages = (cache_max_batch_size * cache_max_seq_len + PAGE_SIZE - 1) // PAGE_SIZE
|
||||
batch_max_pages = (cache_max_seq_len + PAGE_SIZE - 1) // PAGE_SIZE
|
||||
cache_size = (cache_max_pages, PAGE_SIZE, n_kv_heads, D_HEAD)
|
||||
cache_loc = torch.randperm(cache_max_batch_size, **int_kwargs)[:batch_size]
|
||||
|
||||
k_cache = torch.zeros(cache_size, **dtype_kwargs)
|
||||
v_cache = torch.randn(cache_size, **dtype_kwargs)
|
||||
|
||||
# randomize num_context vs num_generate; NOTE: we can use context kernel for generate as well
|
||||
num_generate = torch.tensor(num_generate_ratio * batch_size, **int_kwargs)
|
||||
num_context = batch_size - num_generate
|
||||
|
||||
# construct seq_len, seq_start;
|
||||
# Context seq_len = 0 can result in wrong view and disorder infos like seq_start,
|
||||
# only check seq_len > 0.
|
||||
# i.e. num_context = 1, num_generate = 1 and seq_len = [0, 1],
|
||||
# but the op might mistake the batch as batch_size = 1 and use the context batch infos.
|
||||
seq_len = torch.cat(
|
||||
[
|
||||
torch.randint(
|
||||
1 if max_seq_len > 0 else 0,
|
||||
max_seq_len + 1,
|
||||
(num_context,),
|
||||
**int_kwargs,
|
||||
), # context
|
||||
torch.zeros(num_generate, **int_kwargs) + (max_seq_len > 0), # generate
|
||||
]
|
||||
)
|
||||
|
||||
# construct random input_positions(cache_len)
|
||||
input_positions = torch.cat(
|
||||
[
|
||||
torch.zeros(num_context, **int_kwargs), # context
|
||||
torch.randint(0, max_seq_len + 1, (num_generate,), **int_kwargs), # generate
|
||||
]
|
||||
)
|
||||
|
||||
seq_start = (seq_len.cumsum(0) - seq_len).to(torch.int32)
|
||||
|
||||
# allocate pages for kv cache
|
||||
# pages of each batch is continuous
|
||||
PAGE_TABLE = [[0] * batch_max_pages] * cache_max_batch_size
|
||||
cnt = 0
|
||||
for b in range(batch_size):
|
||||
# allocate pages for history kv cache and new coming kv
|
||||
length = input_positions[b] + seq_len[b]
|
||||
allocated_pages = (length + PAGE_SIZE - 1) // PAGE_SIZE
|
||||
table = []
|
||||
for p in range(batch_max_pages):
|
||||
if p < allocated_pages:
|
||||
table.append(cnt)
|
||||
cnt = cnt + 1
|
||||
else:
|
||||
table.append(0)
|
||||
PAGE_TABLE[cache_loc[b]] = table
|
||||
# prepare value for kv cache of decode batch
|
||||
cache_pages = input_positions[b] // PAGE_SIZE
|
||||
cache_page_off = input_positions[b] % PAGE_SIZE
|
||||
k_cache[table[0] : table[cache_pages]] = torch.randn(
|
||||
cache_pages, PAGE_SIZE, n_kv_heads, D_HEAD, **dtype_kwargs
|
||||
)
|
||||
v_cache[table[0] : table[cache_pages]] = torch.randn(
|
||||
cache_pages, PAGE_SIZE, n_kv_heads, D_HEAD, **dtype_kwargs
|
||||
)
|
||||
k_cache[table[cache_pages], 0:cache_page_off] = torch.randn(
|
||||
cache_page_off, n_kv_heads, D_HEAD, **dtype_kwargs
|
||||
)
|
||||
v_cache[table[cache_pages], 0:cache_page_off] = torch.randn(
|
||||
cache_page_off, n_kv_heads, D_HEAD, **dtype_kwargs
|
||||
)
|
||||
|
||||
page_table = torch.tensor(PAGE_TABLE, **int_kwargs)
|
||||
|
||||
# get fake input
|
||||
q = torch.randn(1, seq_len.sum(), n_heads * D_HEAD, **dtype_kwargs)
|
||||
k = torch.randn(1, seq_len.sum(), n_kv_heads * D_HEAD, **dtype_kwargs)
|
||||
v = torch.randn(1, seq_len.sum(), n_kv_heads * D_HEAD, **dtype_kwargs)
|
||||
|
||||
# run op
|
||||
output = torch.ops.auto_deploy.triton_attention_fused_mha_with_paged_cache(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
input_positions,
|
||||
cache_loc,
|
||||
seq_len,
|
||||
seq_start,
|
||||
page_table,
|
||||
cache_max_seq_len,
|
||||
k_cache,
|
||||
v_cache,
|
||||
None,
|
||||
)
|
||||
|
||||
# TODO (nvchenghaoz): Replace this with torch backend reference.
|
||||
|
||||
# prep batched tensors for comparison
|
||||
def compute_reference(q, k_cache, v_cache):
|
||||
ref = []
|
||||
for batch in range(batch_size):
|
||||
table = page_table[cache_loc[batch]]
|
||||
s_len = seq_len[batch]
|
||||
c_len = input_positions[batch]
|
||||
length = c_len + s_len
|
||||
cache_pages = length // PAGE_SIZE
|
||||
cache_page_off = length % PAGE_SIZE
|
||||
s_start = seq_start[batch]
|
||||
# [bsnd]
|
||||
qq = q[0, s_start : s_start + s_len].view(1, -1, n_heads, D_HEAD)
|
||||
kk = []
|
||||
vv = []
|
||||
kk.append(
|
||||
k_cache[table[0] : table[0] + cache_pages].reshape(
|
||||
1, cache_pages * PAGE_SIZE, n_kv_heads, D_HEAD
|
||||
)
|
||||
)
|
||||
kk.append(
|
||||
k_cache[table[0] + cache_pages, 0:cache_page_off].reshape(
|
||||
1, cache_page_off, n_kv_heads, D_HEAD
|
||||
)
|
||||
)
|
||||
# [bsnd]
|
||||
k_f = torch.cat(kk, 1)
|
||||
vv.append(
|
||||
v_cache[table[0] : table[0] + cache_pages].reshape(
|
||||
1, cache_pages * PAGE_SIZE, n_kv_heads, D_HEAD
|
||||
)
|
||||
)
|
||||
vv.append(
|
||||
v_cache[table[0] + cache_pages, 0:cache_page_off].reshape(
|
||||
1, cache_page_off, n_kv_heads, D_HEAD
|
||||
)
|
||||
)
|
||||
v_f = torch.cat(vv, 1)
|
||||
if n_heads != n_kv_heads:
|
||||
k_f = torch.repeat_interleave(k_f, group_size, dim=2)
|
||||
v_f = torch.repeat_interleave(v_f, group_size, dim=2)
|
||||
mask = torch.tril(
|
||||
torch.ones(s_len, length, dtype=torch.bool),
|
||||
diagonal=c_len,
|
||||
)
|
||||
ref.append(
|
||||
torch.nn.functional.scaled_dot_product_attention(
|
||||
qq.transpose(1, 2),
|
||||
k_f.transpose(1, 2),
|
||||
v_f.transpose(1, 2),
|
||||
attn_mask=mask.to(device),
|
||||
)
|
||||
.transpose(2, 1)
|
||||
.contiguous()
|
||||
.view(1, s_len, n_heads * D_HEAD) # [b,s,n*d]
|
||||
)
|
||||
return torch.cat(ref, 1)
|
||||
|
||||
ref = compute_reference(q, k_cache, v_cache)
|
||||
assert torch.allclose(ref, output, atol=1e-2, rtol=1e-2)
|
||||
|
||||
@ -6,17 +6,31 @@ from torch_attention_reference import TorchAttentionReference
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.flashinfer_attention import _GlobalFlashInferPlanner
|
||||
|
||||
|
||||
def _create_combined_kv_cache(k_cache: torch.Tensor, v_cache: torch.Tensor) -> torch.Tensor:
|
||||
"""Create combined KV cache in HND layout from separate K and V caches.
|
||||
|
||||
Input shapes (NHD layout): [num_blocks, tokens_per_block, num_heads, head_dim]
|
||||
Output shape (HND layout): [num_blocks, 2, num_heads, tokens_per_block, head_dim]
|
||||
"""
|
||||
# Input: [num_blocks, tokens_per_block, num_heads, head_dim]
|
||||
# Permute to: [num_blocks, num_heads, tokens_per_block, head_dim]
|
||||
k_hnd = k_cache.permute(0, 2, 1, 3)
|
||||
v_hnd = v_cache.permute(0, 2, 1, 3)
|
||||
# Stack along kv dimension: [num_blocks, 2, num_heads, tokens_per_block, head_dim]
|
||||
return torch.stack([k_hnd, v_hnd], dim=1)
|
||||
|
||||
|
||||
def _attention_with_fp8_kv_cache(
|
||||
q, k, v, k_cache, v_cache, k_scale, v_scale, prefill_seq_len, causal, mask
|
||||
q, k, v, kv_cache, k_scale, v_scale, prefill_seq_len, causal, mask
|
||||
):
|
||||
"""Simulates attention for fp8 kv cache with q,k,v outputs of GEMM in fp16"""
|
||||
batch_size, seq_len, _ = k.shape
|
||||
# Quantize k and v
|
||||
# k = (k / k_scale).to(torch.float8_e4m3fn)
|
||||
# v = (v / v_scale).to(torch.float8_e4m3fn)
|
||||
# Append to kv cache
|
||||
# k_cache[0:batch_size, prefill_seq_len : prefill_seq_len + seq_len, :, :] = k
|
||||
# v_cache[0:batch_size, prefill_seq_len : prefill_seq_len + seq_len, :, :] = v
|
||||
# kv_cache shape: [num_blocks, 2, num_heads, tokens_per_block, head_dim] (HND layout)
|
||||
# Extract k and v, convert back to NHD layout for reference
|
||||
k_cache_hnd = kv_cache[:, 0, :, :, :] # [num_blocks, num_heads, tokens_per_block, head_dim]
|
||||
v_cache_hnd = kv_cache[:, 1, :, :, :]
|
||||
k_cache = k_cache_hnd.permute(0, 2, 1, 3) # [num_blocks, tokens_per_block, num_heads, head_dim]
|
||||
v_cache = v_cache_hnd.permute(0, 2, 1, 3)
|
||||
|
||||
# Compute attention
|
||||
# Step 1: Retrieve KV cache
|
||||
@ -75,12 +89,10 @@ def test_flashinfer_attention_op_context(seq_length, n_heads, batch_size, dtype,
|
||||
k = torch.randn(BATCH_SIZE, SEQ_LEN, N_HEADS * D_HEAD, dtype=DTYPE).to(device)
|
||||
v = torch.randn(BATCH_SIZE, SEQ_LEN, N_HEADS * D_HEAD, dtype=DTYPE).to(device)
|
||||
|
||||
# Setup KV Cache. KV cache is empty, context phase
|
||||
k_cache = torch.zeros(
|
||||
(MAX_BATCH_SIZE, MAX_SEQ_LEN, N_HEADS, D_HEAD), dtype=DTYPE, device=device
|
||||
)
|
||||
v_cache = torch.zeros(
|
||||
(MAX_BATCH_SIZE, MAX_SEQ_LEN, N_HEADS, D_HEAD), dtype=DTYPE, device=device
|
||||
# Setup KV Cache in HND layout: [num_blocks, 2, num_heads, tokens_per_block, head_dim]
|
||||
# For unpaged case: num_blocks=MAX_BATCH_SIZE, tokens_per_block=MAX_SEQ_LEN
|
||||
kv_cache = torch.zeros(
|
||||
(MAX_BATCH_SIZE, 2, N_HEADS, MAX_SEQ_LEN, D_HEAD), dtype=DTYPE, device=device
|
||||
)
|
||||
|
||||
# make sure planner is initialized
|
||||
@ -89,7 +101,7 @@ def test_flashinfer_attention_op_context(seq_length, n_heads, batch_size, dtype,
|
||||
batch_indices, positions = flashinfer.get_batch_indices_positions(
|
||||
qo_indptr,
|
||||
flashinfer.get_seq_lens(
|
||||
paged_kv_indptr, paged_kv_last_page_len, page_size=k_cache.shape[1]
|
||||
paged_kv_indptr, paged_kv_last_page_len, page_size=kv_cache.shape[3]
|
||||
),
|
||||
BATCH_SIZE * SEQ_LEN,
|
||||
)
|
||||
@ -114,15 +126,19 @@ def test_flashinfer_attention_op_context(seq_length, n_heads, batch_size, dtype,
|
||||
# EXTRA METADATA
|
||||
batch_indices,
|
||||
positions,
|
||||
# CACHES
|
||||
k_cache,
|
||||
v_cache,
|
||||
# CACHES - combined KV cache in HND layout
|
||||
kv_cache,
|
||||
# CONSTANTS
|
||||
None,
|
||||
1.0,
|
||||
1.0,
|
||||
)
|
||||
|
||||
# Extract k_cache and v_cache for reference computation (convert HND to NHD)
|
||||
# Need .contiguous() after permute() for later operations
|
||||
k_cache = kv_cache[:, 0, :, :, :].permute(0, 2, 1, 3).contiguous() # [batch, seq, heads, dim]
|
||||
v_cache = kv_cache[:, 1, :, :, :].permute(0, 2, 1, 3).contiguous()
|
||||
|
||||
# Use torch backend as clean reference
|
||||
q_reshaped = q.view(BATCH_SIZE, SEQ_LEN, N_HEADS, D_HEAD)
|
||||
k_reshaped = k.view(BATCH_SIZE, SEQ_LEN, N_HEADS, D_HEAD)
|
||||
@ -185,26 +201,23 @@ def test_flashinfer_attention_op_decode(
|
||||
k = torch.ones(BATCH_SIZE, SEQ_LEN, N_HEADS * D_HEAD, dtype=DTYPE).to(device)
|
||||
v = torch.randn(BATCH_SIZE, SEQ_LEN, N_HEADS * D_HEAD, dtype=DTYPE).to(device)
|
||||
|
||||
# Setup KV Cache. KV cache is partially filled from the context phase
|
||||
k_cache = torch.zeros(
|
||||
# Setup KV Cache in HND layout: [num_blocks, 2, num_heads, tokens_per_block, head_dim]
|
||||
kv_cache = torch.zeros(
|
||||
(MAX_BATCH_SIZE, 2, N_HEADS, MAX_SEQ_LEN, D_HEAD), dtype=DTYPE, device=device
|
||||
)
|
||||
|
||||
# Initialize prefilled portion (zeros in this test)
|
||||
# kv_cache is already zeros
|
||||
|
||||
# Generate reference cache in NHD layout for comparison
|
||||
k_cache_ref = torch.zeros(
|
||||
(MAX_BATCH_SIZE, MAX_SEQ_LEN, N_HEADS, D_HEAD), dtype=DTYPE, device=device
|
||||
)
|
||||
v_cache = torch.zeros(
|
||||
v_cache_ref = torch.zeros(
|
||||
(MAX_BATCH_SIZE, MAX_SEQ_LEN, N_HEADS, D_HEAD), dtype=DTYPE, device=device
|
||||
)
|
||||
|
||||
k_cache[0:BATCH_SIZE, 0:PREFILL_SEQ_LEN, :, :] = torch.zeros(
|
||||
BATCH_SIZE, PREFILL_SEQ_LEN, N_HEADS, D_HEAD
|
||||
)
|
||||
v_cache[0:BATCH_SIZE, 0:PREFILL_SEQ_LEN, :, :] = torch.zeros(
|
||||
BATCH_SIZE, PREFILL_SEQ_LEN, N_HEADS, D_HEAD
|
||||
)
|
||||
|
||||
# Generate reference cache
|
||||
k_cache_ref = k_cache.clone()
|
||||
v_cache_ref = v_cache.clone()
|
||||
|
||||
# Apply RoPE to k_cache
|
||||
# Fill expected values after append
|
||||
k_cache_ref[0:BATCH_SIZE, PREFILL_SEQ_LEN : PREFILL_SEQ_LEN + SEQ_LEN, :, :] = k.view(
|
||||
BATCH_SIZE, SEQ_LEN, N_HEADS, D_HEAD
|
||||
)
|
||||
@ -212,26 +225,13 @@ def test_flashinfer_attention_op_decode(
|
||||
BATCH_SIZE, SEQ_LEN, N_HEADS, D_HEAD
|
||||
)
|
||||
|
||||
assert not torch.allclose(
|
||||
k_cache_ref.to(torch.float32),
|
||||
k_cache.to(torch.float32),
|
||||
atol=1e-2,
|
||||
rtol=1e-2,
|
||||
)
|
||||
assert not torch.allclose(
|
||||
v_cache_ref.to(torch.float32),
|
||||
v_cache.to(torch.float32),
|
||||
atol=1e-2,
|
||||
rtol=1e-2,
|
||||
)
|
||||
|
||||
# make sure planner is initialized
|
||||
_GlobalFlashInferPlanner.reset(torch.device(device))
|
||||
|
||||
batch_indices, positions = flashinfer.get_batch_indices_positions(
|
||||
qo_indptr,
|
||||
flashinfer.get_seq_lens(
|
||||
paged_kv_indptr, paged_kv_last_page_len, page_size=k_cache.shape[1]
|
||||
paged_kv_indptr, paged_kv_last_page_len, page_size=kv_cache.shape[3]
|
||||
),
|
||||
BATCH_SIZE * SEQ_LEN,
|
||||
)
|
||||
@ -255,15 +255,19 @@ def test_flashinfer_attention_op_decode(
|
||||
# EXTRA METADATA
|
||||
batch_indices,
|
||||
positions,
|
||||
# CACHES
|
||||
k_cache,
|
||||
v_cache,
|
||||
# CACHES - combined KV cache in HND layout
|
||||
kv_cache,
|
||||
# CONSTANTS
|
||||
None,
|
||||
1.0,
|
||||
1.0,
|
||||
)
|
||||
|
||||
# Extract k_cache and v_cache for reference (convert HND to NHD)
|
||||
# Need .contiguous() after permute() for later operations
|
||||
k_cache = kv_cache[:, 0, :, :, :].permute(0, 2, 1, 3).contiguous() # [batch, seq, heads, dim]
|
||||
v_cache = kv_cache[:, 1, :, :, :].permute(0, 2, 1, 3).contiguous()
|
||||
|
||||
assert torch.allclose(
|
||||
k_cache_ref.to(torch.float32),
|
||||
k_cache.to(torch.float32),
|
||||
@ -280,8 +284,8 @@ def test_flashinfer_attention_op_decode(
|
||||
# Generate reference outputs
|
||||
q_ref = q.view(BATCH_SIZE, SEQ_LEN, N_HEADS, D_HEAD)
|
||||
|
||||
k_ref = k_cache[:BATCH_SIZE, : PREFILL_SEQ_LEN + SEQ_LEN, :, :]
|
||||
v_ref = v_cache[:BATCH_SIZE, : PREFILL_SEQ_LEN + SEQ_LEN, :, :]
|
||||
k_ref = k_cache[:BATCH_SIZE, : PREFILL_SEQ_LEN + SEQ_LEN, :, :].clone()
|
||||
v_ref = v_cache[:BATCH_SIZE, : PREFILL_SEQ_LEN + SEQ_LEN, :, :].clone()
|
||||
k_ref[:, PREFILL_SEQ_LEN : PREFILL_SEQ_LEN + SEQ_LEN, :, :] = k.view(
|
||||
BATCH_SIZE, SEQ_LEN, N_HEADS, D_HEAD
|
||||
)
|
||||
@ -346,12 +350,9 @@ def test_flashinfer_attention_context_and_generate(
|
||||
k_1 = torch.randn(BATCH_SIZE, PREFILL_SEQ_LEN, N_HEADS * D_HEAD, dtype=DTYPE).to(device)
|
||||
v_1 = torch.randn(BATCH_SIZE, PREFILL_SEQ_LEN, N_HEADS * D_HEAD, dtype=DTYPE).to(device)
|
||||
|
||||
# Setup KV Cache. KV cache is empty, context phase
|
||||
k_cache = torch.zeros(
|
||||
(MAX_BATCH_SIZE, MAX_SEQ_LEN, N_HEADS, D_HEAD), dtype=DTYPE, device=device
|
||||
)
|
||||
v_cache = torch.zeros(
|
||||
(MAX_BATCH_SIZE, MAX_SEQ_LEN, N_HEADS, D_HEAD), dtype=DTYPE, device=device
|
||||
# Setup KV Cache in HND layout: [num_blocks, 2, num_heads, tokens_per_block, head_dim]
|
||||
kv_cache = torch.zeros(
|
||||
(MAX_BATCH_SIZE, 2, N_HEADS, MAX_SEQ_LEN, D_HEAD), dtype=DTYPE, device=device
|
||||
)
|
||||
|
||||
# make sure planner is initialized
|
||||
@ -360,7 +361,7 @@ def test_flashinfer_attention_context_and_generate(
|
||||
batch_indices, positions = flashinfer.get_batch_indices_positions(
|
||||
qo_indptr,
|
||||
flashinfer.get_seq_lens(
|
||||
paged_kv_indptr, paged_kv_last_page_len, page_size=k_cache.shape[1]
|
||||
paged_kv_indptr, paged_kv_last_page_len, page_size=kv_cache.shape[3]
|
||||
),
|
||||
BATCH_SIZE * PREFILL_SEQ_LEN,
|
||||
)
|
||||
@ -385,19 +386,23 @@ def test_flashinfer_attention_context_and_generate(
|
||||
# EXTRA METADATA
|
||||
batch_indices,
|
||||
positions,
|
||||
# CACHES
|
||||
k_cache,
|
||||
v_cache,
|
||||
# CACHES - combined KV cache in HND layout
|
||||
kv_cache,
|
||||
# CONSTANTS
|
||||
None,
|
||||
1.0,
|
||||
1.0,
|
||||
)
|
||||
|
||||
# Extract k_cache and v_cache for reference (convert HND to NHD)
|
||||
# Need .contiguous() after permute() for later operations
|
||||
k_cache = kv_cache[:, 0, :, :, :].permute(0, 2, 1, 3).contiguous() # [batch, seq, heads, dim]
|
||||
v_cache = kv_cache[:, 1, :, :, :].permute(0, 2, 1, 3).contiguous()
|
||||
|
||||
# Generate reference outputs
|
||||
q_ref = q_1
|
||||
k_ref = k_cache[:BATCH_SIZE, 0:PREFILL_SEQ_LEN, :, :]
|
||||
v_ref = v_cache[:BATCH_SIZE, 0:PREFILL_SEQ_LEN, :, :]
|
||||
k_ref = k_cache[:BATCH_SIZE, 0:PREFILL_SEQ_LEN, :, :].clone()
|
||||
v_ref = v_cache[:BATCH_SIZE, 0:PREFILL_SEQ_LEN, :, :].clone()
|
||||
|
||||
# Use torch backend as clean reference
|
||||
ref = TorchAttentionReference.basic_mha_with_cache(
|
||||
@ -419,6 +424,7 @@ def test_flashinfer_attention_context_and_generate(
|
||||
rtol=1e-2,
|
||||
)
|
||||
|
||||
# Update k_cache view (which reflects changes to kv_cache)
|
||||
k_cache[:BATCH_SIZE, 0:PREFILL_SEQ_LEN, :, :] = k_ref
|
||||
|
||||
# Generate output
|
||||
@ -450,7 +456,7 @@ def test_flashinfer_attention_context_and_generate(
|
||||
batch_indices, positions = flashinfer.get_batch_indices_positions(
|
||||
qo_indptr,
|
||||
flashinfer.get_seq_lens(
|
||||
paged_kv_indptr, paged_kv_last_page_len, page_size=k_cache.shape[1]
|
||||
paged_kv_indptr, paged_kv_last_page_len, page_size=kv_cache.shape[3]
|
||||
),
|
||||
BATCH_SIZE * 1,
|
||||
)
|
||||
@ -473,19 +479,23 @@ def test_flashinfer_attention_context_and_generate(
|
||||
# EXTRA METADATA
|
||||
batch_indices,
|
||||
positions,
|
||||
# CACHES
|
||||
k_cache,
|
||||
v_cache,
|
||||
# CACHES - combined KV cache in HND layout
|
||||
kv_cache,
|
||||
# CONSTANTS
|
||||
None,
|
||||
1.0,
|
||||
1.0,
|
||||
)
|
||||
|
||||
# Re-extract k_cache and v_cache (may have changed)
|
||||
# Need .contiguous() after permute() for later operations
|
||||
k_cache = kv_cache[:, 0, :, :, :].permute(0, 2, 1, 3).contiguous()
|
||||
v_cache = kv_cache[:, 1, :, :, :].permute(0, 2, 1, 3).contiguous()
|
||||
|
||||
# Generate reference outputs
|
||||
q_ref = torch.cat([q_1, q_3], dim=-2)
|
||||
k_ref = k_cache[:BATCH_SIZE, PREFILL_SEQ_LEN : PREFILL_SEQ_LEN + 1, :, :]
|
||||
v_ref = v_cache[:BATCH_SIZE, PREFILL_SEQ_LEN : PREFILL_SEQ_LEN + 1, :, :]
|
||||
k_ref = k_cache[:BATCH_SIZE, PREFILL_SEQ_LEN : PREFILL_SEQ_LEN + 1, :, :].clone()
|
||||
v_ref = v_cache[:BATCH_SIZE, PREFILL_SEQ_LEN : PREFILL_SEQ_LEN + 1, :, :].clone()
|
||||
|
||||
ref = torch.nn.functional.scaled_dot_product_attention(
|
||||
q_3.view(BATCH_SIZE, 1, N_HEADS, D_HEAD).transpose(1, 2),
|
||||
@ -555,23 +565,21 @@ def test_flashinfer_attention_op_context_input_pos(seq, batch_size, n_heads, dty
|
||||
k = torch.randn(BATCH_SIZE, SEQ_LEN, N_HEADS * D_HEAD, dtype=DTYPE).to(device)
|
||||
v = torch.randn(BATCH_SIZE, SEQ_LEN, N_HEADS * D_HEAD, dtype=DTYPE).to(device)
|
||||
|
||||
# Setup KV Cache. KV cache is partially filled from the context phase
|
||||
k_cache = torch.zeros(
|
||||
(MAX_BATCH_SIZE, MAX_SEQ_LEN, N_HEADS, D_HEAD), dtype=DTYPE, device=device
|
||||
)
|
||||
v_cache = torch.zeros(
|
||||
(MAX_BATCH_SIZE, MAX_SEQ_LEN, N_HEADS, D_HEAD), dtype=DTYPE, device=device
|
||||
# Setup KV Cache in HND layout: [num_blocks, 2, num_heads, tokens_per_block, head_dim]
|
||||
kv_cache = torch.zeros(
|
||||
(MAX_BATCH_SIZE, 2, N_HEADS, MAX_SEQ_LEN, D_HEAD), dtype=DTYPE, device=device
|
||||
)
|
||||
|
||||
# Initialize the prefilled portion of the cache with random data
|
||||
# This simulates a chunked prefill scenario where previous chunks have already
|
||||
# populated the cache at positions 0:PREFILL_SEQ_LEN
|
||||
if PREFILL_SEQ_LEN > 0:
|
||||
k_cache[0:BATCH_SIZE, 0:PREFILL_SEQ_LEN, :, :] = torch.randn(
|
||||
BATCH_SIZE, PREFILL_SEQ_LEN, N_HEADS, D_HEAD, dtype=DTYPE, device=device
|
||||
# HND layout: [batch, kv, heads, seq, dim] - fill k and v separately
|
||||
kv_cache[0:BATCH_SIZE, 0, :, 0:PREFILL_SEQ_LEN, :] = torch.randn(
|
||||
BATCH_SIZE, N_HEADS, PREFILL_SEQ_LEN, D_HEAD, dtype=DTYPE, device=device
|
||||
)
|
||||
v_cache[0:BATCH_SIZE, 0:PREFILL_SEQ_LEN, :, :] = torch.randn(
|
||||
BATCH_SIZE, PREFILL_SEQ_LEN, N_HEADS, D_HEAD, dtype=DTYPE, device=device
|
||||
kv_cache[0:BATCH_SIZE, 1, :, 0:PREFILL_SEQ_LEN, :] = torch.randn(
|
||||
BATCH_SIZE, N_HEADS, PREFILL_SEQ_LEN, D_HEAD, dtype=DTYPE, device=device
|
||||
)
|
||||
|
||||
# make sure planner is initialized
|
||||
@ -580,7 +588,7 @@ def test_flashinfer_attention_op_context_input_pos(seq, batch_size, n_heads, dty
|
||||
batch_indices, positions = flashinfer.get_batch_indices_positions(
|
||||
qo_indptr,
|
||||
flashinfer.get_seq_lens(
|
||||
paged_kv_indptr, paged_kv_last_page_len, page_size=k_cache.shape[1]
|
||||
paged_kv_indptr, paged_kv_last_page_len, page_size=kv_cache.shape[3]
|
||||
),
|
||||
BATCH_SIZE * SEQ_LEN,
|
||||
)
|
||||
@ -605,22 +613,26 @@ def test_flashinfer_attention_op_context_input_pos(seq, batch_size, n_heads, dty
|
||||
# EXTRA METADATA
|
||||
batch_indices,
|
||||
positions,
|
||||
# CACHES
|
||||
k_cache,
|
||||
v_cache,
|
||||
# CACHES - combined KV cache in HND layout
|
||||
kv_cache,
|
||||
# CONSTANTS
|
||||
None,
|
||||
1.0,
|
||||
1.0,
|
||||
)
|
||||
|
||||
# Extract k_cache and v_cache for reference (convert HND to NHD)
|
||||
# Need .contiguous() after permute() for later operations
|
||||
k_cache = kv_cache[:, 0, :, :, :].permute(0, 2, 1, 3).contiguous() # [batch, seq, heads, dim]
|
||||
v_cache = kv_cache[:, 1, :, :, :].permute(0, 2, 1, 3).contiguous()
|
||||
|
||||
# Generate ref
|
||||
q_ref = q.view(BATCH_SIZE, SEQ_LEN, N_HEADS, D_HEAD)
|
||||
k_ref = k_cache[0:BATCH_SIZE, 0 : PREFILL_SEQ_LEN + SEQ_LEN, :, :]
|
||||
v_ref = v_cache[0:BATCH_SIZE, 0 : PREFILL_SEQ_LEN + SEQ_LEN, :, :]
|
||||
k_ref = k_cache[0:BATCH_SIZE, 0 : PREFILL_SEQ_LEN + SEQ_LEN, :, :].contiguous()
|
||||
v_ref = v_cache[0:BATCH_SIZE, 0 : PREFILL_SEQ_LEN + SEQ_LEN, :, :].contiguous()
|
||||
|
||||
q_ref = q_ref.view(BATCH_SIZE, SEQ_LEN, N_HEADS, D_HEAD)
|
||||
k_ref = k_ref.view(BATCH_SIZE, PREFILL_SEQ_LEN + SEQ_LEN, N_HEADS, D_HEAD)
|
||||
k_ref = k_ref.reshape(BATCH_SIZE, PREFILL_SEQ_LEN + SEQ_LEN, N_HEADS, D_HEAD)
|
||||
mask = torch.cat(
|
||||
[
|
||||
torch.ones(SEQ_LEN, PREFILL_SEQ_LEN, device=device, dtype=torch.bool),
|
||||
@ -695,24 +707,21 @@ def test_flashinfer_attention_with_fp8_cache(
|
||||
k = torch.randn(BATCH_SIZE, SEQ_LEN, N_HEADS * D_HEAD, dtype=DTYPE).to(device)
|
||||
v = torch.randn(BATCH_SIZE, SEQ_LEN, N_HEADS * D_HEAD, dtype=DTYPE).to(device)
|
||||
|
||||
# Setup KV Cache. KV cache is empty, context phase
|
||||
k_cache = torch.zeros(
|
||||
(MAX_BATCH_SIZE, MAX_SEQ_LEN, N_HEADS, D_HEAD), dtype=DTYPE, device=device
|
||||
)
|
||||
v_cache = torch.zeros(
|
||||
(MAX_BATCH_SIZE, MAX_SEQ_LEN, N_HEADS, D_HEAD), dtype=DTYPE, device=device
|
||||
# Setup KV Cache in HND layout: [num_blocks, 2, num_heads, tokens_per_block, head_dim]
|
||||
kv_cache = torch.zeros(
|
||||
(MAX_BATCH_SIZE, 2, N_HEADS, MAX_SEQ_LEN, D_HEAD), dtype=DTYPE, device=device
|
||||
)
|
||||
|
||||
if PREFILL_SEQ_LEN > 0:
|
||||
k_cache[0:BATCH_SIZE, 0:PREFILL_SEQ_LEN, :, :] = torch.randn(
|
||||
BATCH_SIZE, PREFILL_SEQ_LEN, N_HEADS, D_HEAD
|
||||
# HND layout: [batch, kv, heads, seq, dim]
|
||||
kv_cache[0:BATCH_SIZE, 0, :, 0:PREFILL_SEQ_LEN, :] = torch.randn(
|
||||
BATCH_SIZE, N_HEADS, PREFILL_SEQ_LEN, D_HEAD
|
||||
)
|
||||
v_cache[0:BATCH_SIZE, 0:PREFILL_SEQ_LEN, :, :] = torch.randn(
|
||||
BATCH_SIZE, PREFILL_SEQ_LEN, N_HEADS, D_HEAD
|
||||
kv_cache[0:BATCH_SIZE, 1, :, 0:PREFILL_SEQ_LEN, :] = torch.randn(
|
||||
BATCH_SIZE, N_HEADS, PREFILL_SEQ_LEN, D_HEAD
|
||||
)
|
||||
|
||||
k_cache = k_cache / K_SCALE
|
||||
v_cache = v_cache / V_SCALE
|
||||
kv_cache = kv_cache / K_SCALE
|
||||
|
||||
# Set causal mask to false if its a partially filled kv_cache
|
||||
causal = False
|
||||
@ -730,8 +739,7 @@ def test_flashinfer_attention_with_fp8_cache(
|
||||
causal = True
|
||||
mask = None
|
||||
|
||||
k_cache = k_cache.to(torch.float8_e4m3fn)
|
||||
v_cache = v_cache.to(torch.float8_e4m3fn)
|
||||
kv_cache = kv_cache.to(torch.float8_e4m3fn)
|
||||
|
||||
# make sure planner is initialized
|
||||
_GlobalFlashInferPlanner.reset(torch.device(device))
|
||||
@ -739,7 +747,7 @@ def test_flashinfer_attention_with_fp8_cache(
|
||||
batch_indices, positions = flashinfer.get_batch_indices_positions(
|
||||
qo_indptr,
|
||||
flashinfer.get_seq_lens(
|
||||
paged_kv_indptr, paged_kv_last_page_len, page_size=k_cache.shape[1]
|
||||
paged_kv_indptr, paged_kv_last_page_len, page_size=kv_cache.shape[3]
|
||||
),
|
||||
BATCH_SIZE * SEQ_LEN,
|
||||
)
|
||||
@ -764,9 +772,8 @@ def test_flashinfer_attention_with_fp8_cache(
|
||||
# EXTRA METADATA
|
||||
batch_indices,
|
||||
positions,
|
||||
# CACHES
|
||||
k_cache,
|
||||
v_cache,
|
||||
# CACHES - combined KV cache in HND layout
|
||||
kv_cache,
|
||||
# CONSTANTS
|
||||
None,
|
||||
K_SCALE,
|
||||
@ -777,7 +784,7 @@ def test_flashinfer_attention_with_fp8_cache(
|
||||
q = q.view(BATCH_SIZE, SEQ_LEN, N_HEADS, D_HEAD)
|
||||
|
||||
ref = _attention_with_fp8_kv_cache(
|
||||
q, k, v, k_cache, v_cache, K_SCALE, V_SCALE, PREFILL_SEQ_LEN, causal, mask
|
||||
q, k, v, kv_cache, K_SCALE, V_SCALE, PREFILL_SEQ_LEN, causal, mask
|
||||
)
|
||||
|
||||
assert torch.allclose(
|
||||
@ -809,9 +816,10 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de
|
||||
k = torch.randn(1, SEQ_LEN, N_HEADS * D_HEAD, dtype=DTYPE).to(device)
|
||||
v = torch.randn(1, SEQ_LEN, N_HEADS * D_HEAD, dtype=DTYPE).to(device)
|
||||
|
||||
# Setup KV Cache. KV cache is empty, context phase
|
||||
k_cache = torch.zeros((MAX_NUM_PAGES, PAGE_SIZE, N_HEADS, D_HEAD), dtype=DTYPE, device=device)
|
||||
v_cache = torch.zeros((MAX_NUM_PAGES, PAGE_SIZE, N_HEADS, D_HEAD), dtype=DTYPE, device=device)
|
||||
# Setup KV Cache in HND layout: [num_pages, 2, num_heads, page_size, head_dim]
|
||||
kv_cache = torch.zeros(
|
||||
(MAX_NUM_PAGES, 2, N_HEADS, PAGE_SIZE, D_HEAD), dtype=DTYPE, device=device
|
||||
)
|
||||
offsets = torch.zeros(BATCH_SIZE, device=device, dtype=torch.int)
|
||||
|
||||
# assign pages
|
||||
@ -847,7 +855,7 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de
|
||||
batch_indices, positions = flashinfer.get_batch_indices_positions(
|
||||
qo_indptr,
|
||||
flashinfer.get_seq_lens(
|
||||
paged_kv_indptr, paged_kv_last_page_len, page_size=k_cache.shape[1]
|
||||
paged_kv_indptr, paged_kv_last_page_len, page_size=kv_cache.shape[3]
|
||||
),
|
||||
SEQ_LEN,
|
||||
)
|
||||
@ -870,9 +878,8 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de
|
||||
# EXTRA METADATA
|
||||
batch_indices,
|
||||
positions,
|
||||
# CACHES
|
||||
k_cache,
|
||||
v_cache,
|
||||
# CACHES - combined KV cache in HND layout
|
||||
kv_cache,
|
||||
# CONSTANTS
|
||||
None,
|
||||
1.0,
|
||||
@ -941,7 +948,7 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de
|
||||
batch_indices, positions = flashinfer.get_batch_indices_positions(
|
||||
qo_indptr2,
|
||||
flashinfer.get_seq_lens(
|
||||
paged_kv_indptr2, paged_kv_last_page_len2, page_size=k_cache.shape[1]
|
||||
paged_kv_indptr2, paged_kv_last_page_len2, page_size=kv_cache.shape[3]
|
||||
),
|
||||
BATCH_SIZE * 1,
|
||||
)
|
||||
@ -964,9 +971,8 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de
|
||||
# EXTRA METADATA
|
||||
batch_indices,
|
||||
positions,
|
||||
# CACHES
|
||||
k_cache,
|
||||
v_cache,
|
||||
# CACHES - combined KV cache in HND layout
|
||||
kv_cache,
|
||||
# CONSTANTS
|
||||
None,
|
||||
1.0,
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
"""Unit tests for ResourceHandler classes in attention_interface.py.
|
||||
|
||||
Tests the new resource handler abstraction for cache management:
|
||||
- PagedResourceHandler (for paged KV caches)
|
||||
- KVPagedResourceHandler (for paged KV caches)
|
||||
- StateResourceHandler (for SSM/conv states)
|
||||
- UnpagedResourceHandler (for unpaged local caches)
|
||||
- AttentionDescriptor.resolve_cache_dtype()
|
||||
@ -12,8 +12,7 @@ import torch
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import (
|
||||
AttentionDescriptor,
|
||||
ManagedResourceHandler,
|
||||
PagedResourceHandler,
|
||||
KVPagedResourceHandler,
|
||||
ResourceHandler,
|
||||
SequenceInfo,
|
||||
StateResourceHandler,
|
||||
@ -21,52 +20,68 @@ from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import (
|
||||
)
|
||||
|
||||
# =============================================================================
|
||||
# PagedResourceHandler Tests
|
||||
# KVPagedResourceHandler Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_paged_handler_stores_token_shape_and_dtype():
|
||||
"""Verify PagedResourceHandler stores token_shape and dtype correctly."""
|
||||
handler = PagedResourceHandler(8, 64, dtype=torch.float16)
|
||||
assert handler.token_shape == (8, 64)
|
||||
assert handler.dtype == torch.float16
|
||||
|
||||
|
||||
def test_paged_handler_single_dimension_token_shape():
|
||||
"""Test PagedResourceHandler with single dimension token shape."""
|
||||
handler = PagedResourceHandler(128, dtype=torch.bfloat16)
|
||||
assert handler.token_shape == (128,)
|
||||
def test_paged_handler_with_nhd_layout():
|
||||
"""Test KVPagedResourceHandler with NHD layout."""
|
||||
handler = KVPagedResourceHandler(8, 64, dtype=torch.bfloat16, kv_layout="NHD")
|
||||
assert handler.num_kv_heads == 8
|
||||
assert handler.head_dim == 64
|
||||
assert handler.dtype == torch.bfloat16
|
||||
assert handler.kv_layout == "NHD"
|
||||
|
||||
|
||||
def test_paged_handler_multi_dimension_token_shape():
|
||||
"""Test PagedResourceHandler with multiple dimension token shape."""
|
||||
handler = PagedResourceHandler(4, 8, 16, dtype=torch.float32)
|
||||
assert handler.token_shape == (4, 8, 16)
|
||||
def test_paged_handler_with_hnd_layout():
|
||||
"""Test KVPagedResourceHandler with explicit HND layout."""
|
||||
handler = KVPagedResourceHandler(4, 128, dtype=torch.float32, kv_layout="HND")
|
||||
assert handler.num_kv_heads == 4
|
||||
assert handler.head_dim == 128
|
||||
assert handler.dtype == torch.float32
|
||||
assert handler.kv_layout == "HND"
|
||||
|
||||
|
||||
def test_paged_handler_allocate_raises_not_implemented():
|
||||
"""Verify PagedResourceHandler.allocate() raises NotImplementedError."""
|
||||
handler = PagedResourceHandler(8, 64, dtype=torch.float16)
|
||||
seq_info = SequenceInfo(max_seq_len=128, max_batch_size=4)
|
||||
@pytest.mark.parametrize("kv_layout", ["HND", "NHD"])
|
||||
def test_paged_handler_allocate_with_blocks(kv_layout):
|
||||
"""Verify KVPagedResourceHandler.allocate() returns correct shape."""
|
||||
handler = KVPagedResourceHandler(8, 64, dtype=torch.float16, kv_layout=kv_layout)
|
||||
tokens_per_block = 32
|
||||
seq_info = SequenceInfo(max_seq_len=128, max_batch_size=4, tokens_per_block=tokens_per_block)
|
||||
seq_info.to("cuda")
|
||||
# Set up num_blocks via estimate_cache_loc_capacity
|
||||
seq_info.estimate_cache_loc_capacity(num_blocks=10)
|
||||
|
||||
with pytest.raises(NotImplementedError, match="Managed resources should not be allocated"):
|
||||
handler.allocate(seq_info)
|
||||
tensor = handler.allocate(seq_info)
|
||||
|
||||
if kv_layout == "HND":
|
||||
expected_shape = (
|
||||
10,
|
||||
2,
|
||||
8,
|
||||
tokens_per_block,
|
||||
64,
|
||||
) # [num_blocks, 2, num_kv_heads, tokens_per_block, head_dim]
|
||||
else: # NHD
|
||||
expected_shape = (
|
||||
10,
|
||||
tokens_per_block,
|
||||
2,
|
||||
8,
|
||||
64,
|
||||
) # [num_blocks, tokens_per_block, 2, num_kv_heads, head_dim]
|
||||
|
||||
assert tensor.shape == expected_shape
|
||||
assert tensor.dtype == torch.float16
|
||||
assert tensor.device.type == "cuda"
|
||||
|
||||
|
||||
def test_paged_handler_is_resource_handler():
|
||||
"""Verify PagedResourceHandler is a ResourceHandler subclass."""
|
||||
handler = PagedResourceHandler(8, 64, dtype=torch.float16)
|
||||
"""Verify KVPagedResourceHandler is a ResourceHandler subclass."""
|
||||
handler = KVPagedResourceHandler(8, 64, dtype=torch.float16)
|
||||
assert isinstance(handler, ResourceHandler)
|
||||
|
||||
|
||||
def test_paged_handler_is_managed_resource():
|
||||
"""Verify PagedResourceHandler is a ManagedResourceHandler."""
|
||||
handler = PagedResourceHandler(8, 64, dtype=torch.float16)
|
||||
assert isinstance(handler, ManagedResourceHandler)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# StateResourceHandler Tests
|
||||
# =============================================================================
|
||||
@ -100,13 +115,19 @@ def test_state_handler_ssm_state_shape():
|
||||
assert handler.dtype == torch.float32
|
||||
|
||||
|
||||
def test_state_handler_allocate_raises_not_implemented():
|
||||
"""Verify StateResourceHandler.allocate() raises NotImplementedError."""
|
||||
def test_state_handler_allocate_creates_tensor():
|
||||
"""Verify StateResourceHandler.allocate() creates tensor with correct shape."""
|
||||
handler = StateResourceHandler(4, 64, 16, dtype=torch.bfloat16)
|
||||
seq_info = SequenceInfo(max_seq_len=128, max_batch_size=4)
|
||||
seq_info.to("cuda")
|
||||
|
||||
with pytest.raises(NotImplementedError, match="Managed resources should not be allocated"):
|
||||
handler.allocate(seq_info)
|
||||
tensor = handler.allocate(seq_info)
|
||||
|
||||
# Shape: [max_num_state_slots, *state_shape]
|
||||
expected_shape = (seq_info.max_num_state_slots, 4, 64, 16)
|
||||
assert tensor.shape == expected_shape
|
||||
assert tensor.dtype == torch.bfloat16
|
||||
assert tensor.device.type == "cuda"
|
||||
|
||||
|
||||
def test_state_handler_is_resource_handler():
|
||||
@ -115,12 +136,6 @@ def test_state_handler_is_resource_handler():
|
||||
assert isinstance(handler, ResourceHandler)
|
||||
|
||||
|
||||
def test_state_handler_is_managed_resource():
|
||||
"""Verify StateResourceHandler is a ManagedResourceHandler."""
|
||||
handler = StateResourceHandler(4, 64, 16, dtype=torch.bfloat16)
|
||||
assert isinstance(handler, ManagedResourceHandler)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# UnpagedResourceHandler Tests
|
||||
# =============================================================================
|
||||
@ -181,12 +196,6 @@ def test_unpaged_handler_is_resource_handler():
|
||||
assert isinstance(handler, ResourceHandler)
|
||||
|
||||
|
||||
def test_unpaged_handler_is_not_managed_resource():
|
||||
"""Verify UnpagedResourceHandler is NOT a ManagedResourceHandler."""
|
||||
handler = UnpagedResourceHandler(8, 64, dtype=torch.float16)
|
||||
assert not isinstance(handler, ManagedResourceHandler)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# AttentionDescriptor.resolve_cache_dtype() Tests
|
||||
# =============================================================================
|
||||
@ -235,3 +244,89 @@ def test_resolve_cache_dtype_explicit_fp8():
|
||||
"""Test explicit 'fp8' dtype string resolves correctly."""
|
||||
result = AttentionDescriptor.resolve_cache_dtype("fp8", torch.float16)
|
||||
assert result == torch.float8_e4m3fn
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Resource Handler __eq__ Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_kv_paged_handler_eq_same_head_dim_dtype():
|
||||
"""Verify KVPagedResourceHandler __eq__ checks head_dim and dtype."""
|
||||
h1 = KVPagedResourceHandler(8, 64, dtype=torch.float16)
|
||||
h2 = KVPagedResourceHandler(4, 64, dtype=torch.float16) # Different num_kv_heads
|
||||
h3 = KVPagedResourceHandler(8, 64, dtype=torch.float16, kv_layout="NHD") # Different layout
|
||||
|
||||
# head_dim, kv_factor, dtype, kv_layout -> equal (num_kv_heads doesn't matter for compatibility)
|
||||
assert h1 == h2
|
||||
assert h1 != h3
|
||||
|
||||
|
||||
def test_kv_paged_handler_eq_different_head_dim_or_dtype():
|
||||
"""Verify KVPagedResourceHandler __eq__ returns False for different head_dim or dtype."""
|
||||
h1 = KVPagedResourceHandler(8, 64, dtype=torch.float16)
|
||||
h2 = KVPagedResourceHandler(8, 128, dtype=torch.float16) # Different head_dim
|
||||
h3 = KVPagedResourceHandler(8, 64, dtype=torch.bfloat16) # Different dtype
|
||||
|
||||
assert h1 != h2
|
||||
assert h1 != h3
|
||||
|
||||
|
||||
def test_ssm_handler_eq_same_params():
|
||||
"""Verify SSMResourceHandler __eq__ for same parameters."""
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import SSMResourceHandler
|
||||
|
||||
h1 = SSMResourceHandler(num_heads=8, head_dim=64, d_state=16, dtype=torch.bfloat16)
|
||||
h2 = SSMResourceHandler(num_heads=8, head_dim=64, d_state=16, dtype=torch.bfloat16)
|
||||
|
||||
assert h1 == h2
|
||||
|
||||
|
||||
def test_ssm_handler_eq_different_params():
|
||||
"""Verify SSMResourceHandler __eq__ returns False for different parameters."""
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import SSMResourceHandler
|
||||
|
||||
h1 = SSMResourceHandler(num_heads=8, head_dim=64, d_state=16, dtype=torch.bfloat16)
|
||||
h2 = SSMResourceHandler(
|
||||
num_heads=4, head_dim=64, d_state=16, dtype=torch.bfloat16
|
||||
) # diff heads
|
||||
h3 = SSMResourceHandler(
|
||||
num_heads=8, head_dim=128, d_state=16, dtype=torch.bfloat16
|
||||
) # diff head_dim
|
||||
h4 = SSMResourceHandler(
|
||||
num_heads=8, head_dim=64, d_state=32, dtype=torch.bfloat16
|
||||
) # diff d_state
|
||||
h5 = SSMResourceHandler(num_heads=8, head_dim=64, d_state=16, dtype=torch.float32) # diff dtype
|
||||
|
||||
assert h1 != h2
|
||||
assert h1 != h3
|
||||
assert h1 != h4
|
||||
assert h1 != h5
|
||||
|
||||
|
||||
def test_conv_handler_eq_same_params():
|
||||
"""Verify CausalConvResourceHandler __eq__ for same parameters."""
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import (
|
||||
CausalConvResourceHandler,
|
||||
)
|
||||
|
||||
h1 = CausalConvResourceHandler(conv_dim=256, d_conv=4, dtype=torch.float32)
|
||||
h2 = CausalConvResourceHandler(conv_dim=256, d_conv=4, dtype=torch.float32)
|
||||
|
||||
assert h1 == h2
|
||||
|
||||
|
||||
def test_conv_handler_eq_different_params():
|
||||
"""Verify CausalConvResourceHandler __eq__ returns False for different parameters."""
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import (
|
||||
CausalConvResourceHandler,
|
||||
)
|
||||
|
||||
h1 = CausalConvResourceHandler(conv_dim=256, d_conv=4, dtype=torch.float32)
|
||||
h2 = CausalConvResourceHandler(conv_dim=512, d_conv=4, dtype=torch.float32) # diff conv_dim
|
||||
h3 = CausalConvResourceHandler(conv_dim=256, d_conv=5, dtype=torch.float32) # diff d_conv
|
||||
h4 = CausalConvResourceHandler(conv_dim=256, d_conv=4, dtype=torch.bfloat16) # diff dtype
|
||||
|
||||
assert h1 != h2
|
||||
assert h1 != h3
|
||||
assert h1 != h4
|
||||
|
||||
@ -23,7 +23,7 @@ def test_update_kv_cache():
|
||||
print("k_cache: " + str(k_cache))
|
||||
print("v_cache: " + str(v_cache))
|
||||
print("input_pos: " + str(torch.tensor([0, 0])))
|
||||
print("cache_loc: " + str(torch.tensor([0, 1])))
|
||||
print("slot_idx: " + str(torch.tensor([0, 1])))
|
||||
print("seq_start: " + str(torch.tensor([0, 3])))
|
||||
|
||||
update_kv_cache(
|
||||
@ -33,7 +33,7 @@ def test_update_kv_cache():
|
||||
v_cache,
|
||||
torch.tensor([3, 1]).long(),
|
||||
torch.tensor([0, 0]),
|
||||
cache_loc=torch.tensor([0, 1]),
|
||||
slot_idx=torch.tensor([0, 1]),
|
||||
seq_start=torch.tensor([0, 3]).long(),
|
||||
)
|
||||
|
||||
|
||||
@ -52,8 +52,8 @@ def test_update_kv_cache(k_d_head, v_d_head, seq_lens, dtype):
|
||||
MAX_SEQ_LEN = 64
|
||||
MAX_BATCH_SIZE = 16
|
||||
SEQ_LENS = seq_lens
|
||||
CACHE_LOCS = list(range(0, len(SEQ_LENS)))
|
||||
random.shuffle(CACHE_LOCS)
|
||||
SLOT_IDX = list(range(0, len(SEQ_LENS)))
|
||||
random.shuffle(SLOT_IDX)
|
||||
INPUT_POS = [
|
||||
random.randint(0, 16) for _ in range(len(SEQ_LENS))
|
||||
] # The starting position for in the cache for each of the sequences.
|
||||
@ -89,7 +89,7 @@ def test_update_kv_cache(k_d_head, v_d_head, seq_lens, dtype):
|
||||
k_cache,
|
||||
v_cache,
|
||||
torch.tensor(INPUT_POS, device=DEVICE, dtype=torch.int32),
|
||||
torch.tensor(CACHE_LOCS, device=DEVICE, dtype=torch.int32),
|
||||
torch.tensor(SLOT_IDX, device=DEVICE, dtype=torch.int32),
|
||||
MAX_SEQ_LEN,
|
||||
N_KV_HEADS,
|
||||
K_D_HEAD,
|
||||
@ -99,13 +99,13 @@ def test_update_kv_cache(k_d_head, v_d_head, seq_lens, dtype):
|
||||
)
|
||||
|
||||
# Check if the cache was correctly updated:
|
||||
for i, cache_loc in enumerate(CACHE_LOCS):
|
||||
for i, slot_idx in enumerate(SLOT_IDX):
|
||||
assert torch.equal(
|
||||
k_cache[cache_loc, INPUT_POS[i] : INPUT_POS[i] + SEQ_LENS[i], :N_KV_HEADS, :].squeeze(),
|
||||
k_cache[slot_idx, INPUT_POS[i] : INPUT_POS[i] + SEQ_LENS[i], :N_KV_HEADS, :].squeeze(),
|
||||
k[i].squeeze(),
|
||||
)
|
||||
assert torch.equal(
|
||||
v_cache[cache_loc, INPUT_POS[i] : INPUT_POS[i] + SEQ_LENS[i], :N_KV_HEADS, :].squeeze(),
|
||||
v_cache[slot_idx, INPUT_POS[i] : INPUT_POS[i] + SEQ_LENS[i], :N_KV_HEADS, :].squeeze(),
|
||||
v[i].squeeze(),
|
||||
)
|
||||
|
||||
@ -118,8 +118,8 @@ def test_attention_kv_flash_decoding(d_head):
|
||||
N_HEADS = 1
|
||||
D_HEAD = d_head
|
||||
MAX_SEQ_LEN = 64
|
||||
CACHE_LOCS = list(range(0, BATCH_SIZE))
|
||||
random.shuffle(CACHE_LOCS)
|
||||
SLOT_IDX = list(range(0, BATCH_SIZE))
|
||||
random.shuffle(SLOT_IDX)
|
||||
INPUT_POS = [0] * BATCH_SIZE
|
||||
# Q,K,V are computed using GEMM.
|
||||
q = torch.randn(BATCH_SIZE, 1, N_HEADS, D_HEAD, dtype=DTYPE, device=DEVICE)
|
||||
@ -137,7 +137,7 @@ def test_attention_kv_flash_decoding(d_head):
|
||||
k_cache,
|
||||
v_cache,
|
||||
torch.tensor(INPUT_POS, device=DEVICE, dtype=torch.int32),
|
||||
torch.tensor(CACHE_LOCS, device=DEVICE, dtype=torch.int32),
|
||||
torch.tensor(SLOT_IDX, device=DEVICE, dtype=torch.int32),
|
||||
MAX_SEQ_LEN,
|
||||
N_HEADS,
|
||||
D_HEAD,
|
||||
@ -165,7 +165,7 @@ def test_attention_kv_flash_decoding(d_head):
|
||||
q,
|
||||
k_cache,
|
||||
v_cache,
|
||||
torch.tensor(CACHE_LOCS, device=DEVICE, dtype=torch.int32),
|
||||
torch.tensor(SLOT_IDX, device=DEVICE, dtype=torch.int32),
|
||||
torch.tensor(INPUT_POS, device=DEVICE, dtype=torch.int32),
|
||||
output_tensor,
|
||||
output_logsumexp,
|
||||
@ -185,8 +185,8 @@ def test_attention_kv_flash_decoding(d_head):
|
||||
ref.append(
|
||||
torch.nn.functional.scaled_dot_product_attention(
|
||||
q[i, :, :, :].unsqueeze(0).transpose(1, 2), # [BSND]->[BNSD]
|
||||
k_cache[CACHE_LOCS[i], 0 : INPUT_POS[i] + 1, :, :].unsqueeze(0).transpose(1, 2),
|
||||
v_cache[CACHE_LOCS[i], 0 : INPUT_POS[i] + 1, :, :].unsqueeze(0).transpose(1, 2),
|
||||
k_cache[SLOT_IDX[i], 0 : INPUT_POS[i] + 1, :, :].unsqueeze(0).transpose(1, 2),
|
||||
v_cache[SLOT_IDX[i], 0 : INPUT_POS[i] + 1, :, :].unsqueeze(0).transpose(1, 2),
|
||||
)
|
||||
)
|
||||
ref = torch.cat(ref, dim=0)
|
||||
@ -212,7 +212,7 @@ def test_gqa_attention_kv_flash_decoding(q_d_head, v_d_head, n_heads, n_kv_heads
|
||||
Q_D_HEAD = q_d_head
|
||||
V_D_HEAD = v_d_head
|
||||
MAX_SEQ_LEN = 64
|
||||
CACHE_LOCS = list(range(0, BATCH_SIZE))
|
||||
SLOT_IDX = list(range(0, BATCH_SIZE))
|
||||
INPUT_POS = [0] * BATCH_SIZE
|
||||
# Q,K,V are computed using GEMM.
|
||||
q = torch.randn(BATCH_SIZE, 1, N_HEADS, Q_D_HEAD, dtype=DTYPE, device=DEVICE)
|
||||
@ -221,7 +221,7 @@ def test_gqa_attention_kv_flash_decoding(q_d_head, v_d_head, n_heads, n_kv_heads
|
||||
k_cache = torch.randn(BATCH_SIZE, MAX_SEQ_LEN, N_KV_HEADS, Q_D_HEAD, dtype=DTYPE, device=DEVICE)
|
||||
v_cache = torch.randn(BATCH_SIZE, MAX_SEQ_LEN, N_KV_HEADS, V_D_HEAD, dtype=DTYPE, device=DEVICE)
|
||||
|
||||
cache_loc = torch.tensor(CACHE_LOCS, device=DEVICE, dtype=torch.int32)
|
||||
slot_idx = torch.tensor(SLOT_IDX, device=DEVICE, dtype=torch.int32)
|
||||
input_pos = torch.tensor(INPUT_POS, device=DEVICE, dtype=torch.int32)
|
||||
|
||||
grid = (BATCH_SIZE, N_KV_HEADS, 1) #
|
||||
@ -233,7 +233,7 @@ def test_gqa_attention_kv_flash_decoding(q_d_head, v_d_head, n_heads, n_kv_heads
|
||||
k_cache,
|
||||
v_cache,
|
||||
input_pos,
|
||||
cache_loc,
|
||||
slot_idx,
|
||||
MAX_SEQ_LEN,
|
||||
N_KV_HEADS,
|
||||
Q_D_HEAD,
|
||||
@ -263,7 +263,7 @@ def test_gqa_attention_kv_flash_decoding(q_d_head, v_d_head, n_heads, n_kv_heads
|
||||
q,
|
||||
k_cache,
|
||||
v_cache,
|
||||
cache_loc,
|
||||
slot_idx,
|
||||
input_pos,
|
||||
output_tensor,
|
||||
output_logsumexp,
|
||||
@ -285,8 +285,8 @@ def test_gqa_attention_kv_flash_decoding(q_d_head, v_d_head, n_heads, n_kv_heads
|
||||
|
||||
ref = []
|
||||
for i in range(BATCH_SIZE):
|
||||
kk = k_cache[CACHE_LOCS[i], 0 : INPUT_POS[i] + 1, :, :].unsqueeze(0)
|
||||
vv = v_cache[CACHE_LOCS[i], 0 : INPUT_POS[i] + 1, :, :].unsqueeze(0)
|
||||
kk = k_cache[SLOT_IDX[i], 0 : INPUT_POS[i] + 1, :, :].unsqueeze(0)
|
||||
vv = v_cache[SLOT_IDX[i], 0 : INPUT_POS[i] + 1, :, :].unsqueeze(0)
|
||||
|
||||
if N_HEADS != N_KV_HEADS:
|
||||
kk = repeat_kv(q[i, :, :, :].unsqueeze(0), kk)
|
||||
@ -453,8 +453,8 @@ def test_context_attention_kv_flattened(
|
||||
V_D_HEAD = v_d_head
|
||||
MAX_SEQ_LEN = 64
|
||||
SEQ_LENS = [36, 44, 12, 1, 1]
|
||||
CACHE_LOCS = list(range(0, len(SEQ_LENS)))
|
||||
random.shuffle(CACHE_LOCS)
|
||||
SLOT_IDX = list(range(0, len(SEQ_LENS)))
|
||||
random.shuffle(SLOT_IDX)
|
||||
INPUT_POS = [2, 4, 8, 0, 1] # The starting position for in the cache for each of the sequences.
|
||||
q = []
|
||||
k = []
|
||||
@ -476,10 +476,10 @@ def test_context_attention_kv_flattened(
|
||||
def compute_reference(q, k_cache, v_cache):
|
||||
ref = []
|
||||
for i in range(len(SEQ_LENS)):
|
||||
kk = k_cache[CACHE_LOCS[i], : INPUT_POS[i] + SEQ_LENS[i], :N_KV_HEADS, :].view(
|
||||
kk = k_cache[SLOT_IDX[i], : INPUT_POS[i] + SEQ_LENS[i], :N_KV_HEADS, :].view(
|
||||
1, INPUT_POS[i] + SEQ_LENS[i], N_KV_HEADS, K_D_HEAD
|
||||
)
|
||||
vv = v_cache[CACHE_LOCS[i], : INPUT_POS[i] + SEQ_LENS[i], :N_KV_HEADS, :].view(
|
||||
vv = v_cache[SLOT_IDX[i], : INPUT_POS[i] + SEQ_LENS[i], :N_KV_HEADS, :].view(
|
||||
1, INPUT_POS[i] + SEQ_LENS[i], N_KV_HEADS, V_D_HEAD
|
||||
)
|
||||
|
||||
@ -528,7 +528,7 @@ def test_context_attention_kv_flattened(
|
||||
seq_start_indices = torch.zeros(len(SEQ_LENS), device=DEVICE, dtype=torch.int32)
|
||||
seq_start_indices[1:] = torch.cumsum(seq_len[:-1], 0)
|
||||
input_pos = torch.tensor(INPUT_POS, device=DEVICE, dtype=torch.int32)
|
||||
cache_loc = torch.tensor(CACHE_LOCS, device=DEVICE, dtype=torch.int32)
|
||||
slot_idx = torch.tensor(SLOT_IDX, device=DEVICE, dtype=torch.int32)
|
||||
SEQ_BLOCK = 32
|
||||
output_tensor = torch.empty((1, sum(SEQ_LENS), N_HEADS, V_D_HEAD), dtype=DTYPE, device=DEVICE)
|
||||
grid = (len(SEQ_LENS), N_KV_HEADS, (max(SEQ_LENS) + SEQ_BLOCK - 1) // SEQ_BLOCK)
|
||||
@ -540,7 +540,7 @@ def test_context_attention_kv_flattened(
|
||||
k_cache,
|
||||
v_cache,
|
||||
input_pos,
|
||||
cache_loc,
|
||||
slot_idx,
|
||||
MAX_SEQ_LEN,
|
||||
N_KV_HEADS,
|
||||
K_D_HEAD,
|
||||
@ -550,7 +550,7 @@ def test_context_attention_kv_flattened(
|
||||
)
|
||||
|
||||
# Check if the cache was correctly updated:
|
||||
for i, cl in enumerate(CACHE_LOCS):
|
||||
for i, cl in enumerate(SLOT_IDX):
|
||||
assert torch.equal(
|
||||
k_cache[cl, INPUT_POS[i] : INPUT_POS[i] + SEQ_LENS[i], :N_KV_HEADS, :].squeeze(),
|
||||
k[i].squeeze(),
|
||||
@ -568,7 +568,7 @@ def test_context_attention_kv_flattened(
|
||||
k_cache,
|
||||
v_cache,
|
||||
input_pos,
|
||||
cache_loc,
|
||||
slot_idx,
|
||||
output_tensor,
|
||||
1.0 / math.sqrt(Q_D_HEAD),
|
||||
N_HEADS,
|
||||
@ -604,8 +604,8 @@ def test_update_kv_cache_rope_fusion(seq_lens, n_heads, n_kv_heads, dtype):
|
||||
MAX_BATCH_SIZE = 16
|
||||
SEQ_LENS = seq_lens
|
||||
BATCH_SIZE = len(SEQ_LENS)
|
||||
CACHE_LOCS = list(range(0, BATCH_SIZE))
|
||||
random.shuffle(CACHE_LOCS)
|
||||
SLOT_IDX = list(range(0, BATCH_SIZE))
|
||||
random.shuffle(SLOT_IDX)
|
||||
INPUT_POS = [
|
||||
random.randint(0, 16) for _ in range(BATCH_SIZE)
|
||||
] # The starting position for in the cache for each of the sequences.
|
||||
@ -651,7 +651,7 @@ def test_update_kv_cache_rope_fusion(seq_lens, n_heads, n_kv_heads, dtype):
|
||||
k_cache,
|
||||
v_cache,
|
||||
torch.tensor(INPUT_POS, device=DEVICE, dtype=torch.int32),
|
||||
torch.tensor(CACHE_LOCS, device=DEVICE, dtype=torch.int32),
|
||||
torch.tensor(SLOT_IDX, device=DEVICE, dtype=torch.int32),
|
||||
freqs,
|
||||
MAX_SEQ_LEN,
|
||||
N_HEADS,
|
||||
@ -697,7 +697,7 @@ def test_update_kv_cache_rope_fusion(seq_lens, n_heads, n_kv_heads, dtype):
|
||||
start = end
|
||||
|
||||
# Check if the cache was correctly updated:
|
||||
for i, cl in enumerate(CACHE_LOCS):
|
||||
for i, cl in enumerate(SLOT_IDX):
|
||||
assert torch.allclose(
|
||||
k_cache[cl, INPUT_POS[i] : INPUT_POS[i] + SEQ_LENS[i]].squeeze(),
|
||||
k_ref[i].squeeze(),
|
||||
|
||||
@ -1,537 +0,0 @@
|
||||
import math
|
||||
import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import triton
|
||||
from _custom_op_utils import torch_reference_mha_stage2
|
||||
from _model_test_utils import repeat_kv
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.triton_kernels.attention_with_paged_kv_cache import (
|
||||
attention_kv_paged_stage1,
|
||||
context_attention_kv_paged,
|
||||
update_paged_kv_cache,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"seq_lens",
|
||||
[
|
||||
[16, 8, 9, 21], # context only sequences
|
||||
[1, 1, 1, 1, 1, 1], # decode only sequences
|
||||
[5, 10, 4, 1, 1, 1], # context + decode sequences
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
|
||||
def test_update_paged_kv_cache(seq_lens, dtype):
|
||||
DEVICE = "cuda"
|
||||
DTYPE = dtype
|
||||
N_KV_HEADS = 8
|
||||
D_HEAD = 16
|
||||
MAX_SEQ_LEN = 64
|
||||
SEQ_LENS = seq_lens # 2 context
|
||||
BATCH_SIZE = len(SEQ_LENS)
|
||||
CACHE_LOCS = list(range(BATCH_SIZE))
|
||||
random.shuffle(CACHE_LOCS)
|
||||
NUM_PAGES = 256
|
||||
PAGE_SIZE = 4
|
||||
|
||||
k = []
|
||||
v = []
|
||||
for i, s in enumerate(SEQ_LENS):
|
||||
k.append(torch.randn(1, s, N_KV_HEADS, D_HEAD, dtype=DTYPE, device=DEVICE))
|
||||
v.append(torch.randn(1, s, N_KV_HEADS, D_HEAD, dtype=DTYPE, device=DEVICE))
|
||||
|
||||
(k_f, v_f) = tuple(map(lambda x: torch.cat(x, 1), (k, v)))
|
||||
|
||||
k_cache = torch.zeros(NUM_PAGES, PAGE_SIZE, N_KV_HEADS, D_HEAD, dtype=DTYPE, device=DEVICE)
|
||||
v_cache = torch.zeros(NUM_PAGES, PAGE_SIZE, N_KV_HEADS, D_HEAD, dtype=DTYPE, device=DEVICE)
|
||||
|
||||
# length of kv cache
|
||||
# if context batch then 0
|
||||
CACHE_LENS = []
|
||||
for b in range(BATCH_SIZE):
|
||||
CACHE_LENS.append(random.randint(0, 4 * PAGE_SIZE))
|
||||
|
||||
# allocate pages for kv cache
|
||||
# pages of each batch is continuous
|
||||
PAGE_TABLE = [None] * BATCH_SIZE
|
||||
PAGES_PER_SEQ = (MAX_SEQ_LEN + PAGE_SIZE - 1) // PAGE_SIZE
|
||||
cnt = 0
|
||||
for b in range(BATCH_SIZE):
|
||||
# allocate pages for history kv cache and new coming kv
|
||||
length = CACHE_LENS[b] + SEQ_LENS[b]
|
||||
allocated_pages = (length + PAGE_SIZE - 1) // PAGE_SIZE
|
||||
table = []
|
||||
for p in range(PAGES_PER_SEQ):
|
||||
if p < allocated_pages:
|
||||
table.append(cnt)
|
||||
cnt = cnt + 1
|
||||
else:
|
||||
table.append(0)
|
||||
PAGE_TABLE[CACHE_LOCS[b]] = table
|
||||
page_table = torch.tensor(PAGE_TABLE, device=DEVICE, dtype=torch.int32)
|
||||
|
||||
GENERATE_ONLY = all(s == 1 for s in SEQ_LENS)
|
||||
SEQ_BLOCK = PAGE_SIZE if GENERATE_ONLY else 32
|
||||
grid = (BATCH_SIZE, N_KV_HEADS, (max(SEQ_LENS) + SEQ_BLOCK - 1) // SEQ_BLOCK)
|
||||
|
||||
seq_len = torch.tensor(SEQ_LENS, device=DEVICE, dtype=torch.int32)
|
||||
seq_start_indices = torch.zeros(BATCH_SIZE, device=DEVICE, dtype=torch.int32)
|
||||
seq_start_indices[1:] = torch.cumsum(seq_len[:-1], 0)
|
||||
|
||||
update_paged_kv_cache[grid](
|
||||
k_f,
|
||||
v_f,
|
||||
seq_len,
|
||||
seq_start_indices,
|
||||
k_cache,
|
||||
v_cache,
|
||||
torch.tensor(CACHE_LOCS, device=DEVICE, dtype=torch.int32),
|
||||
torch.tensor(CACHE_LENS, device=DEVICE, dtype=torch.int32),
|
||||
page_table,
|
||||
N_KV_HEADS,
|
||||
D_HEAD,
|
||||
SEQ_BLOCK,
|
||||
MAX_SEQ_LEN,
|
||||
PAGE_SIZE,
|
||||
page_table.stride(0),
|
||||
GENERATE_ONLY,
|
||||
)
|
||||
|
||||
# Check if the cache was correctly updated:
|
||||
for batch, kv_batch in enumerate(CACHE_LOCS):
|
||||
batch_page_table = page_table[kv_batch]
|
||||
cache_len = CACHE_LENS[batch]
|
||||
update_len = SEQ_LENS[batch]
|
||||
if cache_len == 0:
|
||||
# context batch
|
||||
for seq_page, kv_page in enumerate(batch_page_table):
|
||||
if seq_page * PAGE_SIZE >= update_len:
|
||||
break
|
||||
start = 0
|
||||
end = min(update_len - seq_page * PAGE_SIZE, PAGE_SIZE)
|
||||
assert torch.equal(
|
||||
k_cache[kv_page, start:end].squeeze(),
|
||||
k[batch][:, seq_page * PAGE_SIZE : (seq_page + 1) * PAGE_SIZE].squeeze(),
|
||||
)
|
||||
assert torch.equal(
|
||||
v_cache[kv_page, start:end].squeeze(),
|
||||
v[batch][:, seq_page * PAGE_SIZE : (seq_page + 1) * PAGE_SIZE].squeeze(),
|
||||
)
|
||||
else:
|
||||
# decode batch, only check one token in one page
|
||||
check_page = cache_len // PAGE_SIZE
|
||||
kv_page = batch_page_table[check_page]
|
||||
start = cache_len % PAGE_SIZE
|
||||
end = start + 1
|
||||
assert torch.equal(
|
||||
k_cache[kv_page, start:end].squeeze(),
|
||||
k[batch][:, 0:1].squeeze(),
|
||||
)
|
||||
assert torch.equal(
|
||||
v_cache[kv_page, start:end].squeeze(),
|
||||
v[batch][:, 0:1].squeeze(),
|
||||
)
|
||||
|
||||
|
||||
def test_attention_kv_paged_flash_decoding():
|
||||
DEVICE = "cuda"
|
||||
DTYPE = torch.float16
|
||||
N_HEADS = 32
|
||||
D_HEAD = 32
|
||||
MAX_SEQ_LEN = 64
|
||||
NUM_PAGES = 256
|
||||
PAGE_SIZE = 4
|
||||
|
||||
CACHE_LEN = [44, 33, 18, 11, 25]
|
||||
BATCH_SIZE = len(CACHE_LEN)
|
||||
SEQ_LENS = []
|
||||
for _ in range(BATCH_SIZE):
|
||||
SEQ_LENS.append(1)
|
||||
# only use for page table index
|
||||
CACHE_LOCS = list(range(0, BATCH_SIZE))
|
||||
random.shuffle(CACHE_LOCS)
|
||||
|
||||
# Q,K,V are computed using GEMM.
|
||||
qkv = torch.randn(BATCH_SIZE, 3, N_HEADS, D_HEAD, dtype=DTYPE, device=DEVICE) * 2
|
||||
q, k, v = torch.split(qkv, [1, 1, 1], dim=1)
|
||||
q, k, v = (x.contiguous() for x in (q, k, v))
|
||||
k_cache = torch.zeros(NUM_PAGES, PAGE_SIZE, N_HEADS, D_HEAD, dtype=DTYPE, device=DEVICE)
|
||||
v_cache = torch.zeros(NUM_PAGES, PAGE_SIZE, N_HEADS, D_HEAD, dtype=DTYPE, device=DEVICE)
|
||||
|
||||
# allocate pages for kv cache
|
||||
# pages of each batch is continuous
|
||||
PAGE_TABLE = [None] * BATCH_SIZE
|
||||
PAGES_PER_SEQ = (MAX_SEQ_LEN + PAGE_SIZE - 1) // PAGE_SIZE
|
||||
cnt = 0
|
||||
for b in range(BATCH_SIZE):
|
||||
length = CACHE_LEN[b]
|
||||
allocated_pages = (length + PAGE_SIZE - 1) // PAGE_SIZE
|
||||
table = []
|
||||
for p in range(PAGES_PER_SEQ):
|
||||
if p < allocated_pages:
|
||||
table.append(cnt)
|
||||
cnt = cnt + 1
|
||||
else:
|
||||
table.append(0)
|
||||
PAGE_TABLE[CACHE_LOCS[b]] = table
|
||||
page_table = torch.tensor(PAGE_TABLE, device=DEVICE, dtype=torch.int32)
|
||||
|
||||
# prepare kv-cache
|
||||
for b in range(BATCH_SIZE):
|
||||
pages = PAGE_TABLE[CACHE_LOCS[b]]
|
||||
cache_l = CACHE_LEN[b]
|
||||
page_num = cache_l // PAGE_SIZE
|
||||
page_off = cache_l % PAGE_SIZE
|
||||
for p in range(page_num):
|
||||
k_cache[pages[p]] = torch.randn(
|
||||
(1, PAGE_SIZE, N_HEADS, D_HEAD), dtype=DTYPE, device=DEVICE
|
||||
)
|
||||
v_cache[pages[p]] = torch.randn(
|
||||
(1, PAGE_SIZE, N_HEADS, D_HEAD), dtype=DTYPE, device=DEVICE
|
||||
)
|
||||
k_cache[pages[page_num], 0:page_off] = torch.randn(
|
||||
(1, page_off, N_HEADS, D_HEAD), dtype=DTYPE, device=DEVICE
|
||||
)
|
||||
v_cache[pages[page_num], 0:page_off] = torch.randn(
|
||||
(1, page_off, N_HEADS, D_HEAD), dtype=DTYPE, device=DEVICE
|
||||
)
|
||||
|
||||
SEQ_BLOCK_SIZE = PAGE_SIZE
|
||||
# Input position 0 implies that kv-cache is empty
|
||||
num_blocks = MAX_SEQ_LEN // SEQ_BLOCK_SIZE
|
||||
output_tensor = torch.zeros(
|
||||
BATCH_SIZE, N_HEADS, num_blocks, D_HEAD, device=DEVICE, dtype=torch.float32
|
||||
)
|
||||
output_logsumexp = torch.zeros(
|
||||
BATCH_SIZE, N_HEADS, num_blocks, device=DEVICE, dtype=torch.float32
|
||||
) - float("inf")
|
||||
|
||||
grid = (BATCH_SIZE, N_HEADS, (max(SEQ_LENS) + SEQ_BLOCK_SIZE - 1) // SEQ_BLOCK_SIZE)
|
||||
|
||||
seq_len = torch.tensor(SEQ_LENS, device=DEVICE, dtype=torch.int32)
|
||||
seq_start_indices = torch.zeros(len(SEQ_LENS), device=DEVICE, dtype=torch.int32)
|
||||
seq_start_indices[1:] = torch.cumsum(seq_len[:-1], 0)
|
||||
update_paged_kv_cache[grid](
|
||||
k,
|
||||
v,
|
||||
seq_len,
|
||||
seq_start_indices,
|
||||
k_cache,
|
||||
v_cache,
|
||||
torch.tensor(CACHE_LOCS, device=DEVICE, dtype=torch.int32),
|
||||
torch.tensor(CACHE_LEN, device=DEVICE, dtype=torch.int32),
|
||||
page_table,
|
||||
N_HEADS,
|
||||
D_HEAD,
|
||||
SEQ_BLOCK_SIZE,
|
||||
MAX_SEQ_LEN,
|
||||
PAGE_SIZE,
|
||||
page_table.stride(0),
|
||||
GENERATE_ONLY=True,
|
||||
)
|
||||
|
||||
# Check if the cache was correctly updated:
|
||||
for batch, kv_batch in enumerate(CACHE_LOCS):
|
||||
batch_page_table = page_table[kv_batch]
|
||||
# decode batch, only check one token in one page
|
||||
cache_len = CACHE_LEN[batch]
|
||||
check_page = cache_len // PAGE_SIZE
|
||||
kv_page = batch_page_table[check_page]
|
||||
start = cache_len % PAGE_SIZE
|
||||
end = start + 1
|
||||
assert torch.equal(
|
||||
k_cache[kv_page, start:end].squeeze(),
|
||||
k[batch].squeeze(),
|
||||
)
|
||||
assert torch.equal(
|
||||
v_cache[kv_page, start:end].squeeze(),
|
||||
v[batch].squeeze(),
|
||||
)
|
||||
|
||||
def run():
|
||||
attention_kv_paged_stage1[
|
||||
(
|
||||
BATCH_SIZE,
|
||||
N_HEADS,
|
||||
num_blocks,
|
||||
)
|
||||
](
|
||||
q,
|
||||
k_cache,
|
||||
v_cache,
|
||||
torch.tensor(CACHE_LOCS, device=DEVICE, dtype=torch.int32),
|
||||
page_table,
|
||||
torch.tensor(CACHE_LEN, device=DEVICE, dtype=torch.int32),
|
||||
output_tensor,
|
||||
output_logsumexp,
|
||||
num_blocks,
|
||||
MAX_SEQ_LEN,
|
||||
N_HEADS,
|
||||
N_HEADS,
|
||||
D_HEAD,
|
||||
SEQ_BLOCK_SIZE,
|
||||
PAGE_SIZE,
|
||||
page_table.stride(0),
|
||||
)
|
||||
|
||||
run()
|
||||
|
||||
# This needs to be another kernel if torch-trt doesn't support broadcast + div.
|
||||
output = torch_reference_mha_stage2(output_tensor, output_logsumexp)
|
||||
|
||||
_ref = []
|
||||
for b in range(BATCH_SIZE):
|
||||
pages = PAGE_TABLE[CACHE_LOCS[b]]
|
||||
cache_l = CACHE_LEN[b]
|
||||
page_num = cache_l // PAGE_SIZE
|
||||
page_off = cache_l % PAGE_SIZE
|
||||
_k = []
|
||||
_v = []
|
||||
for p in range(page_num):
|
||||
_k.append(k_cache[pages[p]].reshape([-1, N_HEADS, D_HEAD]))
|
||||
_v.append(v_cache[pages[p]].reshape([-1, N_HEADS, D_HEAD]))
|
||||
_k.append(k_cache[pages[page_num], 0 : page_off + 1].reshape([-1, N_HEADS, D_HEAD]))
|
||||
_v.append(v_cache[pages[page_num], 0 : page_off + 1].reshape([-1, N_HEADS, D_HEAD]))
|
||||
_ref.append(
|
||||
torch.nn.functional.scaled_dot_product_attention(
|
||||
q[b].reshape([1, N_HEADS, 1, D_HEAD]),
|
||||
torch.cat(_k, 0).reshape(1, -1, N_HEADS, D_HEAD).transpose(1, 2),
|
||||
torch.cat(_v, 0).reshape(1, -1, N_HEADS, D_HEAD).transpose(1, 2),
|
||||
).transpose(2, 1)
|
||||
)
|
||||
ref = torch.cat(_ref, 1)
|
||||
|
||||
assert torch.allclose(
|
||||
ref.squeeze().cpu().to(torch.float32),
|
||||
output.squeeze().cpu().to(torch.float32),
|
||||
atol=1e-2,
|
||||
rtol=1e-2,
|
||||
)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: run(),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
|
||||
def compute_flops():
|
||||
flops = BATCH_SIZE * N_HEADS * (D_HEAD * D_HEAD * num_blocks * SEQ_BLOCK_SIZE) # S = q*K
|
||||
flops += (
|
||||
BATCH_SIZE
|
||||
* N_HEADS
|
||||
* (D_HEAD * num_blocks * SEQ_BLOCK_SIZE * num_blocks * SEQ_BLOCK_SIZE)
|
||||
) # S*V
|
||||
return flops
|
||||
|
||||
print("Time: %0.2f ms" % ms)
|
||||
print("GFLOPs: %0.2f" % (compute_flops() / ms / 1e6))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"dtype",
|
||||
["float16", "float32", "bfloat16"],
|
||||
)
|
||||
@pytest.mark.parametrize("n_heads, n_kv_heads", [(8, 8), (8, 1)])
|
||||
def test_context_attention_kv_paged(n_heads, n_kv_heads, dtype):
|
||||
DEVICE = "cuda"
|
||||
DTYPE = getattr(torch, dtype)
|
||||
N_HEADS = n_heads
|
||||
N_KV_HEADS = n_kv_heads
|
||||
D_HEAD = 16
|
||||
MAX_SEQ_LEN = 64
|
||||
SEQ_LENS = [36, 43, 21, 14, 18, 1, 1]
|
||||
BATCH_SIZE = len(SEQ_LENS)
|
||||
CACHE_LOCS = list(range(BATCH_SIZE))
|
||||
random.shuffle(CACHE_LOCS)
|
||||
NUM_PAGES = 256
|
||||
PAGE_SIZE = 4
|
||||
SEQ_BLOCK = 32
|
||||
|
||||
q = []
|
||||
k = []
|
||||
v = []
|
||||
for i, s in enumerate(SEQ_LENS):
|
||||
q.append(torch.randn(1, s, N_HEADS, D_HEAD, dtype=DTYPE, device=DEVICE) + i)
|
||||
k.append(torch.randn(1, s, N_KV_HEADS, D_HEAD, dtype=DTYPE, device=DEVICE) + i)
|
||||
v.append(torch.randn(1, s, N_KV_HEADS, D_HEAD, dtype=DTYPE, device=DEVICE) + i)
|
||||
|
||||
(q_f, k_f, v_f) = tuple(map(lambda x: torch.cat(x, 1).contiguous(), (q, k, v)))
|
||||
|
||||
k_cache = torch.zeros(NUM_PAGES, PAGE_SIZE, N_KV_HEADS, D_HEAD, dtype=DTYPE, device=DEVICE)
|
||||
v_cache = torch.zeros(NUM_PAGES, PAGE_SIZE, N_KV_HEADS, D_HEAD, dtype=DTYPE, device=DEVICE)
|
||||
|
||||
CACHE_LENS = []
|
||||
for b in range(BATCH_SIZE):
|
||||
CACHE_LENS.append(random.randint(0, 4 * PAGE_SIZE))
|
||||
|
||||
# allocate pages for kv cache
|
||||
# pages of each batch is continuous
|
||||
PAGE_TABLE = [None] * BATCH_SIZE
|
||||
PAGES_PER_SEQ = (MAX_SEQ_LEN + PAGE_SIZE - 1) // PAGE_SIZE
|
||||
cnt = 0
|
||||
for b in range(BATCH_SIZE):
|
||||
# allocate pages for history kv cache and new coming kv
|
||||
length = CACHE_LENS[b] + SEQ_LENS[b]
|
||||
allocated_pages = (length + PAGE_SIZE - 1) // PAGE_SIZE
|
||||
table = []
|
||||
for p in range(PAGES_PER_SEQ):
|
||||
if p < allocated_pages:
|
||||
table.append(cnt)
|
||||
cnt = cnt + 1
|
||||
else:
|
||||
table.append(0)
|
||||
PAGE_TABLE[CACHE_LOCS[b]] = table
|
||||
# prepare value for kv cache of decode batch
|
||||
cache_pages = CACHE_LENS[b] // PAGE_SIZE
|
||||
cache_page_off = CACHE_LENS[b] % PAGE_SIZE
|
||||
k_cache[table[0] : table[cache_pages]] = torch.randn(
|
||||
cache_pages, PAGE_SIZE, N_KV_HEADS, D_HEAD, dtype=DTYPE, device=DEVICE
|
||||
)
|
||||
v_cache[table[0] : table[cache_pages]] = torch.randn(
|
||||
cache_pages, PAGE_SIZE, N_KV_HEADS, D_HEAD, dtype=DTYPE, device=DEVICE
|
||||
)
|
||||
k_cache[table[cache_pages], 0:cache_page_off] = torch.randn(
|
||||
cache_page_off, N_KV_HEADS, D_HEAD, dtype=DTYPE, device=DEVICE
|
||||
)
|
||||
v_cache[table[cache_pages], 0:cache_page_off] = torch.randn(
|
||||
cache_page_off, N_KV_HEADS, D_HEAD, dtype=DTYPE, device=DEVICE
|
||||
)
|
||||
|
||||
page_table = torch.tensor(PAGE_TABLE, device=DEVICE, dtype=torch.int32)
|
||||
|
||||
seq_len = torch.tensor(SEQ_LENS, device=DEVICE, dtype=torch.int32)
|
||||
seq_start_indices = torch.zeros(BATCH_SIZE, device=DEVICE, dtype=torch.int32)
|
||||
seq_start_indices[1:] = torch.cumsum(seq_len[:-1], 0)
|
||||
|
||||
softmax_scale = 1.0 / math.sqrt(D_HEAD)
|
||||
output_tensor = torch.empty_like(q_f)
|
||||
|
||||
grid = (BATCH_SIZE, N_KV_HEADS, (max(SEQ_LENS) + SEQ_BLOCK - 1) // SEQ_BLOCK)
|
||||
update_paged_kv_cache[grid](
|
||||
k_f,
|
||||
v_f,
|
||||
seq_len,
|
||||
seq_start_indices,
|
||||
k_cache,
|
||||
v_cache,
|
||||
torch.tensor(CACHE_LOCS, device=DEVICE, dtype=torch.int32),
|
||||
torch.tensor(CACHE_LENS, device=DEVICE, dtype=torch.int32),
|
||||
page_table,
|
||||
N_KV_HEADS,
|
||||
D_HEAD,
|
||||
SEQ_BLOCK,
|
||||
MAX_SEQ_LEN,
|
||||
PAGE_SIZE,
|
||||
page_table.stride(0),
|
||||
GENERATE_ONLY=False,
|
||||
)
|
||||
|
||||
# Check if the cache was correctly updated:
|
||||
for batch, kv_batch in enumerate(CACHE_LOCS):
|
||||
batch_page_table = page_table[kv_batch]
|
||||
cache_len = CACHE_LENS[batch]
|
||||
update_len = SEQ_LENS[batch]
|
||||
if cache_len == 0:
|
||||
# context batch
|
||||
for seq_page, kv_page in enumerate(batch_page_table):
|
||||
if seq_page * PAGE_SIZE >= update_len:
|
||||
break
|
||||
start = 0
|
||||
end = min(update_len - seq_page * PAGE_SIZE, PAGE_SIZE)
|
||||
assert torch.equal(
|
||||
k_cache[kv_page, start:end].squeeze(),
|
||||
k[batch][:, seq_page * PAGE_SIZE : (seq_page + 1) * PAGE_SIZE].squeeze(),
|
||||
)
|
||||
assert torch.equal(
|
||||
v_cache[kv_page, start:end].squeeze(),
|
||||
v[batch][:, seq_page * PAGE_SIZE : (seq_page + 1) * PAGE_SIZE].squeeze(),
|
||||
)
|
||||
else:
|
||||
# decode batch, only check one token in one page
|
||||
check_page = cache_len // PAGE_SIZE
|
||||
kv_page = batch_page_table[check_page]
|
||||
start = cache_len % PAGE_SIZE
|
||||
end = start + 1
|
||||
assert torch.equal(
|
||||
k_cache[kv_page, start:end].squeeze(),
|
||||
k[batch][:, 0:1].squeeze(),
|
||||
)
|
||||
assert torch.equal(
|
||||
v_cache[kv_page, start:end].squeeze(),
|
||||
v[batch][:, 0:1].squeeze(),
|
||||
)
|
||||
|
||||
grid = (BATCH_SIZE, N_HEADS, (max(SEQ_LENS) + SEQ_BLOCK - 1) // SEQ_BLOCK)
|
||||
context_attention_kv_paged[grid](
|
||||
q_f,
|
||||
seq_len,
|
||||
seq_start_indices,
|
||||
k_cache,
|
||||
v_cache,
|
||||
torch.tensor(CACHE_LOCS, device=DEVICE, dtype=torch.int32),
|
||||
torch.tensor(CACHE_LENS, device=DEVICE, dtype=torch.int32),
|
||||
page_table,
|
||||
softmax_scale,
|
||||
output_tensor,
|
||||
N_HEADS,
|
||||
N_KV_HEADS,
|
||||
D_HEAD,
|
||||
SEQ_BLOCK,
|
||||
MAX_SEQ_LEN,
|
||||
PAGE_SIZE,
|
||||
page_table.stride(0),
|
||||
num_stages=2,
|
||||
)
|
||||
|
||||
def compute_reference(q, k_cache, v_cache):
|
||||
ref = []
|
||||
for batch in range(BATCH_SIZE):
|
||||
table = page_table[CACHE_LOCS[batch]]
|
||||
length = CACHE_LENS[batch] + SEQ_LENS[batch]
|
||||
cache_pages = length // PAGE_SIZE
|
||||
cache_page_off = length % PAGE_SIZE
|
||||
kk = []
|
||||
vv = []
|
||||
kk.append(
|
||||
k_cache[table[0] : table[0] + cache_pages].reshape(
|
||||
1, cache_pages * PAGE_SIZE, N_KV_HEADS, D_HEAD
|
||||
)
|
||||
)
|
||||
kk.append(
|
||||
k_cache[table[0] + cache_pages, 0:cache_page_off].reshape(
|
||||
1, cache_page_off, N_KV_HEADS, D_HEAD
|
||||
)
|
||||
)
|
||||
k_f = torch.cat(kk, 1)
|
||||
vv.append(
|
||||
v_cache[table[0] : table[0] + cache_pages].reshape(
|
||||
1, cache_pages * PAGE_SIZE, N_KV_HEADS, D_HEAD
|
||||
)
|
||||
)
|
||||
vv.append(
|
||||
v_cache[table[0] + cache_pages, 0:cache_page_off].reshape(
|
||||
1, cache_page_off, N_KV_HEADS, D_HEAD
|
||||
)
|
||||
)
|
||||
v_f = torch.cat(vv, 1)
|
||||
if N_HEADS != N_KV_HEADS:
|
||||
k_f = repeat_kv(q[batch], k_f)
|
||||
v_f = repeat_kv(q[batch], v_f)
|
||||
mask = torch.tril(
|
||||
torch.ones(q[batch].shape[1], k_f.shape[1], dtype=torch.bool),
|
||||
diagonal=k_f.shape[1] - q[batch].shape[1],
|
||||
)
|
||||
ref.append(
|
||||
torch.nn.functional.scaled_dot_product_attention(
|
||||
q[batch].transpose(1, 2),
|
||||
k_f.transpose(1, 2),
|
||||
v_f.transpose(1, 2),
|
||||
attn_mask=mask.to(DEVICE),
|
||||
).transpose(2, 1)
|
||||
)
|
||||
return torch.cat(ref, 1)
|
||||
|
||||
ref = compute_reference(q, k_cache, v_cache)
|
||||
assert torch.allclose(ref, output_tensor, atol=1e-2, rtol=1e-2)
|
||||
@ -10,8 +10,10 @@ import pytest
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import (
|
||||
PagedResourceHandler,
|
||||
CausalConvResourceHandler,
|
||||
KVPagedResourceHandler,
|
||||
SequenceInfo,
|
||||
SSMResourceHandler,
|
||||
StateResourceHandler,
|
||||
UnpagedResourceHandler,
|
||||
)
|
||||
@ -137,7 +139,7 @@ def test_init_default_device_is_cuda():
|
||||
|
||||
|
||||
def test_add_resource_paged_handler(paged_kv_cache_config):
|
||||
"""Test adding a PagedResourceHandler resource."""
|
||||
"""Test adding a KVPagedResourceHandler resource."""
|
||||
interface = CachedSequenceInterface(
|
||||
max_seq_len=128,
|
||||
max_batch_size=4,
|
||||
@ -145,15 +147,15 @@ def test_add_resource_paged_handler(paged_kv_cache_config):
|
||||
kv_cache_config=paged_kv_cache_config,
|
||||
)
|
||||
|
||||
handler = PagedResourceHandler(8, 64, dtype=torch.float16)
|
||||
interface.add_resource("k_cache_0", handler)
|
||||
handler = KVPagedResourceHandler(8, 64, dtype=torch.float16, kv_layout="HND")
|
||||
interface.add_resource("kv_cache_0", handler)
|
||||
|
||||
assert "k_cache_0" in interface._resource_lookup
|
||||
assert interface._resource_lookup["k_cache_0"] is handler
|
||||
assert "kv_cache_0" in interface._resource_lookup
|
||||
assert interface._resource_lookup["kv_cache_0"] is handler
|
||||
|
||||
|
||||
def test_add_resource_state_handler(paged_kv_cache_config):
|
||||
"""Test adding a StateResourceHandler resource."""
|
||||
"""Test adding a SSMResourceHandler resource."""
|
||||
interface = CachedSequenceInterface(
|
||||
max_seq_len=128,
|
||||
max_batch_size=4,
|
||||
@ -161,7 +163,7 @@ def test_add_resource_state_handler(paged_kv_cache_config):
|
||||
kv_cache_config=paged_kv_cache_config,
|
||||
)
|
||||
|
||||
handler = StateResourceHandler(4, 64, 16, dtype=torch.bfloat16)
|
||||
handler = SSMResourceHandler(num_heads=4, head_dim=64, d_state=16, dtype=torch.bfloat16)
|
||||
interface.add_resource("ssm_state_0", handler)
|
||||
|
||||
assert interface._resource_lookup["ssm_state_0"] is handler
|
||||
@ -192,12 +194,12 @@ def test_add_multiple_resources(paged_kv_cache_config):
|
||||
kv_cache_config=paged_kv_cache_config,
|
||||
)
|
||||
|
||||
k_handler = PagedResourceHandler(8, 64, dtype=torch.float16)
|
||||
v_handler = PagedResourceHandler(8, 64, dtype=torch.float16)
|
||||
ssm_handler = StateResourceHandler(4, 64, 16, dtype=torch.bfloat16)
|
||||
kv_handler_0 = KVPagedResourceHandler(8, 64, dtype=torch.float16)
|
||||
kv_handler_1 = KVPagedResourceHandler(8, 64, dtype=torch.float16)
|
||||
ssm_handler = SSMResourceHandler(num_heads=4, head_dim=64, d_state=16, dtype=torch.bfloat16)
|
||||
|
||||
interface.add_resource("k_cache_0", k_handler)
|
||||
interface.add_resource("v_cache_0", v_handler)
|
||||
interface.add_resource("kv_cache_0", kv_handler_0)
|
||||
interface.add_resource("kv_cache_1", kv_handler_1)
|
||||
interface.add_resource("ssm_state_0", ssm_handler)
|
||||
|
||||
assert len(interface._resource_lookup) == 3
|
||||
@ -217,9 +219,9 @@ def test_initialize_resources_paged_only_creates_kv_cache_manager(paged_kv_cache
|
||||
kv_cache_config=paged_kv_cache_config,
|
||||
)
|
||||
|
||||
# Add only paged resources
|
||||
interface.add_resource("k_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16))
|
||||
interface.add_resource("v_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16))
|
||||
# Add only paged resources (combined KV cache)
|
||||
interface.add_resource("kv_cache_0", KVPagedResourceHandler(8, 64, dtype=torch.float16))
|
||||
interface.add_resource("kv_cache_1", KVPagedResourceHandler(8, 64, dtype=torch.float16))
|
||||
|
||||
num_caches = interface.initialize_resources()
|
||||
|
||||
@ -238,9 +240,12 @@ def test_initialize_resources_mixed_creates_mamba_hybrid_cache_manager(paged_kv_
|
||||
)
|
||||
|
||||
# Add paged and state resources
|
||||
interface.add_resource("k_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16))
|
||||
interface.add_resource("v_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16))
|
||||
interface.add_resource("ssm_state_0", StateResourceHandler(4, 64, 16, dtype=torch.bfloat16))
|
||||
interface.add_resource("kv_cache_0", KVPagedResourceHandler(8, 64, dtype=torch.float16))
|
||||
interface.add_resource("kv_cache_1", KVPagedResourceHandler(8, 64, dtype=torch.float16))
|
||||
interface.add_resource(
|
||||
"ssm_state_0",
|
||||
SSMResourceHandler(num_heads=4, head_dim=64, d_state=16, dtype=torch.bfloat16),
|
||||
)
|
||||
|
||||
num_caches = interface.initialize_resources()
|
||||
|
||||
@ -259,26 +264,25 @@ def test_initialize_resources_creates_cache_views_with_correct_shape(paged_kv_ca
|
||||
|
||||
num_kv_heads = 8
|
||||
head_dim = 64
|
||||
# Using HND layout (default)
|
||||
interface.add_resource(
|
||||
"k_cache_0", PagedResourceHandler(num_kv_heads, head_dim, dtype=torch.float16)
|
||||
)
|
||||
interface.add_resource(
|
||||
"v_cache_0", PagedResourceHandler(num_kv_heads, head_dim, dtype=torch.float16)
|
||||
"kv_cache_0",
|
||||
KVPagedResourceHandler(num_kv_heads, head_dim, dtype=torch.float16, kv_layout="HND"),
|
||||
)
|
||||
|
||||
interface.initialize_resources()
|
||||
|
||||
# Check cache views exist
|
||||
assert "k_cache_0" in interface._caches
|
||||
assert "v_cache_0" in interface._caches
|
||||
# Check cache view exists
|
||||
assert "kv_cache_0" in interface._caches
|
||||
|
||||
# Check shapes: [num_blocks, tokens_per_block, num_kv_heads, head_dim]
|
||||
k_cache = interface._caches["k_cache_0"]
|
||||
assert k_cache is not None
|
||||
assert k_cache.shape[1] == paged_kv_cache_config.tokens_per_block
|
||||
assert k_cache.shape[2] == num_kv_heads
|
||||
assert k_cache.shape[3] == head_dim
|
||||
assert k_cache.dtype == torch.float16
|
||||
# Check shape for HND layout: [num_blocks, 2, num_kv_heads, tokens_per_block, head_dim]
|
||||
kv_cache = interface._caches["kv_cache_0"]
|
||||
assert kv_cache is not None
|
||||
assert kv_cache.shape[1] == 2 # K and V
|
||||
assert kv_cache.shape[2] == num_kv_heads
|
||||
assert kv_cache.shape[3] == paged_kv_cache_config.tokens_per_block
|
||||
assert kv_cache.shape[4] == head_dim
|
||||
assert kv_cache.dtype == torch.float16
|
||||
|
||||
|
||||
def test_initialize_resources_creates_state_views_with_correct_shape(paged_kv_cache_config):
|
||||
@ -293,10 +297,12 @@ def test_initialize_resources_creates_state_views_with_correct_shape(paged_kv_ca
|
||||
num_heads = 4
|
||||
head_dim = 64
|
||||
ssm_state_size = 16
|
||||
interface.add_resource("k_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16))
|
||||
interface.add_resource("kv_cache_0", KVPagedResourceHandler(8, 64, dtype=torch.float16))
|
||||
interface.add_resource(
|
||||
"ssm_state_0",
|
||||
StateResourceHandler(num_heads, head_dim, ssm_state_size, dtype=torch.bfloat16),
|
||||
SSMResourceHandler(
|
||||
num_heads=num_heads, head_dim=head_dim, d_state=ssm_state_size, dtype=torch.bfloat16
|
||||
),
|
||||
)
|
||||
|
||||
interface.initialize_resources()
|
||||
@ -345,11 +351,11 @@ def test_is_paged_returns_true_for_paged_only(paged_kv_cache_config):
|
||||
kv_cache_config=paged_kv_cache_config,
|
||||
)
|
||||
|
||||
interface.add_resource("k_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16))
|
||||
interface.add_resource("v_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16))
|
||||
interface.add_resource("kv_cache_0", KVPagedResourceHandler(8, 64, dtype=torch.float16))
|
||||
interface.add_resource("kv_cache_1", KVPagedResourceHandler(8, 64, dtype=torch.float16))
|
||||
interface.initialize_resources()
|
||||
|
||||
assert interface.is_paged() is True
|
||||
assert interface.kv_cache_config_tuned.enable_block_reuse is True
|
||||
|
||||
|
||||
def test_is_paged_returns_false_for_hybrid(paged_kv_cache_config):
|
||||
@ -361,11 +367,14 @@ def test_is_paged_returns_false_for_hybrid(paged_kv_cache_config):
|
||||
kv_cache_config=paged_kv_cache_config,
|
||||
)
|
||||
|
||||
interface.add_resource("k_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16))
|
||||
interface.add_resource("ssm_state_0", StateResourceHandler(4, 64, 16, dtype=torch.bfloat16))
|
||||
interface.add_resource("kv_cache_0", KVPagedResourceHandler(8, 64, dtype=torch.float16))
|
||||
interface.add_resource(
|
||||
"ssm_state_0",
|
||||
SSMResourceHandler(num_heads=4, head_dim=64, d_state=16, dtype=torch.bfloat16),
|
||||
)
|
||||
interface.initialize_resources()
|
||||
|
||||
assert interface.is_paged() is False
|
||||
assert interface.kv_cache_config_tuned.enable_block_reuse is False
|
||||
|
||||
|
||||
# =============================================================================
|
||||
@ -382,7 +391,7 @@ def test_needs_resize_returns_false_when_fraction_is_zero(paged_kv_cache_config)
|
||||
kv_cache_config=paged_kv_cache_config,
|
||||
)
|
||||
|
||||
interface.add_resource("k_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16))
|
||||
interface.add_resource("kv_cache_0", KVPagedResourceHandler(8, 64, dtype=torch.float16))
|
||||
interface.initialize_resources()
|
||||
|
||||
assert interface.needs_resize() is False
|
||||
@ -397,7 +406,7 @@ def test_needs_resize_returns_true_when_fraction_is_positive(resizable_kv_cache_
|
||||
kv_cache_config=resizable_kv_cache_config,
|
||||
)
|
||||
|
||||
interface.add_resource("k_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16))
|
||||
interface.add_resource("kv_cache_0", KVPagedResourceHandler(8, 64, dtype=torch.float16))
|
||||
interface.initialize_resources()
|
||||
|
||||
assert interface.needs_resize() is True
|
||||
@ -412,7 +421,7 @@ def test_resize_kv_cache_manager_skipped_when_not_needed(paged_kv_cache_config):
|
||||
kv_cache_config=paged_kv_cache_config,
|
||||
)
|
||||
|
||||
interface.add_resource("k_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16))
|
||||
interface.add_resource("kv_cache_0", KVPagedResourceHandler(8, 64, dtype=torch.float16))
|
||||
interface.initialize_resources()
|
||||
|
||||
# Get initial state
|
||||
@ -439,19 +448,19 @@ def test_shutdown_clears_caches(paged_kv_cache_config):
|
||||
kv_cache_config=paged_kv_cache_config,
|
||||
)
|
||||
|
||||
interface.add_resource("k_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16))
|
||||
interface.add_resource("v_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16))
|
||||
interface.add_resource("kv_cache_0", KVPagedResourceHandler(8, 64, dtype=torch.float16))
|
||||
interface.add_resource("kv_cache_1", KVPagedResourceHandler(8, 64, dtype=torch.float16))
|
||||
interface.initialize_resources()
|
||||
|
||||
assert len(interface._caches) == 2
|
||||
|
||||
interface.shutdown()
|
||||
|
||||
assert len(interface._caches) == 0
|
||||
assert all(cache is None for cache in interface._caches.values())
|
||||
|
||||
|
||||
def test_clear_cache_views_sets_views_to_none(paged_kv_cache_config):
|
||||
"""Test _clear_cache_views() sets paged and state cache views to None."""
|
||||
def test_clear_caches_clears_all(paged_kv_cache_config):
|
||||
"""Test _clear_caches() clears all cache entries."""
|
||||
interface = CachedSequenceInterface(
|
||||
max_seq_len=128,
|
||||
max_batch_size=4,
|
||||
@ -459,16 +468,20 @@ def test_clear_cache_views_sets_views_to_none(paged_kv_cache_config):
|
||||
kv_cache_config=paged_kv_cache_config,
|
||||
)
|
||||
|
||||
interface.add_resource("k_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16))
|
||||
interface.add_resource("ssm_state_0", StateResourceHandler(4, 64, 16, dtype=torch.bfloat16))
|
||||
interface.add_resource("kv_cache_0", KVPagedResourceHandler(8, 64, dtype=torch.float16))
|
||||
interface.add_resource(
|
||||
"ssm_state_0",
|
||||
SSMResourceHandler(num_heads=4, head_dim=64, d_state=16, dtype=torch.bfloat16),
|
||||
)
|
||||
interface.initialize_resources()
|
||||
|
||||
# Manually call _clear_cache_views
|
||||
interface._clear_cache_views()
|
||||
assert len(interface._caches) == 2
|
||||
|
||||
# Paged and state caches should be None
|
||||
assert interface._caches["k_cache_0"] is None
|
||||
assert interface._caches["ssm_state_0"] is None
|
||||
# Manually call _clear_caches
|
||||
interface._clear_caches()
|
||||
|
||||
# All caches should be cleared
|
||||
assert all(cache is None for cache in interface._caches.values())
|
||||
|
||||
|
||||
# =============================================================================
|
||||
@ -534,7 +547,7 @@ def test_named_args_includes_sequence_info_and_caches(paged_kv_cache_config):
|
||||
kv_cache_config=paged_kv_cache_config,
|
||||
)
|
||||
|
||||
interface.add_resource("k_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16))
|
||||
interface.add_resource("kv_cache_0", KVPagedResourceHandler(8, 64, dtype=torch.float16))
|
||||
interface.initialize_resources()
|
||||
|
||||
named_args = interface.named_args
|
||||
@ -544,7 +557,7 @@ def test_named_args_includes_sequence_info_and_caches(paged_kv_cache_config):
|
||||
assert "position_ids" in named_args
|
||||
|
||||
# Should contain cache
|
||||
assert "k_cache_0" in named_args
|
||||
assert "kv_cache_0" in named_args
|
||||
|
||||
|
||||
def test_args_returns_tuple_of_tensors(paged_kv_cache_config):
|
||||
@ -556,7 +569,7 @@ def test_args_returns_tuple_of_tensors(paged_kv_cache_config):
|
||||
kv_cache_config=paged_kv_cache_config,
|
||||
)
|
||||
|
||||
interface.add_resource("k_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16))
|
||||
interface.add_resource("kv_cache_0", KVPagedResourceHandler(8, 64, dtype=torch.float16))
|
||||
interface.initialize_resources()
|
||||
|
||||
args = interface.args
|
||||
@ -722,3 +735,154 @@ def test_sequence_info_page_assignments():
|
||||
|
||||
page_assignments = seq_info.page_assignments
|
||||
assert page_assignments == [[0], [1, 2]]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Typed State Resource Handler Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_ssm_resource_handler_state_shape():
|
||||
"""Test SSMResourceHandler returns correct state_shape property."""
|
||||
handler = SSMResourceHandler(num_heads=8, head_dim=64, d_state=16, dtype=torch.bfloat16)
|
||||
assert handler.state_shape == (8, 64, 16)
|
||||
assert handler.num_heads == 8
|
||||
assert handler.head_dim == 64
|
||||
assert handler.d_state == 16
|
||||
assert handler.dtype == torch.bfloat16
|
||||
|
||||
|
||||
def test_causal_conv_resource_handler_state_shape():
|
||||
"""Test CausalConvResourceHandler returns correct state_shape property."""
|
||||
# d_conv=4 means state stores d_conv-1=3 elements
|
||||
handler = CausalConvResourceHandler(conv_dim=256, d_conv=4, dtype=torch.float32)
|
||||
assert handler.state_shape == (256, 3) # (conv_dim, d_conv - 1)
|
||||
assert handler.conv_dim == 256
|
||||
assert handler.d_conv == 4
|
||||
assert handler.dtype == torch.float32
|
||||
|
||||
|
||||
def test_typed_handlers_inherit_from_state_resource_handler():
|
||||
"""Test that SSMResourceHandler and CausalConvResourceHandler inherit from StateResourceHandler."""
|
||||
ssm_handler = SSMResourceHandler(num_heads=4, head_dim=64, d_state=16, dtype=torch.bfloat16)
|
||||
conv_handler = CausalConvResourceHandler(conv_dim=256, d_conv=4, dtype=torch.float32)
|
||||
|
||||
assert isinstance(ssm_handler, StateResourceHandler)
|
||||
assert isinstance(conv_handler, StateResourceHandler)
|
||||
|
||||
|
||||
def test_multiple_ssm_resources_contiguous_views(paged_kv_cache_config):
|
||||
"""Test that multiple SSM resources get contiguous views from MambaHybridCacheManager."""
|
||||
interface = CachedSequenceInterface(
|
||||
max_seq_len=128,
|
||||
max_batch_size=4,
|
||||
device="cuda",
|
||||
kv_cache_config=paged_kv_cache_config,
|
||||
)
|
||||
|
||||
# Add 3 SSM resources with same parameters (compatible)
|
||||
for i in range(3):
|
||||
interface.add_resource(
|
||||
f"ssm_state_{i}",
|
||||
SSMResourceHandler(num_heads=4, head_dim=64, d_state=16, dtype=torch.bfloat16),
|
||||
)
|
||||
|
||||
interface.initialize_resources()
|
||||
|
||||
# Verify all SSM views are contiguous
|
||||
for i in range(3):
|
||||
ssm_cache = interface._caches[f"ssm_state_{i}"]
|
||||
assert ssm_cache is not None
|
||||
assert ssm_cache.is_contiguous(), f"SSM view {i} is not contiguous"
|
||||
|
||||
|
||||
def test_multiple_conv_resources_contiguous_views(paged_kv_cache_config):
|
||||
"""Test that multiple Conv resources get contiguous views from MambaHybridCacheManager."""
|
||||
interface = CachedSequenceInterface(
|
||||
max_seq_len=128,
|
||||
max_batch_size=4,
|
||||
device="cuda",
|
||||
kv_cache_config=paged_kv_cache_config,
|
||||
)
|
||||
|
||||
# Add 3 Conv resources with same parameters (compatible)
|
||||
for i in range(3):
|
||||
interface.add_resource(
|
||||
f"conv_state_{i}",
|
||||
CausalConvResourceHandler(conv_dim=256, d_conv=4, dtype=torch.float32),
|
||||
)
|
||||
|
||||
interface.initialize_resources()
|
||||
|
||||
# Verify all Conv views are contiguous
|
||||
for i in range(3):
|
||||
conv_cache = interface._caches[f"conv_state_{i}"]
|
||||
assert conv_cache is not None
|
||||
assert conv_cache.is_contiguous(), f"Conv view {i} is not contiguous"
|
||||
|
||||
|
||||
def test_mixed_ssm_conv_resources_uses_min_layers(paged_kv_cache_config):
|
||||
"""Test that when both SSM and Conv resources exist, uses min(ssm_count, conv_count) layers."""
|
||||
interface = CachedSequenceInterface(
|
||||
max_seq_len=128,
|
||||
max_batch_size=4,
|
||||
device="cuda",
|
||||
kv_cache_config=paged_kv_cache_config,
|
||||
)
|
||||
|
||||
# Conv params must satisfy n_groups constraint with SSM params:
|
||||
# conv_dim = head_dim * num_heads + 2 * n_groups * d_state
|
||||
# With num_heads=4, head_dim=64, d_state=16, n_groups=2:
|
||||
# conv_dim = 64 * 4 + 2 * 2 * 16 = 256 + 64 = 320
|
||||
num_heads = 4
|
||||
head_dim = 64
|
||||
d_state = 16
|
||||
n_groups = 2
|
||||
conv_dim = head_dim * num_heads + 2 * n_groups * d_state
|
||||
|
||||
# Add 3 SSM and 2 Conv resources
|
||||
for i in range(3):
|
||||
interface.add_resource(
|
||||
f"ssm_state_{i}",
|
||||
SSMResourceHandler(
|
||||
num_heads=num_heads, head_dim=head_dim, d_state=d_state, dtype=torch.bfloat16
|
||||
),
|
||||
)
|
||||
|
||||
for i in range(2):
|
||||
interface.add_resource(
|
||||
f"conv_state_{i}",
|
||||
CausalConvResourceHandler(conv_dim=conv_dim, d_conv=4, dtype=torch.float32),
|
||||
)
|
||||
|
||||
interface.initialize_resources()
|
||||
|
||||
# Verify MambaHybridCacheManager was created
|
||||
assert isinstance(interface.kv_cache_manager, MambaHybridCacheManager)
|
||||
|
||||
# All caches should exist
|
||||
for i in range(3):
|
||||
assert interface._caches[f"ssm_state_{i}"] is not None
|
||||
for i in range(2):
|
||||
assert interface._caches[f"conv_state_{i}"] is not None
|
||||
|
||||
|
||||
def test_generic_state_handler_allocated_locally(paged_kv_cache_config):
|
||||
"""Test that generic StateResourceHandler (not SSM/Conv) is allocated locally."""
|
||||
interface = CachedSequenceInterface(
|
||||
max_seq_len=128,
|
||||
max_batch_size=4,
|
||||
device="cuda",
|
||||
kv_cache_config=paged_kv_cache_config,
|
||||
)
|
||||
|
||||
# Add a generic StateResourceHandler (not SSM or Conv)
|
||||
generic_handler = StateResourceHandler(10, 20, dtype=torch.float32)
|
||||
interface.add_resource("generic_state", generic_handler)
|
||||
|
||||
interface.initialize_resources()
|
||||
|
||||
# Generic handler should be allocated but via local allocation (not MambaHybridCacheManager)
|
||||
assert interface._caches["generic_state"] is not None
|
||||
# Without typed handlers, should use plain KVCacheManager
|
||||
assert isinstance(interface.kv_cache_manager, KVCacheManager)
|
||||
|
||||
@ -7,8 +7,8 @@ import torch.nn as nn
|
||||
from _model_test_utils import GQA
|
||||
from _torch_test_utils import all_close
|
||||
|
||||
# Initialize resources first
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import PagedResourceHandler
|
||||
# Initialize resources first (KVPagedResourceHandler is used within tests below)
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import KVPagedResourceHandler
|
||||
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
|
||||
from tensorrt_llm._torch.auto_deploy.models.factory import (
|
||||
FullModelExportInfo,
|
||||
@ -292,10 +292,12 @@ def test_initialize_cache_transform_calls_initialize_resources(dummy_cached_inte
|
||||
transform = InitializeCache(config=TransformConfig(stage=Stages.PATTERN_MATCHER))
|
||||
|
||||
# Add a resource to verify initialize_resources is called
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import PagedResourceHandler
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import (
|
||||
KVPagedResourceHandler,
|
||||
)
|
||||
|
||||
dummy_cached_interface.add_resource(
|
||||
"k_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16)
|
||||
"kv_cache_0", KVPagedResourceHandler(8, 64, dtype=torch.float16)
|
||||
)
|
||||
|
||||
# Mock the factory and shared_config
|
||||
@ -316,7 +318,7 @@ def test_initialize_cache_transform_calls_initialize_resources(dummy_cached_inte
|
||||
def test_resize_kv_cache_transform_skipped_when_not_needed(dummy_cached_interface):
|
||||
"""Verify ResizeKVCache transform is skipped when resize not needed."""
|
||||
dummy_cached_interface.add_resource(
|
||||
"k_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16)
|
||||
"kv_cache_0", KVPagedResourceHandler(8, 64, dtype=torch.float16)
|
||||
)
|
||||
dummy_cached_interface.initialize_resources()
|
||||
|
||||
@ -356,7 +358,7 @@ def test_resize_kv_cache_transform_runs_when_needed():
|
||||
kv_cache_config=kv_cache_config,
|
||||
)
|
||||
|
||||
cm.add_resource("k_cache_0", PagedResourceHandler(8, 64, dtype=torch.float16))
|
||||
cm.add_resource("kv_cache_0", KVPagedResourceHandler(8, 64, dtype=torch.float16))
|
||||
cm.initialize_resources()
|
||||
|
||||
# Create the transform with a proper config
|
||||
@ -523,7 +525,9 @@ def test_insert_cached_attention_passes_kv_cache_config():
|
||||
# Initialize resources
|
||||
cm.initialize_resources()
|
||||
|
||||
assert not cm.is_paged(), "triton should not use paged resources"
|
||||
assert not any(handler.is_paged for handler in cm._resource_lookup.values()), (
|
||||
"triton should not use paged resources"
|
||||
)
|
||||
assert cm._caches, "at least some resources should be present"
|
||||
|
||||
# Verify cache dtype matches config
|
||||
|
||||
Loading…
Reference in New Issue
Block a user