diff --git a/README.md b/README.md index ad0d485137..b796de6a1b 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,12 @@ This branch is a prototype and not stable for production use. PRs are not accept ## Tech Blogs +* [02/06] Accelerating Long-Context Inference with Skip Softmax Attention +✨ [➡️ link](https://nvidia.github.io/TensorRT-LLM/blogs/tech_blog/blog16_Accelerating_Long_Context_Inference_with_Skip_Softmax_Attention.html) + +* [01/09] Optimizing DeepSeek-V3.2 on NVIDIA Blackwell GPUs +✨ [➡️ link](https://nvidia.github.io/TensorRT-LLM/blogs/tech_blog/blog15_Optimizing_DeepSeek_V32_on_NVIDIA_Blackwell_GPUs) + * [10/13] Scaling Expert Parallelism in TensorRT LLM (Part 3: Pushing the Performance Boundary) ✨ [➡️ link](https://nvidia.github.io/TensorRT-LLM/blogs/tech_blog/blog14_Scaling_Expert_Parallelism_in_TensorRT-LLM_part3.html) diff --git a/docs/source/blogs/media/tech_blog16_blackwell_decode.png b/docs/source/blogs/media/tech_blog16_blackwell_decode.png new file mode 100644 index 0000000000..ae138da51c Binary files /dev/null and b/docs/source/blogs/media/tech_blog16_blackwell_decode.png differ diff --git a/docs/source/blogs/media/tech_blog16_blackwell_prefill.png b/docs/source/blogs/media/tech_blog16_blackwell_prefill.png new file mode 100644 index 0000000000..08d626910d Binary files /dev/null and b/docs/source/blogs/media/tech_blog16_blackwell_prefill.png differ diff --git a/docs/source/blogs/media/tech_blog16_blasst.jpg b/docs/source/blogs/media/tech_blog16_blasst.jpg new file mode 100644 index 0000000000..4b96efd03c Binary files /dev/null and b/docs/source/blogs/media/tech_blog16_blasst.jpg differ diff --git a/docs/source/blogs/media/tech_blog16_hopper_decode.png b/docs/source/blogs/media/tech_blog16_hopper_decode.png new file mode 100644 index 0000000000..5fb9799c10 Binary files /dev/null and b/docs/source/blogs/media/tech_blog16_hopper_decode.png differ diff --git a/docs/source/blogs/media/tech_blog16_hopper_prefill.png b/docs/source/blogs/media/tech_blog16_hopper_prefill.png new file mode 100644 index 0000000000..b0663f72a9 Binary files /dev/null and b/docs/source/blogs/media/tech_blog16_hopper_prefill.png differ diff --git a/docs/source/blogs/tech_blog/blog16_Accelerating_Long_Context_Inference_with_Skip_Softmax_Attention.md b/docs/source/blogs/tech_blog/blog16_Accelerating_Long_Context_Inference_with_Skip_Softmax_Attention.md new file mode 100644 index 0000000000..466803e56d --- /dev/null +++ b/docs/source/blogs/tech_blog/blog16_Accelerating_Long_Context_Inference_with_Skip_Softmax_Attention.md @@ -0,0 +1,276 @@ +# Accelerating Long-Context Inference with Skip Softmax Attention + +As context lengths grow from thousands to hundreds of thousands of tokens, attention computation becomes a major bottleneck in long-context LLM inference. TensorRT-LLM provides a [sparse attention framework](../../features/sparse-attention.md#framework-level-sparse-attention) that supports techniques like KV cache compression and sparse pattern prediction, featured in [RocketKV](https://arxiv.org/pdf/2502.14051) and [DSA](https://github.com/deepseek-ai/DeepSeek-V3.2-Exp/blob/main/DeepSeek_V3_2.pdf). However, these require framework-level support—additional token selection steps, auxiliary data structures, and kernel modifications are required, compared to the vanilla attention architecture. This complexity introduces **runtime overhead** that can offset performance gains, particularly when context lengths are not long enough to amortize the extra work. + +In this blog, we introduce **Skip Softmax Attention**, a drop-in sparse attention technique that is designed to accelerate the existing pretrained models that use standard attention mechanisms like MHA, GQA, or MLA. Skip Softmax Attention based on top of the Flash Attention algorithm and only requires modifying the existing **attention kernels**. Due to this simplicity, the end-to-end performance gain is more predictable. In addition, it is only an approximation method of the attention kernel computation, making it compatible with nearly all the other features, such as FP8 attention, KV cache reuse, chunked prefill etc. + +## Table of Contents +- [Accelerating Long-Context Inference with Skip Softmax Attention](#accelerating-long-context-inference-with-skip-softmax-attention) + - [Table of Contents](#table-of-contents) + - [Method Overview](#method-overview) + - [Example Usage](#example-usage) + - [Accuracy Evaluation](#accuracy-evaluation) + - [Performance Benchmark](#performance-benchmark) + - [Kernel Performance](#kernel-performance) + - [End-to-end Performance](#end-to-end-performance) + - [Reproduction](#reproduction) + - [Accuracy evaluation (LongBench V1/V2)](#accuracy-evaluation-longbench-v1v2) + - [End-to-end performance (TTFT/TPOT)](#end-to-end-performance-ttfttpot) + - [Conclusion](#conclusion) + +## Method Overview + +The idea of Skip Softmax Attention is to compare the local maximum $\tilde{m}_i^{(j)}$ of $Q \cdot K^T$ with the running global maximum $m_i^{(j)}$, and skip the softmax (exp) and BMM2 calculation for blocks that are below a certain threshold $\lambda$: + +$$\tilde{m}_i^{(j)} - m_i^{(j)} < \lambda$$ + +In this way, we can indirectly control the sparsity via the threshold. The threshold is set to be inversely proportional to the context length, i.e., the longer the context, the smaller the threshold is needed to achieve the same sparsity. + +The method is fully dynamic, and can be applied to both the prefilling and decoding. The algorithm of Skip Softmax Attention is described in the paper [BLASST: Dynamic Blocked Attention Sparsity via Softmax Thresholding](https://arxiv.org/pdf/2512.12087). We have also published a [Developer Blog](https://developer.nvidia.com/blog/accelerating-long-context-inference-with-skip-softmax-in-nvidia-tensorrt-llm/) for explanation. Please refer to these resources for in-depth dive into the algorithm details. We will focus on the application of Skip Softmax Attention in TensorRT-LLM to accelerate long-context inference. + +

+ BLASST Illustration +

+ +## Example Usage + +Enabling Skip Softmax Attention is pretty simple: we only need to configure the `SkipSoftmaxAttentionConfig` and pass it to the `LLM` API: + +```python +from tensorrt_llm import LLM +from tensorrt_llm.llmapi import SkipSoftmaxAttentionConfig + +sparse_attention_config = SkipSoftmaxAttentionConfig(threshold_scale_factor=1000.0) + +# Additionally, the threshold_scale_factor for prefill and decode could be separately configured. +sparse_attention_config = SkipSoftmaxAttentionConfig(threshold_scale_factor={"prefill": 1000.0, "decode": 500.0}) + +llm = LLM( + model="Qwen/Qwen3-30B-A3B-Instruct-2507", + sparse_attention_config=sparse_attention_config, + # Other LLM arguments... +) +``` + +The configuration could also be specified through the extra LLM API options YAML file. An example to launch an OpenAI-compatible endpoint is shown below: + +```bash +cat >extra_llm_api_options.yaml <extra_llm_api_options.yaml < + + +

Hopper (H200)

+

Prefill

+ Hopper prefill kernel +

Decode

+ Hopper decode kernel + + +

Blackwell (B200)

+

Prefill

+ Blackwell prefill kernel +

Decode

+ Blackwell decode kernel + + + + +Skip Softmax Attention could further boost the performance of FP8 attention, though the gain is less significant compared to BF16. + + +### End-to-end Performance + +We benchmark the end-to-end performance to demonstrate the benefit of Skip Softmax Attention. Due to the quadratic complexity of the attention, the TTFT in long-context scenarios is often a severe blocker for real-world usage. Skip Softmax Attention can significantly reduce the TTFT by accelerating the prefilling kernel, and the TPOT can also be reduced if the context length is long enough. The experiemnt is conducted on a single H200 or B200 GPU, using the exact same dataset as the accuracy evaluation. + +**LongBench V1, avg ISL=10k, OSL=6:** + +| Target Sparsity | TTFT/ms (H200) | TPOT/ms (H200) | TTFT/ms (B200) | TPOT/ms (B200) | +|:--------------:|------------------:|-----------------:|--------------------:|--------------------:| +| 0.0 | 9419.61 | 1731.80 | 4854.55 | 928.45 | +| 0.1 | 9519.40 | 1746.73 | 4758.06 | 909.08 | +| 0.2 | 9417.36 | 1729.74 | 4794.23 | 916.64 | +| 0.3 | 9304.48 | 1711.27 | 4770.26 | 913.51 | +| 0.4 | 9139.85 | 1684.78 | 4672.09 | 896.25 | +| 0.5 | 8847.43 | 1633.08 | 4548.07 | 873.80 | +| 0.6 | 8437.45 | 1560.64 | 4459.08 | 858.60 | +| 0.7 | 8134.72 | 1508.60 | 4385.12 | 846.64 | +| 0.8 | 8107.73 | 1507.82 | 4348.80 | 831.88 | +| 0.9 | 8130.39 | 1516.16 | 4150.44 | 798.93 | + +LongBench V1 results are reported with concurrency 64. Due to the nature of in-flight batching, the decoding requests might be piggybacked with the prefilling requests, so the TPOT is relatively high. + +**LongBench V2, avg ISL=130k, OSL=200:** + +| Target Sparsity | TTFT/ms (H200) | TPOT/ms (H200) | TTFT/ms (B200) | TPOT/ms (B200) | +|:--------------:|------------------:|-----------------:|--------------------:|--------------------:| +| 0.0 | 16486.70 | 9.34 | 6990.59 | 6.30 | +| 0.1 | 16487.54 | 8.61 | 7024.50 | 6.30 | +| 0.2 | 16169.69 | 8.61 | 6687.21 | 6.34 | +| 0.3 | 15750.17 | 8.46 | 6616.12 | 6.33 | +| 0.4 | 15288.68 | 8.61 | 6432.32 | 6.27 | +| 0.5 | 14554.04 | 8.45 | 6193.92 | 6.29 | +| 0.6 | 14323.08 | 8.44 | 5966.53 | 6.32 | +| 0.7 | 13871.32 | 8.42 | 5769.19 | 6.31 | +| 0.8 | 12922.99 | 8.58 | 5605.66 | 6.23 | +| 0.9 | 12507.95 | 8.58 | 5276.67 | 6.29 | + +Due to the extremely long context length, we only run LongBench V2 with concurrency 1. In this scenario, the prefilling/decoding is better separated and we can observe how is TTFT/TPOT affected by the sparsity. Note that the speedup for decoding is less pronounced under small batch size. Small batch size and small number of heads (with TP) are more close to real-world usage for long-context serving due to the limit of SLO, and we are actively optimizing the decoding performance under such scenarios. + +## Reproduction + +We provide the commands to reproduce the results in the previous context, as a showcase of how to evaluate the accuracy and benchmark the performance for Skip Softmax Attention. + +### Accuracy evaluation (LongBench V1/V2) + +Both LongBench V1 and V2 are integrated into the TensorRT-LLM accuracy test suite, `trtllm-eval`. Here are the example scripts to run the accuracy evaluation: + +```bash +# Dump the extra LLM API options YAML file. +cat >extra_llm_api_options.yaml <