TensorRT-LLMs/speculative_decoding.html
2024-06-05 21:59:38 +08:00

323 lines
25 KiB
HTML
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<!DOCTYPE html>
<html class="writer-html5" lang="en" data-content_root="./">
<head>
<meta charset="utf-8" /><meta name="viewport" content="width=device-width, initial-scale=1" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Speculative Sampling &mdash; tensorrt_llm documentation</title>
<link rel="stylesheet" type="text/css" href="_static/pygments.css?v=80d5e7a1" />
<link rel="stylesheet" type="text/css" href="_static/css/theme.css?v=19f00094" />
<!--[if lt IE 9]>
<script src="_static/js/html5shiv.min.js"></script>
<![endif]-->
<script src="_static/jquery.js?v=5d32c60e"></script>
<script src="_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
<script src="_static/documentation_options.js?v=5929fcd5"></script>
<script src="_static/doctools.js?v=888ff710"></script>
<script src="_static/sphinx_highlight.js?v=dc90522c"></script>
<script src="_static/js/theme.js"></script>
<link rel="index" title="Index" href="genindex.html" />
<link rel="search" title="Search" href="search.html" />
</head>
<body class="wy-body-for-nav">
<div class="wy-grid-for-nav">
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
<div class="wy-side-scroll">
<div class="wy-side-nav-search" >
<a href="index.html" class="icon icon-home">
tensorrt_llm
</a>
<div role="search">
<form id="rtd-search-form" class="wy-form" action="search.html" method="get">
<input type="text" name="q" placeholder="Search docs" aria-label="Search docs" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
</div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="overview.html">Overview</a></li>
<li class="toctree-l1"><a class="reference internal" href="quick-start-guide.html">Quick Start Guide</a></li>
<li class="toctree-l1"><a class="reference internal" href="release-notes.html">Release Notes</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Installation</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="installation/linux.html">Installing on Linux</a></li>
<li class="toctree-l1"><a class="reference internal" href="installation/build-from-source-linux.html">Building from Source Code on Linux</a></li>
<li class="toctree-l1"><a class="reference internal" href="installation/windows.html">Installing on Windows</a></li>
<li class="toctree-l1"><a class="reference internal" href="installation/build-from-source-windows.html">Building from Source Code on Windows</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Architecture</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="architecture/overview.html">TensorRT-LLM Architecture</a></li>
<li class="toctree-l1"><a class="reference internal" href="architecture/core-concepts.html">Model Definition</a></li>
<li class="toctree-l1"><a class="reference internal" href="architecture/core-concepts.html#compilation">Compilation</a></li>
<li class="toctree-l1"><a class="reference internal" href="architecture/core-concepts.html#runtime">Runtime</a></li>
<li class="toctree-l1"><a class="reference internal" href="architecture/core-concepts.html#multi-gpu-and-multi-node-support">Multi-GPU and Multi-Node Support</a></li>
<li class="toctree-l1"><a class="reference internal" href="architecture/checkpoint.html">TensorRT-LLM Checkpoint</a></li>
<li class="toctree-l1"><a class="reference internal" href="architecture/workflow.html">TensorRT-LLM Build Workflow</a></li>
<li class="toctree-l1"><a class="reference internal" href="architecture/add-model.html">Adding a Model</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Advanced</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="advanced/gpt-attention.html">Multi-Head, Multi-Query, and Group-Query Attention</a></li>
<li class="toctree-l1"><a class="reference internal" href="advanced/gpt-runtime.html">C++ GPT Runtime</a></li>
<li class="toctree-l1"><a class="reference internal" href="advanced/graph-rewriting.html">Graph Rewriting Module</a></li>
<li class="toctree-l1"><a class="reference internal" href="advanced/batch-manager.html">The Batch Manager in TensorRT-LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="advanced/inference-request.html">Inference Request</a></li>
<li class="toctree-l1"><a class="reference internal" href="advanced/inference-request.html#responses">Responses</a></li>
<li class="toctree-l1"><a class="reference internal" href="advanced/lora.html">Run gpt-2b + LoRA using GptManager / cpp runtime</a></li>
<li class="toctree-l1"><a class="reference internal" href="advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="performance/perf-overview.html">Overview</a></li>
<li class="toctree-l1"><a class="reference internal" href="performance/perf-best-practices.html">Best Practices for Tuning the Performance of TensorRT-LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="performance/perf-analysis.html">Performance Analysis</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Reference</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="reference/troubleshooting.html">Troubleshooting</a></li>
<li class="toctree-l1"><a class="reference internal" href="reference/support-matrix.html">Support Matrix</a></li>
<li class="toctree-l1"><a class="reference internal" href="reference/precision.html">Numerical Precision</a></li>
<li class="toctree-l1"><a class="reference internal" href="reference/memory.html">Memory Usage of TensorRT-LLM</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">C++ API</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="_cpp_gen/runtime.html">Runtime</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="python-api/tensorrt_llm.layers.html">Layers</a></li>
<li class="toctree-l1"><a class="reference internal" href="python-api/tensorrt_llm.functional.html">Functionals</a></li>
<li class="toctree-l1"><a class="reference internal" href="python-api/tensorrt_llm.models.html">Models</a></li>
<li class="toctree-l1"><a class="reference internal" href="python-api/tensorrt_llm.plugin.html">Plugin</a></li>
<li class="toctree-l1"><a class="reference internal" href="python-api/tensorrt_llm.quantization.html">Quantization</a></li>
<li class="toctree-l1"><a class="reference internal" href="python-api/tensorrt_llm.runtime.html">Runtime</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Blogs</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="blogs/H100vsA100.html">H100 has 4.6x A100 Performance in TensorRT-LLM, achieving 10,000 tok/s at 100ms to first token</a></li>
<li class="toctree-l1"><a class="reference internal" href="blogs/H200launch.html">H200 achieves nearly 12,000 tokens/sec on Llama2-13B with TensorRT-LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="blogs/Falcon180B-H200.html">Falcon-180B on a single H200 GPU with INT4 AWQ, and 6.7x faster Llama-70B over A100</a></li>
<li class="toctree-l1"><a class="reference internal" href="blogs/quantization-in-TRT-LLM.html">Speed up inference with SOTA quantization techniques in TRT-LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="blogs/XQA-kernel.html">New XQA-kernel provides 2.4x more Llama-70B throughput within the same latency budget</a></li>
</ul>
</div>
</div>
</nav>
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
<a href="index.html">tensorrt_llm</a>
</nav>
<div class="wy-nav-content">
<div class="rst-content">
<div role="navigation" aria-label="Page navigation">
<ul class="wy-breadcrumbs">
<li><a href="index.html" class="icon icon-home" aria-label="Home"></a></li>
<li class="breadcrumb-item active">Speculative Sampling</li>
<li class="wy-breadcrumbs-aside">
<a href="_sources/speculative_decoding.md.txt" rel="nofollow"> View page source</a>
</li>
</ul>
<hr/>
</div>
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
<div itemprop="articleBody">
<section id="speculative-sampling">
<h1>Speculative Sampling<a class="headerlink" href="#speculative-sampling" title="Link to this heading"></a></h1>
<p>Speculative Sampling (also referred to as Speculative Decoding) is a set of techniques designed to allow generation of more than one token per forward pass iteration. This can lead to a reduction in the average per-token latency <strong>in situations where the GPU
is underutilized due to small batch sizes.</strong></p>
<p>Speculative Sampling involves predicting a sequence of future tokens, referred to as draft tokens, using a method
that is substantially more efficient than repeatedly executing the target Large Language Model (LLM).
These draft tokens are then collectively validated by processing them through the target LLM in a single forward pass.
The underlying assumptions are twofold:</p>
<ol class="arabic simple">
<li><p>processing multiple draft tokens concurrently will be as rapid as processing a single token</p></li>
<li><p>multiple draft tokens will be validated successfully over the course of the full generation</p></li>
</ol>
<p>If the first assumption holds true, the latency of speculative decoding will no worse than the standard approach. If the second holds, output token generation advances by statistically more than one token per forward pass.
The combination of both these allows speculative decoding to result in reduced latency.</p>
<p>TensorRT-LLM supports several approaches for generating draft tokens, including:</p>
<ol class="arabic simple">
<li><p>Utilizing a smaller, auxiliary model, known as the draft model approach. For more information, refer to the <a class="reference external" href="https://arxiv.org/pdf/2211.17192.pdf">Fast Inference from Transformers via Speculative Decoding paper</a>.</p></li>
<li><p>Implementing additional language model heads that predict tokens for future positions, as detailed in the <a class="reference external" href="https://arxiv.org/abs/2401.10774">Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads paper</a>.</p></li>
</ol>
<section id="performance-improvements">
<h2>Performance Improvements<a class="headerlink" href="#performance-improvements" title="Link to this heading"></a></h2>
<p>Its important to note that the effectiveness of speculative decoding techniques is highly dependent
on the specific task at hand. For instance, forecasting subsequent tokens in a code-completion scenario
may prove simpler than generating a summary for an article.</p>
<p>Furthermore, when integrating Medusa with a standard PyTorch model implementation which may not be as finely
tuned as TensorRT-LLM, the potential time savings are more pronounced.</p>
</section>
</section>
<section id="draft-model-approach">
<h1>Draft Model Approach<a class="headerlink" href="#draft-model-approach" title="Link to this heading"></a></h1>
<p>The Draft model approach involves the use of two distinct models trained independently
but sharing the same vocabulary: a smaller Draft model and a larger Target model.
For example, a GPT 125M model can serve as the Draft model, while a GPT 6.7B model acts as the Target model.</p>
<p>The management of Draft and Target models is facilitated through two separate <code class="docutils literal notranslate"><span class="pre">GptManager</span></code> instances.
It is essential that you to coordinate the interactions between the Draft and Target models effectively.
Initially, the Draft model is queried to generate up to <code class="docutils literal notranslate"><span class="pre">K</span></code> draft tokens.
These tokens are then forwarded to the Target model for verification.
Upon verification, the Target model may return up to <code class="docutils literal notranslate"><span class="pre">K+1</span></code> tokens.
Subsequently, the prompt, now updated with the accepted tokens, is sent back to the Draft model to initiate the generation of new draft tokens.
This iterative process continues until a predefined stop conditions are met.
An example of this orchestration process can be found in the <a class="reference external" href="https://github.com/triton-inference-server/tensorrtllm_backend/blob/main/inflight_batcher_llm/client/e2e_grpc_speculative_decoding_client.py">TensorRT-LLM Triton backend</a>.</p>
<p>Configuring and executing the Draft model within the Inflight Fused Batching (IFB) framework
follows the same procedure as for any other model within IFB.
The <code class="docutils literal notranslate"><span class="pre">maxNewTokens</span></code> parameter should be set to the number of draft tokens in the <code class="docutils literal notranslate"><span class="pre">LlmRequest</span></code> for the Draft model query.</p>
<p>When building the Target model, it is necessary to specify the <code class="docutils literal notranslate"><span class="pre">--max_draft_len</span> <span class="pre">&lt;K&gt;</span> <span class="pre">--speculative_decoding_mode</span> <span class="pre">draft_tokens_external</span></code> option to the <code class="docutils literal notranslate"><span class="pre">trtllm-build</span></code> command.
During the Target models inference phase in IFB, <code class="docutils literal notranslate"><span class="pre">maxNewTokens</span></code> should be set to <code class="docutils literal notranslate"><span class="pre">1</span></code>,
and the draft tokens must be set in the <code class="docutils literal notranslate"><span class="pre">draftTokens</span></code> field of the <code class="docutils literal notranslate"><span class="pre">LlmRequest</span></code> for the Target model query.</p>
<p><strong>NOTE:</strong> To enhance performance, especially due to the repetitive querying of Draft and Target models with requests that share a common prefix,
it is advisable to enable KV cache reuse for both models.
This can be achieved by adding the <code class="docutils literal notranslate"><span class="pre">--use_paged_context_fmha=enable</span></code> flag to the <code class="docutils literal notranslate"><span class="pre">trtllm-build</span></code> command
and setting <code class="docutils literal notranslate"><span class="pre">enableBlockReuse=true</span></code> in the <code class="docutils literal notranslate"><span class="pre">KVCacheConfig</span></code>.</p>
</section>
<section id="medusa">
<h1>Medusa<a class="headerlink" href="#medusa" title="Link to this heading"></a></h1>
<p>This approach leverages a single model to both generate and verify draft tokens.
It enhances the existing model by adding multiple extra language model heads, known as Medusa heads.
These additional heads are trained to predict future tokens while the base model remains unchanged.
Specifically, the first Medusa head is tasked with predicting the immediate next token,
the second head predicts the token after that, and so on.
With <code class="docutils literal notranslate"><span class="pre">K</span></code> Medusa heads, the model can forecast up to <code class="docutils literal notranslate"><span class="pre">K</span></code> tokens ahead.
The draft tokens generated by the Medusa heads during iteration <code class="docutils literal notranslate"><span class="pre">i</span></code>
are then verified and potentially accepted in the subsequent iteration, <code class="docutils literal notranslate"><span class="pre">i+1</span></code>.</p>
<p>The true potential of the Medusa strategy is realized when more than one token per head is used,
employing a TopK approach to create multiple potential paths, essentially forming a tree, rather than
a single linear path as seen in the Draft model approach. To reduce redundant computations, many of these paths,
which often share common prefixes, are consolidated into a single path.
This is achieved by applying attention with a sparse mask that represents the various paths. Sparse mask formed by Medusa tree is described in detail later.</p>
<p>By validating multiple paths simultaneously, there is an increased likelihood of accepting more than one token per iteration,
albeit at the expense of additional computational effort.</p>
<p>It is crucial to recognize that as the number of potential paths grows exponentially with <code class="docutils literal notranslate"><span class="pre">K</span></code>,
it is not necessary to explore or validate all of them. A recommended strategy for managing this complexity is to prune the tree
by focusing only on the paths with higher-probability tokens.</p>
<p>You must strike a balance between the breadth and depth of the tree you want to explore and the impact of a larger tree on the overall
performance for your specific application.</p>
<p>In the TensorRT-LLM implementation of Medusa, the configuration of the tree is a runtime parameter.
This flexibility allows you to experiment and identify the optimal tree structure for your use case,
which can then be utilized in a production environment.</p>
<section id="medusa-tree">
<h2>Medusa Tree<a class="headerlink" href="#medusa-tree" title="Link to this heading"></a></h2>
<p>Consider the following diagram, which illustrates how the hidden states from the last layer of the base model
are passed to the base models language model (LM) head and to four Medusa heads (MHs).</p>
<p align="center">
<img src="./media/medusa_tree.svg" alt="Example Medusa Tree" width="auto" height="auto">
</p>
<p>In this example:</p>
<ol class="arabic simple">
<li><p>The token <code>l<sub>0</sub></code> represents the actual token generated by the model.
All other tokens, denoted as <code>p<sub>hk</sub></code>, are predictions from the MHs,
where <code class="docutils literal notranslate"><span class="pre">h</span></code> indicates the Medusa head index (1-based) and <code class="docutils literal notranslate"><span class="pre">k</span></code> represents the TopK choice index (0-based).</p></li>
<li><p>Four MHs are used, which means the model is predicting four future tokens.</p></li>
<li><p>The first two MHs utilize Top-2 predictions, while the last two use Top-1.
For instance, <code>p<sub>10</sub></code> and <code>p<sub>11</sub></code> are the top and
second top predictions from the first Medusa Head (MH1).</p></li>
<li><p>A total of four paths are explored, which is fewer than the 16 that would be examined
if a complete binary tree were used (assuming Top-2 predictions for all MHs).</p></li>
<li><p>As some of these paths may be accepted, there are ten potential candidates, referred to as <code class="docutils literal notranslate"><span class="pre">medusa_choices</span></code>.
The number of tokens that can be accepted at each step, including the true token,
ranges from 1 (if all Medusa predictions are incorrect) to 5 (if all are correct).</p></li>
</ol>
<p>During the generation phase, the model receives an input of 10 tokens,
which corresponds to the last tokens of each candidate path, rather than just one.</p>
<p>In TensorRT-LLM, you have the option to define such trees by providing all the Medusa choices
or by simply specifying the unique paths.</p>
<ul class="simple">
<li><p>Since each candidate/path begins with the true token (<code>l<sub>0</sub></code>),
there is no need to specify it separately. For the predicted tokens, only the TopK indices are required.</p></li>
<li><p>For example, to specify the path <code>l<sub>0</sub>p<sub>10</sub>p<sub>21</sub>p<sub>30</sub></code>,
one would use <code class="docutils literal notranslate"><span class="pre">[0,1,0]</span></code>. And
to specify the path <code>l<sub>0</sub>p<sub>11</sub>p<sub>20</sub></code>,
one would use <code class="docutils literal notranslate"><span class="pre">[1,0]</span></code>.</p></li>
<li><p>To specify all 4 paths in the example, use <code class="docutils literal notranslate"><span class="pre">medusa_choices=[[0,0,0,0],</span> <span class="pre">[0,1,0],</span> <span class="pre">[1,0],</span> <span class="pre">[1,1]]</span></code>.</p></li>
<li><p>Its also possible to specify all candidates explicitly, similar to the Medusa repository.
For instance, <code class="docutils literal notranslate"><span class="pre">medusa_choices=[[0],</span> <span class="pre">[0,0],</span> <span class="pre">[0,0,0],</span> <span class="pre">[0,0,0,0],</span> <span class="pre">[0,1],</span> <span class="pre">[0,1,0],</span> <span class="pre">[1],</span> <span class="pre">[1,0],</span> <span class="pre">[1,1]]</span></code>. Note that when specifying all the candidates explicitly, <strong>we dont include
the empty <code class="docutils literal notranslate"><span class="pre">[]</span></code> candidate</strong> for the case where only the true token is accepted, that is, all the predictions from MHs are wrong.
So, only <code class="docutils literal notranslate"><span class="pre">9</span></code> candidates are specified.</p></li>
</ul>
<p><strong>Specifying paths-only instead of all choices is currently supported only in the Python runtime.</strong></p>
</section>
<section id="using-medusa-with-tensorrt-llm">
<h2>Using Medusa with TensorRT-LLM<a class="headerlink" href="#using-medusa-with-tensorrt-llm" title="Link to this heading"></a></h2>
<p>For guidance on constructing and executing Medusa with the Python runtime, consult the <span class="xref myst">Medusa README</span>. When utilizing the Inflight Fused Batching (IFB) with the C++ API, it is necessary to define the <code class="docutils literal notranslate"><span class="pre">medusa_choices</span></code> explicitly within the model configuration. For detailed instructions, refer to the <a class="reference external" href="https://github.com/triton-inference-server/tensorrtllm_backend?tab=readme-ov-file#modify-the-model-configuration">model configuration in TensorRT-LLM backend</a> for more details.</p>
<section id="limitations">
<h3>Limitations<a class="headerlink" href="#limitations" title="Link to this heading"></a></h3>
<ul class="simple">
<li><p>TensorRT-LLM supports Medusa only for Vicuna (fine tuned LLaMA).
However, similar to any new model, you can follow the same approach to define your own Medusa model and deploy with TensorRT-LLM.</p></li>
<li><p>We match only tokens during the validation phasem that is <code class="docutils literal notranslate"><span class="pre">medusa_temperature=0</span></code>.</p></li>
<li><p>Beam search is <strong>not</strong> compatible with Medusa.</p></li>
</ul>
</section>
</section>
</section>
</div>
</div>
<footer>
<hr/>
<div role="contentinfo">
<jinja2.runtime.BlockReference object at 0x7f27027e0100>
<div class="footer">
<p>
Copyright © 2024 NVIDIA Corporation
</p>
<p>
<a class="Link" href="https://www.nvidia.com/en-us/about-nvidia/privacy-policy/" target="_blank" rel="noopener"
data-cms-ai="0">Privacy Policy</a> |
<a class="Link" href="https://www.nvidia.com/en-us/about-nvidia/privacy-center/" target="_blank" rel="noopener"
data-cms-ai="0">Manage My Privacy</a> |
<a class="Link" href="https://www.nvidia.com/en-us/preferences/start/" target="_blank" rel="noopener"
data-cms-ai="0">Do Not Sell or Share My Data</a> |
<a class="Link" href="https://www.nvidia.com/en-us/about-nvidia/terms-of-service/" target="_blank"
rel="noopener" data-cms-ai="0">Terms of Service</a> |
<a class="Link" href="https://www.nvidia.com/en-us/about-nvidia/accessibility/" target="_blank" rel="noopener"
data-cms-ai="0">Accessibility</a> |
<a class="Link" href="https://www.nvidia.com/en-us/about-nvidia/company-policies/" target="_blank"
rel="noopener" data-cms-ai="0">Corporate Policies</a> |
<a class="Link" href="https://www.nvidia.com/en-us/product-security/" target="_blank" rel="noopener"
data-cms-ai="0">Product Security</a> |
<a class="Link" href="https://www.nvidia.com/en-us/contact/" target="_blank" rel="noopener"
data-cms-ai="0">Contact</a>
</p>
</div>
</div>
</footer>
</div>
</div>
</section>
</div>
<script>
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
</script>
</body>
</html>