mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
532 lines
43 KiB
HTML
532 lines
43 KiB
HTML
<!DOCTYPE html>
|
||
<html class="writer-html5" lang="en" data-content_root="../">
|
||
<head>
|
||
<meta charset="utf-8" /><meta name="viewport" content="width=device-width, initial-scale=1" />
|
||
|
||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||
<title>Multi-Head, Multi-Query, and Group-Query Attention — tensorrt_llm documentation</title>
|
||
<link rel="stylesheet" type="text/css" href="../_static/pygments.css?v=80d5e7a1" />
|
||
<link rel="stylesheet" type="text/css" href="../_static/css/theme.css?v=19f00094" />
|
||
|
||
|
||
<!--[if lt IE 9]>
|
||
<script src="../_static/js/html5shiv.min.js"></script>
|
||
<![endif]-->
|
||
|
||
<script src="../_static/jquery.js?v=5d32c60e"></script>
|
||
<script src="../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||
<script src="../_static/documentation_options.js?v=5929fcd5"></script>
|
||
<script src="../_static/doctools.js?v=9a2dae69"></script>
|
||
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||
<script src="../_static/js/theme.js"></script>
|
||
<link rel="index" title="Index" href="../genindex.html" />
|
||
<link rel="search" title="Search" href="../search.html" />
|
||
<link rel="next" title="C++ GPT Runtime" href="gpt-runtime.html" />
|
||
<link rel="prev" title="Adding a Model" href="../architecture/add-model.html" />
|
||
</head>
|
||
|
||
<body class="wy-body-for-nav">
|
||
<div class="wy-grid-for-nav">
|
||
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
|
||
<div class="wy-side-scroll">
|
||
<div class="wy-side-nav-search" >
|
||
|
||
|
||
|
||
<a href="../index.html" class="icon icon-home">
|
||
tensorrt_llm
|
||
</a>
|
||
<div role="search">
|
||
<form id="rtd-search-form" class="wy-form" action="../search.html" method="get">
|
||
<input type="text" name="q" placeholder="Search docs" aria-label="Search docs" />
|
||
<input type="hidden" name="check_keywords" value="yes" />
|
||
<input type="hidden" name="area" value="default" />
|
||
</form>
|
||
</div>
|
||
</div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
|
||
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../overview.html">Overview</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../quick-start-guide.html">Quick Start Guide</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../release-notes.html">Release Notes</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">Installation</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../installation/linux.html">Installing on Linux</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../installation/build-from-source-linux.html">Building from Source Code on Linux</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../installation/windows.html">Installing on Windows</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../installation/build-from-source-windows.html">Building from Source Code on Windows</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">Architecture</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../architecture/overview.html">TensorRT-LLM Architecture</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../architecture/core-concepts.html">Model Definition</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../architecture/core-concepts.html#compilation">Compilation</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../architecture/core-concepts.html#runtime">Runtime</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../architecture/core-concepts.html#multi-gpu-and-multi-node-support">Multi-GPU and Multi-Node Support</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../architecture/checkpoint.html">TensorRT-LLM Checkpoint</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../architecture/workflow.html">TensorRT-LLM Build Workflow</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../architecture/add-model.html">Adding a Model</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">Advanced</span></p>
|
||
<ul class="current">
|
||
<li class="toctree-l1 current"><a class="current reference internal" href="#">Multi-Head, Multi-Query, and Group-Query Attention</a><ul>
|
||
<li class="toctree-l2"><a class="reference internal" href="#important-note">Important Note</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="#padded-and-packed-tensors">Padded and Packed Tensors</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="#context-and-generation-phases">Context and Generation Phases</a><ul>
|
||
<li class="toctree-l3"><a class="reference internal" href="#context-phase">Context Phase</a><ul>
|
||
<li class="toctree-l4"><a class="reference internal" href="#fp8-context-fmha">FP8 Context FMHA</a></li>
|
||
</ul>
|
||
</li>
|
||
<li class="toctree-l3"><a class="reference internal" href="#generation-phase">Generation Phase</a><ul>
|
||
<li class="toctree-l4"><a class="reference internal" href="#xqa-optimization">XQA Optimization</a></li>
|
||
</ul>
|
||
</li>
|
||
</ul>
|
||
</li>
|
||
<li class="toctree-l2"><a class="reference internal" href="#in-flight-batching">In-flight Batching</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="#chunked-context">Chunked Context</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="#kv-cache">KV Cache</a><ul>
|
||
<li class="toctree-l3"><a class="reference internal" href="#contiguous-kv-cache">Contiguous KV Cache</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="#paged-kv-cache">Paged KV Cache</a></li>
|
||
</ul>
|
||
</li>
|
||
<li class="toctree-l2"><a class="reference internal" href="#int8-fp8-kv-caches">INT8/FP8 KV Caches</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="#sliding-window-attention-cyclic-rolling-buffer-kv-cache">Sliding Window Attention, Cyclic (Rolling Buffer) KV Cache</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="#streamingllm">StreamingLLM</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="#beam-search">Beam-Search</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="#input-qkv-tensor">Input QKV tensor</a><ul>
|
||
<li class="toctree-l3"><a class="reference internal" href="#rotary-positional-embedding-rope">Rotary Positional Embedding (RoPE)</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="#alibi">ALiBi</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="#scaling-factor-s">Scaling factor(s)</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="#cross-attention">Cross Attention</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="#relative-attention-bias-rab">Relative Attention Bias (RAB)</a></li>
|
||
</ul>
|
||
</li>
|
||
</ul>
|
||
</li>
|
||
<li class="toctree-l1"><a class="reference internal" href="gpt-runtime.html">C++ GPT Runtime</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="graph-rewriting.html">Graph Rewriting Module</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="batch-manager.html">The Batch Manager in TensorRT-LLM</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="inference-request.html">Inference Request</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="inference-request.html#responses">Responses</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="lora.html">Run gpt-2b + LoRA using GptManager / cpp runtime</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../performance/perf-overview.html">Overview</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../performance/perf-best-practices.html">Best Practices for Tuning the Performance of TensorRT-LLM</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../performance/perf-analysis.html">Performance Analysis</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">Reference</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../reference/troubleshooting.html">Troubleshooting</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../reference/support-matrix.html">Support Matrix</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../reference/precision.html">Numerical Precision</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../reference/memory.html">Memory Usage of TensorRT-LLM</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">C++ API</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../_cpp_gen/executor.html">Executor</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../_cpp_gen/runtime.html">Runtime</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">Blogs</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../blogs/H100vsA100.html">H100 has 4.6x A100 Performance in TensorRT-LLM, achieving 10,000 tok/s at 100ms to first token</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../blogs/H200launch.html">H200 achieves nearly 12,000 tokens/sec on Llama2-13B with TensorRT-LLM</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../blogs/Falcon180B-H200.html">Falcon-180B on a single H200 GPU with INT4 AWQ, and 6.7x faster Llama-70B over A100</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../blogs/quantization-in-TRT-LLM.html">Speed up inference with SOTA quantization techniques in TRT-LLM</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../blogs/XQA-kernel.html">New XQA-kernel provides 2.4x more Llama-70B throughput within the same latency budget</a></li>
|
||
</ul>
|
||
|
||
</div>
|
||
</div>
|
||
</nav>
|
||
|
||
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
|
||
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
|
||
<a href="../index.html">tensorrt_llm</a>
|
||
</nav>
|
||
|
||
<div class="wy-nav-content">
|
||
<div class="rst-content">
|
||
<div role="navigation" aria-label="Page navigation">
|
||
<ul class="wy-breadcrumbs">
|
||
<li><a href="../index.html" class="icon icon-home" aria-label="Home"></a></li>
|
||
<li class="breadcrumb-item active">Multi-Head, Multi-Query, and Group-Query Attention</li>
|
||
<li class="wy-breadcrumbs-aside">
|
||
<a href="../_sources/advanced/gpt-attention.md.txt" rel="nofollow"> View page source</a>
|
||
</li>
|
||
</ul>
|
||
<hr/>
|
||
</div>
|
||
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
|
||
<div itemprop="articleBody">
|
||
|
||
<section id="multi-head-multi-query-and-group-query-attention">
|
||
<span id="gpt-attention"></span><h1>Multi-Head, Multi-Query, and Group-Query Attention<a class="headerlink" href="#multi-head-multi-query-and-group-query-attention" title="Link to this heading"></a></h1>
|
||
<p>This document details the implementation of multi-head attention (MHA),
|
||
multi-query attention (MQA) and group-query attention (GQA) for auto-regressive
|
||
GPT-like models in TensorRT-LLM. As a quick reminder, the multi-head attention
|
||
is the sequence of a batched matmul, a softmax and another batched matmul
|
||
described in the
|
||
<a class="reference external" href="https://arxiv.org/abs/1706.03762">Attention Is All You Need</a> article. <a class="reference external" href="https://arxiv.org/abs/1911.02150">Multi-query Attention (MQA)</a> and <a class="reference external" href="https://arxiv.org/abs/2307.09288">Group-query Attention (GQA)</a> are variants of MHA that use fewer, so-called, K/V head than the number of query heads. TensorRT-LLM, MHA, MQA and GQA are implemented by the operator <a class="reference external" href="https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/functional.py"><code class="docutils literal notranslate"><span class="pre">tensorrt_llm.functional.gpt_attention</span></code></a>.</p>
|
||
<section id="important-note">
|
||
<h2>Important Note<a class="headerlink" href="#important-note" title="Link to this heading"></a></h2>
|
||
<p>As discussed below, the current implementation supports two input modes: Padded
|
||
and packed (non-padded). As the packed mode is always more memory-efficient and
|
||
faster than the padded mode, <em><strong>support for padded mode may be removed in the
|
||
future</strong></em>.</p>
|
||
</section>
|
||
<section id="padded-and-packed-tensors">
|
||
<h2>Padded and Packed Tensors<a class="headerlink" href="#padded-and-packed-tensors" title="Link to this heading"></a></h2>
|
||
<p>In TensorRT-LLM, the GPT attention operator supports two different types
|
||
of QKV inputs: Padded and packed (i.e. non padded) inputs. The mode is
|
||
determined by the global configuration parameter <code class="docutils literal notranslate"><span class="pre">remove_input_padding</span></code> defined
|
||
in <a class="reference external" href="https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/plugin/plugin.py"><code class="docutils literal notranslate"><span class="pre">tensorrt_llm.plugin</span></code></a>.</p>
|
||
<p>When padding is enabled (that is, <code class="docutils literal notranslate"><span class="pre">remove_input_padding</span></code> is <code class="docutils literal notranslate"><span class="pre">False</span></code>), the sequences
|
||
that are shorter than the <code class="docutils literal notranslate"><span class="pre">max_sequence_length</span></code> are padded to that maximum
|
||
length. It may result in excessive memory consumption as well as unneeded
|
||
computations on padding tokens (in the various matrix multiplications that
|
||
surround the MHA block).</p>
|
||
<p>To overcome that problem, TensorRT-LLM supports a mode without padding where
|
||
the different tokens are packed together and the user provides the operator
|
||
with a 1D tensor containing the lengths of the different sequences. It is
|
||
recommended that users to always use packed mode (and support for the padded
|
||
mode may be removed in the future).</p>
|
||
</section>
|
||
<section id="context-and-generation-phases">
|
||
<h2>Context and Generation Phases<a class="headerlink" href="#context-and-generation-phases" title="Link to this heading"></a></h2>
|
||
<p>The GPT attention operator encapsulates different implementations for both
|
||
context and generation phases in auto-regressive models like GPT.</p>
|
||
<section id="context-phase">
|
||
<h3>Context Phase<a class="headerlink" href="#context-phase" title="Link to this heading"></a></h3>
|
||
<p>If the <code class="docutils literal notranslate"><span class="pre">context_fmha_type</span></code> is set to <code class="docutils literal notranslate"><span class="pre">disabled</span></code> (refer to
|
||
<a class="reference external" href="https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/plugin/plugin.py"><code class="docutils literal notranslate"><span class="pre">tensorrt_llm.plugin</span></code></a>),
|
||
the implementation maps to a sequence of GPU kernels that will store the
|
||
intermediate <code class="docutils literal notranslate"><span class="pre">Q*K^T</span></code> tensor in memory before calling the softmax operator. It
|
||
is the slowest method and the memory footprint is significant (quadratically
|
||
depends on the sequence length).</p>
|
||
<p>Otherwise, if <code class="docutils literal notranslate"><span class="pre">context_fmha_type</span></code> is set to a <code class="docutils literal notranslate"><span class="pre">enabled</span></code> or
|
||
<code class="docutils literal notranslate"><span class="pre">enabled_with_fp32_acc</span></code> (accumulation in the first batched matmul is forced to
|
||
FP32), that function will trigger a kernel that performs the MHA/MQA block
|
||
using a single kernel. For short sequences, that kernel uses a vanilla
|
||
implementation of MHA/MQA. For larger sequences, this kernel uses the Flash
|
||
Attention algorithm as described in
|
||
<a class="reference external" href="https://arxiv.org/abs/2205.14135">FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness</a>
|
||
and
|
||
<a class="reference external" href="https://arxiv.org/abs/2307.08691">FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning</a>.</p>
|
||
<p>Currently, the implementation triggers extra kernels that apply pre-processing
|
||
to the elements (like RoPE) and populate the KV cache (see below). In a future
|
||
release, the number of such kernels is planned on being reduced in order to
|
||
improve the overall performance.</p>
|
||
<section id="fp8-context-fmha">
|
||
<h4>FP8 Context FMHA<a class="headerlink" href="#fp8-context-fmha" title="Link to this heading"></a></h4>
|
||
<p>When FP8 quantization is activated, the attention can be further accelerated by
|
||
enabling FP8 Context FMHA (<code class="docutils literal notranslate"><span class="pre">use_fp8_context_fmha</span> <span class="pre">=</span> <span class="pre">enable</span></code>).</p>
|
||
<p>FP8 Paged Context FMHA is also supported with the fp8 quantization workflow.
|
||
You need to specify <code class="docutils literal notranslate"><span class="pre">use_fp8_context_fmha</span> <span class="pre">=</span> <span class="pre">enable</span></code> and
|
||
<code class="docutils literal notranslate"><span class="pre">use_paged_context_fmha</span> <span class="pre">=</span> <span class="pre">enable</span></code> at the same time.</p>
|
||
<p>Please be aware that this is an experimental feature only supported on Hopper.
|
||
If you notice a significant decrease in accuracy, it is recommended to disable
|
||
it.</p>
|
||
</section>
|
||
</section>
|
||
<section id="generation-phase">
|
||
<h3>Generation Phase<a class="headerlink" href="#generation-phase" title="Link to this heading"></a></h3>
|
||
<p>The generation phase is implemented using a single kernel called the masked
|
||
multi-head attention in TensorRT-LLM. That kernel is able to apply
|
||
pre-processing on the Q, K, and V elements on-the-fly: adds the QKV bias, applies
|
||
RoPE, and performs dequantization and quantization. TensorRT-LLM will continue to add (or
|
||
enable) additional features in future releases. For example, enable the support
|
||
for IA3.</p>
|
||
<p><em>The masked MHA kernel has a special version that distributes the work across
|
||
multiple CUDA thread-blocks on the GPU for cases where the GPU occupancy is
|
||
low. That mode called multi-block can be enabled using the <code class="docutils literal notranslate"><span class="pre">multi_block_mode</span></code>
|
||
flag. Users are recommended to test that mode in scenarios where both the batch
|
||
size and the number of heads in the model are relatively small. The exact
|
||
definition of small in that context will depend on the model of the GPU and is
|
||
hard to predict but to provide with a rule of thumb, it is worth testing that
|
||
mode when <code class="docutils literal notranslate"><span class="pre">batch_size</span> <span class="pre">*</span> <span class="pre">num_heads</span></code> is less than the number of multi-processors
|
||
on the GPU (that suggestion may evolve in the future as more research is
|
||
conducted and the software improves)</em>.</p>
|
||
<p><em>Note that even if the multi-block mode is enabled, the attention operator will
|
||
not immediately trigger the multi-block version of the GPU kernel. There is a
|
||
minimum number of tokens (input + generated) that are required for the
|
||
multi-block version to become more efficient than the “vanilla” implementation
|
||
that uses a single CUDA thread-block per head. It is controlled by an internal
|
||
heuristic.</em></p>
|
||
<p>Another note is that as the masked MHA kernels use shared memory size
|
||
proportional to sequence length, so there can be some cases that GPU’s shared
|
||
memory is not enough when multi-block mode is not enabled. To get masked MHA
|
||
kernel work in these cases, multi-block mode is forced on and a warning log is
|
||
printed.</p>
|
||
<section id="xqa-optimization">
|
||
<h4>XQA Optimization<a class="headerlink" href="#xqa-optimization" title="Link to this heading"></a></h4>
|
||
<p>Another optimization for MQA/GQA in generation phase called XQA optimization.
|
||
It is still experimental feature and support limited configurations. LLAMA2 70B
|
||
is one model that it supports.</p>
|
||
<p>Support matrix of the XQA optimization:</p>
|
||
<ul class="simple">
|
||
<li><p>FP16 / BF16 compute data type.</p></li>
|
||
<li><p>FP16 / BF16 / FP8 / INT8 KV cache data type.</p></li>
|
||
<li><p>Paged KV cache (64 / 128 tokens per block).</p></li>
|
||
</ul>
|
||
<p>This is default enabled. To disable this, you need to use the
|
||
flag <code class="docutils literal notranslate"><span class="pre">--disable_xqa</span></code> when building the engines. Note that a heuristic algorithm
|
||
is also used to decide whether to use XQA kernel or masked MHA kernel to get
|
||
better performance. That means even <code class="docutils literal notranslate"><span class="pre">--disable_xqa</span></code> is not set, XQA kernels
|
||
may not also be used. If you want to always use that kernel when possible,
|
||
<code class="docutils literal notranslate"><span class="pre">TRTLLM_FORCE_XQA=1</span></code> can be set to force use XQA kernels when the model config
|
||
is supported. Detailed supported configuration can be found function <code class="docutils literal notranslate"><span class="pre">shouldUse</span></code>
|
||
of class <code class="docutils literal notranslate"><span class="pre">DecoderXQARunner</span></code> in
|
||
<code class="docutils literal notranslate"><span class="pre">cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQARunner.h</span></code>.</p>
|
||
</section>
|
||
</section>
|
||
</section>
|
||
<section id="in-flight-batching">
|
||
<span id="inflight-batching"></span><h2>In-flight Batching<a class="headerlink" href="#in-flight-batching" title="Link to this heading"></a></h2>
|
||
<p>TensorRT-LLM supports in-flight batching of requests (also known as continuous
|
||
batching or iteration-level batching) for higher serving throughput. With this feature,
|
||
sequences in context phase can be processed together with sequences in
|
||
generation phase. The purpose of that technique is to better interleave
|
||
requests to reduce latency as well as make better use the of the GPUs.
|
||
For efficiency reasons (1), the support for inflight batching <em><strong>requires the
|
||
input tensors to be packed (no padding)</strong></em>.</p>
|
||
<p><em><strong>In the current implementation, the sequences that are going through the
|
||
context phase must be before the sequences in the generation phase in the input
|
||
tensor. For example, for sequences <code class="docutils literal notranslate"><span class="pre">S0</span></code>, <code class="docutils literal notranslate"><span class="pre">S1</span></code> and <code class="docutils literal notranslate"><span class="pre">S2</span></code>, if <code class="docutils literal notranslate"><span class="pre">S0</span></code> and <code class="docutils literal notranslate"><span class="pre">S2</span></code> are in
|
||
context phase (and <code class="docutils literal notranslate"><span class="pre">S1</span></code> in generation), tokens from <code class="docutils literal notranslate"><span class="pre">S0</span></code> and <code class="docutils literal notranslate"><span class="pre">S2</span></code> must appear
|
||
before the tokens of <code class="docutils literal notranslate"><span class="pre">S1</span></code> in the input tensor</strong></em>. The constraint may or may not
|
||
be relaxed in a future version.</p>
|
||
<p><em>(1) Padding sequences in the generation phase, that contain a single token, to
|
||
the length of the maximum input sequence is inefficient use of resources</em>.</p>
|
||
</section>
|
||
<section id="chunked-context">
|
||
<h2>Chunked Context<a class="headerlink" href="#chunked-context" title="Link to this heading"></a></h2>
|
||
<p>In the original state, the common behavior was to process all context tokens at
|
||
once. This feature splits the context into several chunks. In this way, the
|
||
context chunks can be batched with more tokens during the generation phase,
|
||
which is expected to increase the total throughput. Chunking contexts also removes
|
||
constraints on input length. To enable this feature, the FMHA paged kv-cache also
|
||
needs to be enabled. Except for the last one, the size of the context chunk needs
|
||
to be an integer multiple of the kv-cache block size. Refer to
|
||
<a class="reference internal" href="../performance/perf-best-practices.html#chunked-context"><span class="std std-ref">the performance best practices</span></a> for usage.</p>
|
||
</section>
|
||
<section id="kv-cache">
|
||
<h2>KV Cache<a class="headerlink" href="#kv-cache" title="Link to this heading"></a></h2>
|
||
<p>In the generation phase, a common optimization is to provide the MHA kernel
|
||
with a cache containing the values of the past K and V elements that have
|
||
already been computed. That cache is known as the KV cache. TensorRT-LLM uses
|
||
that technique to accelerate its generation phase. In TensorRT-LLM, there is
|
||
one KV cache per Transformer layer, which means that there are as many KV
|
||
caches as layers in a model. The current version of TensorRT-LLM supports two
|
||
different types of KV caches: <strong>contiguous</strong> and <strong>paged</strong> KV caches.</p>
|
||
<section id="contiguous-kv-cache">
|
||
<h3>Contiguous KV Cache<a class="headerlink" href="#contiguous-kv-cache" title="Link to this heading"></a></h3>
|
||
<p>The contiguous KV cache is a monolithic tensor. Its shape is:</p>
|
||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="p">[</span><span class="n">max_batch_size</span> <span class="o">*</span> <span class="n">max_beam_width</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">,</span> <span class="n">max_seqlen</span><span class="p">,</span> <span class="n">hidden_dim_per_head</span><span class="p">]</span><span class="o">.</span>
|
||
</pre></div>
|
||
</div>
|
||
<p>That implementation uses a lot more memory than needed when the sequences are
|
||
shorter than the maximum sequence length (even if they end up close to the
|
||
limit after the generation of many output tokens, it may take a lot of steps to
|
||
reach that point).</p>
|
||
</section>
|
||
<section id="paged-kv-cache">
|
||
<h3>Paged KV Cache<a class="headerlink" href="#paged-kv-cache" title="Link to this heading"></a></h3>
|
||
<p>The paged KV cache decomposes the KV cache into blocks that are distributed to
|
||
the different requests by a cache manager during processing. That cache manager
|
||
keeps track of the sequences, allocate new blocks from a pool and recycle those
|
||
blocks when required. See the simplified implementation of
|
||
<a class="reference external" href="https://github.com/NVIDIA/TensorRT-LLM/tree/rel/tensorrt_llm/runtime/kv_cache_manager.py"><code class="docutils literal notranslate"><span class="pre">tensorrt_llm.runtime.KVCacheManager</span></code></a>.
|
||
A more efficient C++ implementation is included in the
|
||
<a class="reference external" href="https://github.com/NVIDIA/TensorRT-LLM/tree/rel/cpp/include/tensorrt_llm/batch_manager">Batch Manager</a>.</p>
|
||
</section>
|
||
</section>
|
||
<section id="int8-fp8-kv-caches">
|
||
<h2>INT8/FP8 KV Caches<a class="headerlink" href="#int8-fp8-kv-caches" title="Link to this heading"></a></h2>
|
||
<p>In its current implementation, even if the rest of the network runs in INT8 or
|
||
FP8, the GPT attention operator works with FP32, FP16, and BFloat16 inputs and
|
||
outputs. However, TensorRT-LLM supports INT8 and FP8
|
||
(<code class="docutils literal notranslate"><span class="pre">kv_cache_quant_mode=QuantMode.INT8_KV_CACHE</span></code> and
|
||
<code class="docutils literal notranslate"><span class="pre">kv_cache_quant_mode=QuantMode.FP8_KV_CACHE</span></code>) KV caches.</p>
|
||
<p>The GPT attention operator populates the KV cache. When INT8 or FP8 KV caches
|
||
are enabled, the input values have to be quantized to 8 bits using a scaling
|
||
factor. For quantization, the scaling factor is stored in the
|
||
<code class="docutils literal notranslate"><span class="pre">kv_cache_scaling_factor</span></code> tensor. Its shape is <code class="docutils literal notranslate"><span class="pre">[1]</span></code> and only per-tensor
|
||
quantization is supported in the current version. Quantization uses inversed scale
|
||
since it does multiply as <code class="docutils literal notranslate"><span class="pre">fp_value</span> <span class="pre">*</span> <span class="pre">(1.0</span> <span class="pre">/</span> <span class="pre">kv_cache_scaling_factor)</span></code> in plugin.</p>
|
||
<p>During generation, the values read from the cache are dequantized on-the-fly in
|
||
the MHA/MQA kernel, dequantization can be described as
|
||
<code class="docutils literal notranslate"><span class="pre">quantized_value</span> <span class="pre">*</span> <span class="pre">kv_cache_scaling_factor</span></code>.</p>
|
||
</section>
|
||
<section id="sliding-window-attention-cyclic-rolling-buffer-kv-cache">
|
||
<h2>Sliding Window Attention, Cyclic (Rolling Buffer) KV Cache<a class="headerlink" href="#sliding-window-attention-cyclic-rolling-buffer-kv-cache" title="Link to this heading"></a></h2>
|
||
<p>TensorRT-LLM has a feature called <code class="docutils literal notranslate"><span class="pre">Cyclic</span> <span class="pre">KV</span> <span class="pre">Cache</span></code>, which treats the kv cache
|
||
as a circular buffer. This means that it only stores the kv cache for the last N
|
||
tokens, where N is determined by the <code class="docutils literal notranslate"><span class="pre">max_attention_window_size</span></code> parameter in
|
||
<code class="docutils literal notranslate"><span class="pre">GenerationSession.setup</span></code>. You can see examples of this in the <code class="docutils literal notranslate"><span class="pre">run.py</span></code> or
|
||
<code class="docutils literal notranslate"><span class="pre">summarize.py</span></code> files. When the cache is full, new tokens’ kv cache will
|
||
overwrite the “least recently used” caches.</p>
|
||
<p>In the context phase, if the input length surpasses the <code class="docutils literal notranslate"><span class="pre">max_attention_window_size</span></code>,
|
||
<code class="docutils literal notranslate"><span class="pre">Sliding</span> <span class="pre">Window</span> <span class="pre">Attention</span></code> will be activated. This serves the same function as
|
||
the <code class="docutils literal notranslate"><span class="pre">sliding</span> <span class="pre">window_size</span></code>.</p>
|
||
<p>This feature helps to reduce the memory footprint of the kv cache when
|
||
dealing with very long sequences.</p>
|
||
<p>_Note that the cyclic kv cache feature doesn’t work with beam searching currently as
|
||
the context kv cache are shared across beams.</p>
|
||
<p><em>The experimental feature, which allows different <code class="docutils literal notranslate"><span class="pre">max_attention_window_size</span></code> values
|
||
for each layer, is also supported. To utilize this feature, simply provide an
|
||
<code class="docutils literal notranslate"><span class="pre">int32</span> <span class="pre">torch.Tensor</span></code> with a shape of <code class="docutils literal notranslate"><span class="pre">[num_layers]</span></code> to the <code class="docutils literal notranslate"><span class="pre">GenerationSession.setup</span></code>.
|
||
This tensor will serve as the buffer for <code class="docutils literal notranslate"><span class="pre">max_attention_window_size</span></code>,
|
||
setting unique values for each layer. However, it’s important to note that the
|
||
memory allocation for the kv cache still relies on the buffer’s maximum value.</em></p>
|
||
</section>
|
||
<section id="streamingllm">
|
||
<h2>StreamingLLM<a class="headerlink" href="#streamingllm" title="Link to this heading"></a></h2>
|
||
<p>The StreamingLLM feature uses a window attention to perform efficient and stable LLM
|
||
on long texts, which means that only <code class="docutils literal notranslate"><span class="pre">N</span></code> tokens need to be stored in the KV cache.
|
||
Similar to the cyclic KV cache feature in TensorRT-LLM, <code class="docutils literal notranslate"><span class="pre">max_attention_window_size</span></code>
|
||
parameter is used to determine <code class="docutils literal notranslate"><span class="pre">N</span></code>. Different from the cyclic KV cache feature,
|
||
the first <code class="docutils literal notranslate"><span class="pre">S</span></code> tokens, called sink tokens, are always kept in the attention window,
|
||
where <code class="docutils literal notranslate"><span class="pre">S</span></code> is determined by <code class="docutils literal notranslate"><span class="pre">sink_token_length</span></code> parameter in <code class="docutils literal notranslate"><span class="pre">GenerationSession.setup</span></code>.
|
||
But in context phase, the self-attentions is dense in the official implementation of
|
||
StreamingLLM, and it uses all of the tokens for computation and only saves <code class="docutils literal notranslate"><span class="pre">N</span></code> tokens
|
||
to the KV cache.</p>
|
||
<p>In addition, the relative position embedding is also changed in StreamingLLM.
|
||
When determining the relative distance and adding positional information to tokens,
|
||
StreamingLLM use the positions within the cache rather than those in the original text.</p>
|
||
<p><code class="docutils literal notranslate"><span class="pre">streamingllm</span></code> flag is used to enable this feature.</p>
|
||
</section>
|
||
<section id="beam-search">
|
||
<h2>Beam-Search<a class="headerlink" href="#beam-search" title="Link to this heading"></a></h2>
|
||
<p>The GPT attention operator supports beam-search. In the context phase, a single
|
||
beam is computed per input sequence. In the generation phase, the MHA/MQA/GQA
|
||
kernel uses an additional tensor to reconstruct the correct path for each beam.
|
||
That tensor is called the <code class="docutils literal notranslate"><span class="pre">cache_indirection</span></code>. Its shape is <code class="docutils literal notranslate"><span class="pre">[batch_size,</span> <span class="pre">beam_width,</span> <span class="pre">max_seqlen]</span></code>.</p>
|
||
<p>For a sequence <code class="docutils literal notranslate"><span class="pre">si</span></code>, a beam <code class="docutils literal notranslate"><span class="pre">bi</span></code> and a token <code class="docutils literal notranslate"><span class="pre">ti</span></code>, the element
|
||
<code class="docutils literal notranslate"><span class="pre">cache_indirection[si][bi][ti]</span></code> is an integer between <code class="docutils literal notranslate"><span class="pre">0</span></code> and <code class="docutils literal notranslate"><span class="pre">beam_width-1</span></code>
|
||
that indicates which path in the beam to read the K and V elements from in the
|
||
KV cache. This tensor is populated in the sampling stage.</p>
|
||
</section>
|
||
<section id="input-qkv-tensor">
|
||
<h2>Input QKV tensor<a class="headerlink" href="#input-qkv-tensor" title="Link to this heading"></a></h2>
|
||
<p>The input QKV tensor packs the Q, K and V tensors (concatenated along the last
|
||
dimension) after the projection of the hidden states. It is a 3D tensor. RoPE
|
||
and quantization to INT8 or FP8 (when needed) are performed by the GPT
|
||
attention operator.</p>
|
||
<p>In padded mode, its shape is <code class="docutils literal notranslate"><span class="pre">[batch_beam_size,</span> <span class="pre">max_seqlen,</span> <span class="pre">3</span> <span class="pre">*</span> <span class="pre">hidden_dim]</span></code>
|
||
where <code class="docutils literal notranslate"><span class="pre">batch_beam_size</span></code> is the batch size (number of sequences) for the context
|
||
phase and the batch size multiplied by the beam width for the generation phase.
|
||
Having different beam widths per sequence in padded mode is not supported.</p>
|
||
<p>In packed mode, its shape is <code class="docutils literal notranslate"><span class="pre">[num_tokens,</span> <span class="pre">3</span> <span class="pre">*</span> <span class="pre">hidden_dim]</span></code> where
|
||
<code class="docutils literal notranslate"><span class="pre">num_tokens</span></code> is the total number of tokens in the batch. For the sequences in
|
||
context phase, the number of tokens of a sequence corresponds to its input
|
||
length (even if the beam width is greater than <code class="docutils literal notranslate"><span class="pre">1</span></code> for beam search). For the
|
||
sequences in generation phase, there are <code class="docutils literal notranslate"><span class="pre">beam_width</span></code> tokens per sequence. The
|
||
beam width can be different for each sequence.</p>
|
||
<p>In other words, the pseudo-code to compute the number of tokens is:</p>
|
||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">num_tokens</span> <span class="o">=</span> <span class="mi">0</span>
|
||
|
||
<span class="c1"># Add the length of each sequence in context phase.</span>
|
||
<span class="k">for</span> <span class="n">seq</span> <span class="ow">in</span> <span class="n">context_phase</span><span class="p">:</span>
|
||
<span class="n">num_tokens</span> <span class="o">+=</span> <span class="n">seq</span><span class="o">.</span><span class="n">length</span>
|
||
|
||
<span class="c1"># Add the width of the beam for each sequence in generation phase.</span>
|
||
<span class="k">for</span> <span class="n">seq</span> <span class="ow">in</span> <span class="n">generation_phase</span><span class="p">:</span>
|
||
<span class="n">num_tokens</span> <span class="o">+=</span> <span class="n">seq</span><span class="o">.</span><span class="n">beam_width</span>
|
||
</pre></div>
|
||
</div>
|
||
<section id="rotary-positional-embedding-rope">
|
||
<h3>Rotary Positional Embedding (RoPE)<a class="headerlink" href="#rotary-positional-embedding-rope" title="Link to this heading"></a></h3>
|
||
<p>The GPT attention operation can perform the computation of the Rotary
|
||
Positional Embedding (RoPE). When that operation is enabled,
|
||
<code class="docutils literal notranslate"><span class="pre">rotary_embedding_dim</span></code> is set to a value greater than 0, it is fused with other
|
||
operations. The GPT operator supports GPT-NeoX and GPT-J forms of RoPE by
|
||
setting <code class="docutils literal notranslate"><span class="pre">position_embedding_type</span></code> to <code class="docutils literal notranslate"><span class="pre">PositionEmbeddingType.rope_gpt_neox</span></code>
|
||
or <code class="docutils literal notranslate"><span class="pre">PositionEmbeddingType.rope_gptj</span></code>.</p>
|
||
</section>
|
||
<section id="alibi">
|
||
<h3>ALiBi<a class="headerlink" href="#alibi" title="Link to this heading"></a></h3>
|
||
<p>The GPT attention operator can apply ALiBi to the result of the <code class="docutils literal notranslate"><span class="pre">Q*K^T</span></code>
|
||
product. The bias is computed on-the-fly from the ALiBi slopes in the optimized
|
||
kernel.</p>
|
||
</section>
|
||
<section id="scaling-factor-s">
|
||
<h3>Scaling factor(s)<a class="headerlink" href="#scaling-factor-s" title="Link to this heading"></a></h3>
|
||
<p>In MHA, the output of the <code class="docutils literal notranslate"><span class="pre">Q*K^T</span></code> product is scaled by a constant value that
|
||
is computed as:</p>
|
||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">norm_factor</span> <span class="o">=</span> <span class="mf">1.</span><span class="n">f</span> <span class="o">/</span> <span class="p">(</span><span class="n">q_scaling</span> <span class="o">*</span> <span class="n">sqrt</span><span class="p">(</span><span class="n">head_size</span><span class="p">))</span><span class="o">.</span>
|
||
</pre></div>
|
||
</div>
|
||
</section>
|
||
<section id="cross-attention">
|
||
<h3>Cross Attention<a class="headerlink" href="#cross-attention" title="Link to this heading"></a></h3>
|
||
<p>On top of the MHA as self attention needed by GPT-style decoder-only models, <code class="docutils literal notranslate"><span class="pre">gpt_attention</span></code> also supports cross attention.</p>
|
||
<p>This enables using <code class="docutils literal notranslate"><span class="pre">gpt_attention</span></code> in a broader aspect as a generic decoder component. For example, the Encoder-Decoder model uses <code class="docutils literal notranslate"><span class="pre">gpt_attention</span></code> to issue both the self attention and cross attention modules in its Decoder.</p>
|
||
</section>
|
||
<section id="relative-attention-bias-rab">
|
||
<h3>Relative Attention Bias (RAB)<a class="headerlink" href="#relative-attention-bias-rab" title="Link to this heading"></a></h3>
|
||
<p>Relative attention bias (RAB) is a kind of relative position modeling, adding an attention bias (<code class="docutils literal notranslate"><span class="pre">Q*K^T+bias</span></code>) according to relative positions. RAB is a lightweight method to include the information of relative positions, and is used in the popular Encoder-Decoder model <a class="reference external" href="https://huggingface.co/docs/transformers/model_doc/t5">T5</a> and also other models in the T5 family.</p>
|
||
<p>RAB is supported in two modes: i) regular mode which user passes in relative attention bias computed ahead of MHA. ii) implicit mode which computes the relative attention bias on the fly in MHA. The implicit mode suits the case when the relative attention bias is too large to fit in memory and can be turned on by passing in <code class="docutils literal notranslate"><span class="pre">max_distance</span></code>.</p>
|
||
</section>
|
||
</section>
|
||
</section>
|
||
|
||
|
||
</div>
|
||
</div>
|
||
<footer><div class="rst-footer-buttons" role="navigation" aria-label="Footer">
|
||
<a href="../architecture/add-model.html" class="btn btn-neutral float-left" title="Adding a Model" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
|
||
<a href="gpt-runtime.html" class="btn btn-neutral float-right" title="C++ GPT Runtime" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
|
||
</div>
|
||
|
||
<hr/>
|
||
|
||
<div role="contentinfo">
|
||
<jinja2.runtime.BlockReference object at 0x7f0d2226a290>
|
||
|
||
<div class="footer">
|
||
<p>
|
||
Copyright © 2024 NVIDIA Corporation
|
||
</p>
|
||
<p>
|
||
<a class="Link" href="https://www.nvidia.com/en-us/about-nvidia/privacy-policy/" target="_blank" rel="noopener"
|
||
data-cms-ai="0">Privacy Policy</a> |
|
||
<a class="Link" href="https://www.nvidia.com/en-us/about-nvidia/privacy-center/" target="_blank" rel="noopener"
|
||
data-cms-ai="0">Manage My Privacy</a> |
|
||
<a class="Link" href="https://www.nvidia.com/en-us/preferences/start/" target="_blank" rel="noopener"
|
||
data-cms-ai="0">Do Not Sell or Share My Data</a> |
|
||
<a class="Link" href="https://www.nvidia.com/en-us/about-nvidia/terms-of-service/" target="_blank"
|
||
rel="noopener" data-cms-ai="0">Terms of Service</a> |
|
||
<a class="Link" href="https://www.nvidia.com/en-us/about-nvidia/accessibility/" target="_blank" rel="noopener"
|
||
data-cms-ai="0">Accessibility</a> |
|
||
<a class="Link" href="https://www.nvidia.com/en-us/about-nvidia/company-policies/" target="_blank"
|
||
rel="noopener" data-cms-ai="0">Corporate Policies</a> |
|
||
<a class="Link" href="https://www.nvidia.com/en-us/product-security/" target="_blank" rel="noopener"
|
||
data-cms-ai="0">Product Security</a> |
|
||
<a class="Link" href="https://www.nvidia.com/en-us/contact/" target="_blank" rel="noopener"
|
||
data-cms-ai="0">Contact</a>
|
||
</p>
|
||
</div>
|
||
|
||
|
||
</div>
|
||
|
||
|
||
|
||
</footer>
|
||
</div>
|
||
</div>
|
||
</section>
|
||
</div>
|
||
<script>
|
||
jQuery(function () {
|
||
SphinxRtdTheme.Navigation.enable(true);
|
||
});
|
||
</script>
|
||
|
||
</body>
|
||
</html> |