diff --git a/README.md b/README.md index ad0d485137..b796de6a1b 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,12 @@ This branch is a prototype and not stable for production use. PRs are not accept ## Tech Blogs +* [02/06] Accelerating Long-Context Inference with Skip Softmax Attention +✨ [➡️ link](https://nvidia.github.io/TensorRT-LLM/blogs/tech_blog/blog16_Accelerating_Long_Context_Inference_with_Skip_Softmax_Attention.html) + +* [01/09] Optimizing DeepSeek-V3.2 on NVIDIA Blackwell GPUs +✨ [➡️ link](https://nvidia.github.io/TensorRT-LLM/blogs/tech_blog/blog15_Optimizing_DeepSeek_V32_on_NVIDIA_Blackwell_GPUs) + * [10/13] Scaling Expert Parallelism in TensorRT LLM (Part 3: Pushing the Performance Boundary) ✨ [➡️ link](https://nvidia.github.io/TensorRT-LLM/blogs/tech_blog/blog14_Scaling_Expert_Parallelism_in_TensorRT-LLM_part3.html) diff --git a/docs/source/blogs/media/tech_blog16_blackwell_decode.png b/docs/source/blogs/media/tech_blog16_blackwell_decode.png new file mode 100644 index 0000000000..ae138da51c Binary files /dev/null and b/docs/source/blogs/media/tech_blog16_blackwell_decode.png differ diff --git a/docs/source/blogs/media/tech_blog16_blackwell_prefill.png b/docs/source/blogs/media/tech_blog16_blackwell_prefill.png new file mode 100644 index 0000000000..08d626910d Binary files /dev/null and b/docs/source/blogs/media/tech_blog16_blackwell_prefill.png differ diff --git a/docs/source/blogs/media/tech_blog16_blasst.jpg b/docs/source/blogs/media/tech_blog16_blasst.jpg new file mode 100644 index 0000000000..4b96efd03c Binary files /dev/null and b/docs/source/blogs/media/tech_blog16_blasst.jpg differ diff --git a/docs/source/blogs/media/tech_blog16_hopper_decode.png b/docs/source/blogs/media/tech_blog16_hopper_decode.png new file mode 100644 index 0000000000..5fb9799c10 Binary files /dev/null and b/docs/source/blogs/media/tech_blog16_hopper_decode.png differ diff --git a/docs/source/blogs/media/tech_blog16_hopper_prefill.png b/docs/source/blogs/media/tech_blog16_hopper_prefill.png new file mode 100644 index 0000000000..b0663f72a9 Binary files /dev/null and b/docs/source/blogs/media/tech_blog16_hopper_prefill.png differ diff --git a/docs/source/blogs/tech_blog/blog16_Accelerating_Long_Context_Inference_with_Skip_Softmax_Attention.md b/docs/source/blogs/tech_blog/blog16_Accelerating_Long_Context_Inference_with_Skip_Softmax_Attention.md new file mode 100644 index 0000000000..466803e56d --- /dev/null +++ b/docs/source/blogs/tech_blog/blog16_Accelerating_Long_Context_Inference_with_Skip_Softmax_Attention.md @@ -0,0 +1,276 @@ +# Accelerating Long-Context Inference with Skip Softmax Attention + +As context lengths grow from thousands to hundreds of thousands of tokens, attention computation becomes a major bottleneck in long-context LLM inference. TensorRT-LLM provides a [sparse attention framework](../../features/sparse-attention.md#framework-level-sparse-attention) that supports techniques like KV cache compression and sparse pattern prediction, featured in [RocketKV](https://arxiv.org/pdf/2502.14051) and [DSA](https://github.com/deepseek-ai/DeepSeek-V3.2-Exp/blob/main/DeepSeek_V3_2.pdf). However, these require framework-level support—additional token selection steps, auxiliary data structures, and kernel modifications are required, compared to the vanilla attention architecture. This complexity introduces **runtime overhead** that can offset performance gains, particularly when context lengths are not long enough to amortize the extra work. + +In this blog, we introduce **Skip Softmax Attention**, a drop-in sparse attention technique that is designed to accelerate the existing pretrained models that use standard attention mechanisms like MHA, GQA, or MLA. Skip Softmax Attention based on top of the Flash Attention algorithm and only requires modifying the existing **attention kernels**. Due to this simplicity, the end-to-end performance gain is more predictable. In addition, it is only an approximation method of the attention kernel computation, making it compatible with nearly all the other features, such as FP8 attention, KV cache reuse, chunked prefill etc. + +## Table of Contents +- [Accelerating Long-Context Inference with Skip Softmax Attention](#accelerating-long-context-inference-with-skip-softmax-attention) + - [Table of Contents](#table-of-contents) + - [Method Overview](#method-overview) + - [Example Usage](#example-usage) + - [Accuracy Evaluation](#accuracy-evaluation) + - [Performance Benchmark](#performance-benchmark) + - [Kernel Performance](#kernel-performance) + - [End-to-end Performance](#end-to-end-performance) + - [Reproduction](#reproduction) + - [Accuracy evaluation (LongBench V1/V2)](#accuracy-evaluation-longbench-v1v2) + - [End-to-end performance (TTFT/TPOT)](#end-to-end-performance-ttfttpot) + - [Conclusion](#conclusion) + +## Method Overview + +The idea of Skip Softmax Attention is to compare the local maximum $\tilde{m}_i^{(j)}$ of $Q \cdot K^T$ with the running global maximum $m_i^{(j)}$, and skip the softmax (exp) and BMM2 calculation for blocks that are below a certain threshold $\lambda$: + +$$\tilde{m}_i^{(j)} - m_i^{(j)} < \lambda$$ + +In this way, we can indirectly control the sparsity via the threshold. The threshold is set to be inversely proportional to the context length, i.e., the longer the context, the smaller the threshold is needed to achieve the same sparsity. + +The method is fully dynamic, and can be applied to both the prefilling and decoding. The algorithm of Skip Softmax Attention is described in the paper [BLASST: Dynamic Blocked Attention Sparsity via Softmax Thresholding](https://arxiv.org/pdf/2512.12087). We have also published a [Developer Blog](https://developer.nvidia.com/blog/accelerating-long-context-inference-with-skip-softmax-in-nvidia-tensorrt-llm/) for explanation. Please refer to these resources for in-depth dive into the algorithm details. We will focus on the application of Skip Softmax Attention in TensorRT-LLM to accelerate long-context inference. + +
+
+
Hopper (H200)
+Prefill
+
+ Decode
+
+ Blackwell (B200)
+Prefill
+
+ Decode
+
+
Figure 1: The sparse attention framework in TensorRT LLM.
+Figure 1: The framework support for sparse attention in TensorRT LLM.
Our goal is to design a general, extensible, and flexible sparse attention framework. In this framework, the attention operator provides the unified APIs to support both **sparse computation** and **sparse KV cache** that leverage token sparsity, while the users/developers can only focus on the algorithm of sparse attentions, i.e. how to accurately identify important query-key pairs. For the generality, TensorRT LLM abstracts sparse attention into a prediction-based workflow: *a prediction module first identifies the sparse indices (tokens/blocks to keep or attend to), which are then used by the subsequent attention operator*. Currently, for standard attention (MQA/GQA), TensorRT LLM supports **sparse KV cache** in the context phase and **sparse computation** in the generation phase. Different KV heads are allowed to use different sparse indices, while Q heads that map to the same KV head share the same sparse pattern. It does **not** yet support sparse computation in the context phase or sparse KV cache in the generation phase. -For the scalability, figure 1 illustrates the overall design. The architecture is built by inheriting from the existing `AttentionBackend` to define algorithm-specific sparse attention backends. Within these backends, `prediction` methods are implemented to generate the corresponding sparse indices. These indices are then passed as arguments to the `AttentionOp` to perform the sparse attention computation. This approach balances system flexibility with extensibility, allowing new algorithms to be integrated by simply defining their prediction logic **without** modifying the core attention kernels. +For the scalability, Figure 1 illustrates the overall design. The architecture is built by inheriting from the existing `AttentionBackend` to define algorithm-specific sparse attention backends. Within these backends, `prediction` methods are implemented to generate the corresponding sparse indices. These indices are then passed as arguments to the `AttentionOp` to perform the sparse attention computation. This approach balances system flexibility with extensibility, allowing new algorithms to be integrated by simply defining their prediction logic **without** modifying the core attention kernels. -TensorRT LLM currently supports the following features: +TensorRT LLM currently supports the following features in the framework: 1. **Context Phase**: * **sparse computation**: MLA @@ -126,7 +196,7 @@ TensorRT LLM currently supports the following features: * **sparse computation**: MLA/MQA/GQA * **sparse KV cache**: no support yet -### Framework Implementation +#### Framework Implementation To hide the complexity of sparse algorithms, the main prediction logic is encapsulated within the `tensorrt_llm._torch.attention_backend` module. @@ -159,11 +229,11 @@ The key files located in `tensorrt_llm/_torch/attention_backend/sparse/` are: In `AttentionOp`, currently, the MQA/GQA sparse attention only supports sparse computation at block granularity in the generation phase, where the block size equals to the page size of the KV cache. It means that we can skip the attention computation of those unimportant pages. In addition, we provide a sparse MLA kernel that supports token-level sparse computation in both the context and generation phases. -To support those features, as illustrated in figure 2, we have implemented two kernels for the MQA/GQA path, `updateSparseKvCacheAfterFmha` and `gatherKvPageOffsetsKernel`, applied in the context and generation phases respectively: +To support those features, as illustrated in Figure 2, we have implemented two kernels for the MQA/GQA path, `updateSparseKvCacheAfterFmha` and `gatherKvPageOffsetsKernel`, applied in the context and generation phases respectively: * **`updateSparseKvCacheAfterFmha`**: Invoked in the post-processing stage after the context attention computation. It selects the important KV tokens and write those K/V vectors to the KV cache to reduce the KV cache size. -* **`gatherKvPageOffsetsKernel`**: Executed before the attention computation in the generation phase. It converts the input sparse indices (which can be of arbitrary granularity) into page-aligned indices. This means that if a single token is selected, the entire page it is included in the attention computation. After this conversion, we will get a new `kv_page_offsets` and also an updated `kv_len` that is the number of those selected KV tokens. Then these new metadata are fed into the subsequent attention kernel for computation. +* **`gatherKvPageOffsetsKernel`**: Executed before the attention computation in the generation phase. It converts the input sparse indices (which can be of arbitrary granularity) into page-aligned indices. This means that if a single token is selected, the entire page is included in the attention computation. After this conversion, we will get a new `kv_page_offsets` and also an updated `kv_len` that is the number of those selected KV tokens. Then these new metadata are fed into the subsequent attention kernel for computation. For sparse MLA, the kernel supports token sparsity directly, eliminating the need for `gatherKvPageOffsetsKernel`. However, please note that sparse KV cache support is not yet available. @@ -175,7 +245,7 @@ Many sparse attention algorithms also require additional auxiliary memory. In th Each option has its own advantages and disadvantages, please refer to the [Manage Auxiliary Memory Pool](#3-manage-auxiliary-memory-pool) for more details. -### Implementing a New Algorithm +#### Implementing a New Algorithm Inside the Sparse Attention Framework #### 1. Configuration Class @@ -237,19 +307,38 @@ For tighter integration, you can manage the auxiliary memory within the C++ `KVC * Register your config and backend in `tensorrt_llm/_torch/attention_backend/sparse/utils.py` and `tensorrt_llm/_torch/pyexecutor/_util.py` to ensure the system routes the request to your new backend when the config is present. * Add initialization logic in `cpp/tensorrt_llm/thop/attentionOp.cpp` and `cpp/tensorrt_llm/kernels/sparseAttentionKernels.h` if new C++ level parameters are required. -## Summary and Future Work +#### Future Work -### Current Status - -Currently, the status of the sparse attention framework is as follows: - -1. **Supported Operations**: The `AttentionOp` currently supports **sparse KV cache** in the context phase and **sparse computation** in the generation phase. Other combinations (for example, sparse computation in the context phase) are not yet supported for MQA/GQA. For MLA, sparse computation is supported in both the context and generation phases. -2. **Algorithm Support**: RocketKV is supported in both the vanilla (PyTorch) backend and the TRTLLM backend, while DSA is supported in the TRTLLM backend. These implementations validate the generality and scalability of the framework. - -### Future Work - -* **Sparse Computation in Context Phase**: We plan to introduce sparse computation support for the context phase for MQA/GQA, allowing the TensorRT LLM sparse attention framework to cover more scenarios. -* **Dynamic Eviction in Generation Phase**: Dynamically evicting KV cache blocks during the generation phase poses significant challenges to KV cache flexibility. While difficult to implement in the current framework, block-level eviction appears to be a promising compromise and is under further exploration. -* **Unified Auxiliary Memory Management**: We are exploring a unified mechanism to manage auxiliary memory pools. This would allow users to define custom auxiliary spaces more flexibly while automatically inheriting advanced features from the KV cache, such as reuse and offloading. +* **Sparse Computation in Context Phase**: We plan to introduce sparse computation support for the context phase for MQA/GQA, allowing the framework to cover more scenarios. +* **Dynamic Eviction in Generation Phase**: Dynamically evicting KV cache blocks during the generation phase poses significant challenges to KV cache flexibility. Block-level eviction appears to be a promising compromise and is under exploration. +* **Unified Auxiliary Memory Management**: We are exploring a unified mechanism to manage auxiliary memory pools, allowing custom auxiliary spaces to automatically inherit advanced features from the KV cache (e.g., reuse, offloading). * **Code Refactoring**: As more sparse attention algorithms are integrated, the framework will undergo refactoring to unify code and improve maintainability. -* **Optimizations**: We are discussing further optimizations, such as improving DSA performance. + +### Kernel-Level Sparse Attention + +Unlike framework-level methods, **kernel-level sparse attention** is implemented directly inside the attention kernels. There is no external prediction/gather workflow—the kernel itself decides what to skip based on runtime criteria. + +**Skip Softmax Attention (BLASST)** is TensorRT LLM's kernel-level sparse attention method, supported on both **Hopper** and **Blackwell** GPUs for MHA/GQA/MLA, in both prefill and decode phases. It dynamically skips Softmax and BMM2 computation for KV blocks whose contribution falls below a threshold. Because the logic lives entirely inside the kernel, it requires no auxiliary data structures or framework hooks—just set `threshold_scale_factor` in the config. As a result, the runtime overhead is zero and the attention kernel performance improvement could be directly reflected in the end-to-end speedup. + +For algorithm details and end-to-end results, please refer to the following resources: +- **Paper**: [BLASST: Dynamic Blocked Attention Sparsity via Softmax Thresholding](https://arxiv.org/pdf/2512.12087) +- **NVIDIA developer blog**: [Accelerating Long-Context Inference with Skip Softmax Attention](https://developer.nvidia.com/blog/accelerating-long-context-inference-with-skip-softmax-in-nvidia-tensorrt-llm/) +- **Tech blog**: [Accelerating Long-Context Inference with Skip Softmax Attention](../blogs/tech_blog/blog16_Accelerating_Long_Context_Inference_with_Skip_Softmax_Attention.md) + +Skip Softmax Attention is supported only with the **trtllm** attention backend, implemented inside TensorRT-LLM's high-performance attention kernels: +- **Hopper prefill**: [fmha_v2](https://github.com/NVIDIA/TensorRT-LLM/tree/main/cpp/kernels/fmha_v2) +- **Hopper decode**: [XQA](https://github.com/NVIDIA/TensorRT-LLM/tree/main/cpp/kernels/xqa) +- **Blackwell**: [trtllm-gen](https://github.com/NVIDIA/TensorRT-LLM/tree/main/cpp/tensorrt_llm/kernels/trtllmGenKernels) + + +### Summary + +The following table compares the three sparse attention algorithms available in TensorRT LLM: + +| Aspect | RocketKV | DSA | Skip Softmax | +|--------|----------|-----|--------------| +| Prefill Acceleration | No | Yes | Yes | +| Decode Acceleration | Yes | Yes | Yes | +| KV Cache Reduction | Yes | No | No | +| Framework-Level Support Required | Yes | Yes | No | +| Model Native | No | Yes | No |