KV Cache Management: Pools, Blocks, and Events#
+This document provides an overview of the internal hierarchy and event system for paged KV cache management, as implemented in the TensorRT-LLM codebase.
+For more information on KV cache reuse see KV cache reuse.
++
Hierarchy: Pool, Block, and Page#
+Block#
+-
+
Definition: The smallest unit of KV cache allocation. A
KVCacheBlockholds metadata (not the actual data) for a chunk of KV cache.
+Purpose: Each block represents a fixed number of tokens’ worth of KV data (can be specified by
tokens_per_blockparameter).
+Usage: Blocks are allocated, reused, or evicted as sequences are processed.
+
Page#
+-
+
Definition: In this codebase, “page” is often used interchangeably with “block” (as in “paged KV cache”), but technically, a page could refer to a memory page (hardware-level), while a block is a logical unit for the cache.
+In Practice: The code uses “block” as the main unit; “page” is not a distinct class or struct.
+
Pool#
+-
+
Definition: A pool is a contiguous memory buffer (or set of buffers) that holds the actual KV data for one or more layers.
+Types: There are primary pools (fast GPU memory) and secondary pools (slower, e.g., CPU or offload memory).
+Organization: Each pool can serve multiple layers that share the same KV head configuration. Pools are managed by
KVCacheBlockPooland tracked in vectors inWindowBlockManager.
+Block ↔ Pool: Each block is an index into a pool; the pool provides the actual storage, while the block is the metadata handle.
+
WindowBlockManager/BlockManager#
+TRT-LLM supports 2 complex features related to KV cache management:
+-
+
Variable Group-Query Attention (VGQA) - i.e. a different
num_kv_headsvalue for different layers.
+Variable Sliding Window Attention (VSWA) - i.e. a different
attention_window_sizevalue for different layers.
+
In order to support both of these features, the pool management works as described below.
+But in the simple, most common case, for most models, where
+-
+
MHA/MQA/Non-variable GQA, i.e., same
num_kv_headsvalue for all layers,
+Global attention/SWA, i.e., same
attention_window_sizevalue for all layers,
+
only a single pool will be created within the structure described below.
+KV Cache Pool Management#
+-
+
WindowBlockManager: Manages blocks and pools for a specific attention window size. Within a
WindowBlockManager, there can be multiple pools - each corresponding a unique number of KV heads - i.e., to support VGQA.
+BlockManager: Manages all
WindowBlockManagerinstances, one per unique window size.
+
Hierarchy Summary:
+-
+
Pool (memory buffer for KV data)
+-
+
Contains many blocks.
+
+Blocks (metadata for a chunk of the pool, each block =
+tokens_per_blocktokens)-
+
(Optionally, blocks can be swapped between primary/secondary pools.)
+
+BlockManager/WindowBlockManager: Manage pools and blocks, handle allocation, reuse, and eviction.
+
+
Events in KVCacheEventManager#
+The KVCacheEventManager is responsible for tracking and reporting significant changes in the state of the KV cache. Events are used for logging, debugging, or possibly for external monitoring.
Types of Events#
+-
+
Created Event: When pools or blocks are created/allocated.
+Updated Event: When a block’s state changes (e.g., moved between primary/secondary, priority updated).
+Removed Event: When a block is removed from the cache (evicted or released).
+Stored Event: When blocks are stored for potential reuse (e.g., after a sequence finishes and its blocks are reusable).
+
What Triggers an Event?#
+-
+
Allocation/Deallocation: Creating or freeing memory pools or blocks.
+Eviction/Reuse: When a block is evicted, reused, or its priority changes.
+Block Movement: When a block is moved between memory levels (primary ↔ secondary).
+Block Storage: When blocks are stored for future reuse (e.g., after a sequence completes).
+
In summary: +An “event” is any significant change in the lifecycle or state of a KV cache block or pool, tracked for monitoring, debugging, or optimization purposes.
++
### Precision Strategy
We have explored a mixed precision recipe, which provides a better tradeoff between accuracy and performance.
@@ -84,7 +84,7 @@ We have also explored and introduced mixed parallel strategy on 8xB200 GPUs. Spe
### Everything in One Diagram
Now let's put everything into one diagram, which represents a MoE layer from a decoding iteration.
-
+
The modules in the diagram are:
@@ -136,7 +136,7 @@ The modules in the diagram are:
| Optimize CUTLASS Flow: Sparse Experts as GEMMs | 249 | The code is not open-source yet due to the dependency with internal base environment and we are planning to make it decoupled from internal base environment thus to be able to open-source in the future.|
| Introduce EP4TP2 for better workload balance | 253 | Use `--tp 8 --ep 4` when benchmarking |
| Introduce moe_backend=TRTLLM, EP2TP4 for better balance | 299 | [PR #4280](https://github.com/NVIDIA/TensorRT-LLM/pull/4280) |
-| Optimize Fuse_A_GEMM and Router_GEMM | 340 | WIP: [PR #4115](https://github.com/NVIDIA/TensorRT-LLM/pull/4115) |
+| Optimize Fuse_A_GEMM and Router_GEMM | 340 | WIP |
| Relax Acceptance | **368** | [deepseek_v3#multi-token-prediction-mtp](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/models/core/deepseek_v3#multi-token-prediction-mtp) |
### System Level optimizations
@@ -195,7 +195,7 @@ We have introduced multi-streams based optimizations to hide some kernels' overh
#### Sparse Experts as GEMMs (only works when moe_backend=CUTLASS)
-
+
The existing CUTLASS-based Sparse Experts flow (illustrated in the figure) dispatches input tokens to their designated experts, then applies indexed local reduction on each expert's outputs before a global allreduce. Both dispatching and indexed local reduction incur high overhead in low-latency scenarios. To address this, we propose treating "Sparse Experts as GEMMs" by sending all tokens to each activated expert and masking out unneeded outputs before local reduction. Because grouped GEMMs are memory-bound, the extra computations from redundant tokens have minimal impact, effectively eliminating the costly dispatch and reduction overhead.
@@ -229,12 +229,12 @@ We focus on optimizing two kinds of dense GEMMs: Fuse_A_GEMM and RouterGEMM, bec
##### Fuse_A_GEMM
We developed a custom Fuse_A_GEMM that prefetches the majority of its weights into shared memory (enabled by PDL and overlapped with oneshot-AllReduce), significantly enhancing performance. The kernel shows substantial improvements over default GEMM implementation when num_tokens < 16.
-
+
##### RouterGEMM
-By leveraging our internal AI code generator, we automatically generate an optimized RouterGEMM kernel, which delivers substantial improvements over the default GEMM implementation when [num_tokens <=30](https://github.com/NVIDIA/TensorRT-LLM/pull/4115/files#diff-006ae982200a5ef2b27f4aedb526025e64406d3c2fadde329ea745793fac04edR303:~:text=and%20hidden_states.-,size,-(0))
+By leveraging our internal AI code generator, we automatically generate an optimized RouterGEMM kernel, which delivers substantial improvements over the default GEMM implementation when num_tokens <=30.
-
+
#### Kernel fusion
Kernel fusion is necessary for min-latency scenario to reduce extra global memory write/read cost, and we support following fusion patterns now
diff --git a/_sources/blogs/tech_blog/blog2_DeepSeek_R1_MTP_Implementation_and_Optimization.md.txt b/_sources/blogs/tech_blog/blog2_DeepSeek_R1_MTP_Implementation_and_Optimization.md.txt
new file mode 100644
index 0000000000..0014f1c7f2
--- /dev/null
+++ b/_sources/blogs/tech_blog/blog2_DeepSeek_R1_MTP_Implementation_and_Optimization.md.txt
@@ -0,0 +1,252 @@
+# DeepSeek R1 MTP Implementation and Optimization
+by NVIDIA TensorRT-LLM team
+## Table of Contents
+- [MTP for inference](#mtp-for-inference)
+ - [Background](#background)
+ - [MTP Vanilla](#mtp-vanilla)
+ - [MTP Eagle](#mtp-eagle)
+- [MTP implementation in TensorRT-LLM](#mtp-implementation-in-tensorrt-llm)
+ - [Basic Implementation](#basic-implementation)
+ - [MTP Modules](#mtp-modules)
+ - [Attention for MTP](#attention-for-mtp)
+ - [How to run DeepSeek models with MTP](#how-to-run-deepseek-models-with-mtp)
+- [MTP optimization - Relaxed Acceptance](#mtp-optimization---relaxed-acceptance)
+ - [Relaxed Acceptance](#relaxed-acceptance)
+ - [How to run the DeepSeek-R1 model with Relaxed Acceptance](#how-to-run-the-deepseek-r1-model-with-relaxed-acceptance)
+- [Evaluation](#evaluation)
+ - [Achieving speedup with MTP speculative decoding](#achieving-speedup-with-mtp-speculative-decoding)
+ - [Accuracy studies for Relaxed Acceptance](#accuracy-studies-for-relaxed-acceptance)
+- [Future Works](#future-works)
+ - [Tree-based speculative decoding support](#tree-based-speculative-decoding-support)
+ - [Eagle3 support](#eagle3-support)
+ - [Fix known issues](#fix-known-issues)
+- [Acknowledgment](#acknowledgment)
+
+
+TensorRT-LLM achieves world-record inference performance for DeepSeek-R1 on NVIDIA Blackwell GPUs, where Multi-Token Prediction (MTP) delivers a significant speedup. In our [previous blog post](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/blogs/tech_blog/blog1_Pushing_Latency_Boundaries_Optimizing_DeepSeek-R1_Performance_on_NVIDIA_B200_GPUs.md), we discussed the key optimizations that enable the outstanding inference latency of the DeepSeek-R1 model. This article dives deeper into the implementation and optimization of MTP in TensorRT-LLM.
+
+## MTP for inference
+Inspired by a previous [research work](https://arxiv.org/pdf/2404.19737), MTP is designed to help the DeepSeek-V3 training. It adds additional MTP modules at the end of the main model and uses them to predict additional tokens. In this way, MTP can extend the prediction scope to multiple future tokens at each position to achieve better model accuracy. During inference, those MTP modules can also be used for speculative decoding to improve the generation latency further. In this section, we will introduce the MTP speculative decoding algorithm for LLM inference.
+
+### Background
+Speculative decoding is a popular technique for faster and cost-effective LLM inference. It’s based on the premise that generating multiple future tokens(especially for decode phase which is less compute bound) is more efficient than processing a single token. Speculative decoding techniques usually divide the process into a low-cost draft stage and a parallelized verification stage. The draft stage predicts draft tokens by using a small model or a subset of layers in the main model. And the verification stage uses the main model to determine how many of these draft tokens to accept, which is far more efficient than generating one token per iteration.
+
+
+
+
+
+
+
+
+
+
+
