mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-26 21:53:30 +08:00
423 lines
35 KiB
HTML
423 lines
35 KiB
HTML
<!DOCTYPE html>
|
|
<html class="writer-html5" lang="en" data-content_root="./">
|
|
<head>
|
|
<meta charset="utf-8" /><meta name="viewport" content="width=device-width, initial-scale=1" />
|
|
|
|
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
|
<title>Best Practices for Tuning the Performance of TensorRT-LLM — tensorrt_llm documentation</title>
|
|
<link rel="stylesheet" type="text/css" href="_static/pygments.css?v=80d5e7a1" />
|
|
<link rel="stylesheet" type="text/css" href="_static/css/theme.css?v=19f00094" />
|
|
|
|
|
|
<!--[if lt IE 9]>
|
|
<script src="_static/js/html5shiv.min.js"></script>
|
|
<![endif]-->
|
|
|
|
<script src="_static/jquery.js?v=5d32c60e"></script>
|
|
<script src="_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
|
<script src="_static/documentation_options.js?v=5929fcd5"></script>
|
|
<script src="_static/doctools.js?v=888ff710"></script>
|
|
<script src="_static/sphinx_highlight.js?v=dc90522c"></script>
|
|
<script src="_static/js/theme.js"></script>
|
|
<link rel="index" title="Index" href="genindex.html" />
|
|
<link rel="search" title="Search" href="search.html" />
|
|
<link rel="next" title="Performance Analysis of TensorRT-LLM" href="performance_analysis.html" />
|
|
<link rel="prev" title="Run gpt-2b + LoRA using GptManager / cpp runtime" href="lora.html" />
|
|
</head>
|
|
|
|
<body class="wy-body-for-nav">
|
|
<div class="wy-grid-for-nav">
|
|
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
|
|
<div class="wy-side-scroll">
|
|
<div class="wy-side-nav-search" >
|
|
|
|
|
|
|
|
<a href="index.html" class="icon icon-home">
|
|
tensorrt_llm
|
|
</a>
|
|
<div role="search">
|
|
<form id="rtd-search-form" class="wy-form" action="search.html" method="get">
|
|
<input type="text" name="q" placeholder="Search docs" aria-label="Search docs" />
|
|
<input type="hidden" name="check_keywords" value="yes" />
|
|
<input type="hidden" name="area" value="default" />
|
|
</form>
|
|
</div>
|
|
</div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
|
|
<p class="caption" role="heading"><span class="caption-text">Contents:</span></p>
|
|
<ul class="current">
|
|
<li class="toctree-l1"><a class="reference internal" href="architecture.html">TensorRT-LLM Architecture</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="gpt_runtime.html">C++ GPT Runtime</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="batch_manager.html">The Batch Manager in TensorRT-LLM</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="inference_request.html">Inference Request</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="gpt_attention.html">Multi-head, Multi-query and Group-query Attention</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="precision.html">Numerical Precision</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="build_from_source.html">Build from Source</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="performance.html">Performance of TensorRT-LLM</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="2023-05-19-how-to-debug.html">How to debug</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="2023-05-17-how-to-add-a-new-model.html">How to add a new model</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="graph-rewriting.html">Graph Rewriting Module</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="memory.html">Memory Usage of TensorRT-LLM</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="new_workflow.html">New Workflow</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="lora.html">Run gpt-2b + LoRA using GptManager / cpp runtime</a></li>
|
|
<li class="toctree-l1 current"><a class="current reference internal" href="#">Best Practices for Tuning the Performance of TensorRT-LLM</a><ul>
|
|
<li class="toctree-l2"><a class="reference internal" href="#how-to-measure-performance">How To Measure Performance?</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="#build-options-to-optimize-the-performance-of-tensorrt-llm-models">Build Options to Optimize the Performance of TensorRT-LLM Models?</a><ul>
|
|
<li class="toctree-l3"><a class="reference internal" href="#gpt-attention-plugin-and-context-fused-multi-head-attention">GPT Attention Plugin and Context Fused Multi-Head Attention</a></li>
|
|
<li class="toctree-l3"><a class="reference internal" href="#remove-input-padding">Remove Input Padding</a></li>
|
|
<li class="toctree-l3"><a class="reference internal" href="#maximum-number-of-tokens">Maximum Number of Tokens</a></li>
|
|
<li class="toctree-l3"><a class="reference internal" href="#paged-kv-cache">Paged KV Cache</a></li>
|
|
<li class="toctree-l3"><a class="reference internal" href="#in-flight-sequence-batching">In-flight Sequence Batching</a></li>
|
|
<li class="toctree-l3"><a class="reference internal" href="#multi-block-mode">Multi-Block Mode</a></li>
|
|
<li class="toctree-l3"><a class="reference internal" href="#custom-allreduce-plugin">Custom AllReduce Plugin</a></li>
|
|
<li class="toctree-l3"><a class="reference internal" href="#embedding-parallelism-embedding-sharing-and-look-up-plugin">Embedding Parallelism, Embedding Sharing, and Look-Up Plugin</a></li>
|
|
<li class="toctree-l3"><a class="reference internal" href="#horizontal-fusion-in-gated-mlp">Horizontal Fusion in Gated-MLP</a></li>
|
|
<li class="toctree-l3"><a class="reference internal" href="#bert-attention-plugin-and-context-fused-multi-head-attention">BERT Attention Plugin and Context Fused Multi-Head Attention</a></li>
|
|
</ul>
|
|
</li>
|
|
<li class="toctree-l2"><a class="reference internal" href="#runtime-options-to-optimize-the-performance-of-tensorrt-llm-models">Runtime Options to Optimize the Performance of TensorRT-LLM Models?</a><ul>
|
|
<li class="toctree-l3"><a class="reference internal" href="#gpt-model-type">GPT Model Type</a></li>
|
|
<li class="toctree-l3"><a class="reference internal" href="#max-tokens-in-paged-kv-cache-and-kv-cache-free-gpu-memory-fraction">Max Tokens in Paged KV Cache and KV Cache Free GPU Memory Fraction</a></li>
|
|
<li class="toctree-l3"><a class="reference internal" href="#batch-scheduler-policy">Batch Scheduler Policy</a></li>
|
|
<li class="toctree-l3"><a class="reference internal" href="#tensorrt-overlap">TensorRT Overlap</a></li>
|
|
<li class="toctree-l3"><a class="reference internal" href="#maximum-attention-window-size">Maximum Attention Window Size</a></li>
|
|
<li class="toctree-l3"><a class="reference internal" href="#chunked-context">Chunked Context</a></li>
|
|
</ul>
|
|
</li>
|
|
</ul>
|
|
</li>
|
|
<li class="toctree-l1"><a class="reference internal" href="performance_analysis.html">Performance Analysis of TensorRT-LLM</a></li>
|
|
</ul>
|
|
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
|
|
<ul>
|
|
<li class="toctree-l1"><a class="reference internal" href="python-api/tensorrt_llm.layers.html">Layers</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="python-api/tensorrt_llm.functional.html">Functionals</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="python-api/tensorrt_llm.models.html">Models</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="python-api/tensorrt_llm.plugin.html">Plugin</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="python-api/tensorrt_llm.quantization.html">Quantization</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="python-api/tensorrt_llm.runtime.html">Runtime</a></li>
|
|
</ul>
|
|
<p class="caption" role="heading"><span class="caption-text">C++ API</span></p>
|
|
<ul>
|
|
<li class="toctree-l1"><a class="reference internal" href="_cpp_gen/runtime.html">Runtime</a></li>
|
|
</ul>
|
|
<p class="caption" role="heading"><span class="caption-text">Blogs</span></p>
|
|
<ul>
|
|
<li class="toctree-l1"><a class="reference internal" href="blogs/H100vsA100.html">H100 has 4.6x A100 Performance in TensorRT-LLM, achieving 10,000 tok/s at 100ms to first token</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="blogs/H200launch.html">H200 achieves nearly 12,000 tokens/sec on Llama2-13B with TensorRT-LLM</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="blogs/Falcon180B-H200.html">Falcon-180B on a single H200 GPU with INT4 AWQ, and 6.7x faster Llama-70B over A100</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="blogs/quantization-in-TRT-LLM.html">Speed up inference with SOTA quantization techniques in TRT-LLM</a></li>
|
|
</ul>
|
|
|
|
</div>
|
|
</div>
|
|
</nav>
|
|
|
|
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
|
|
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
|
|
<a href="index.html">tensorrt_llm</a>
|
|
</nav>
|
|
|
|
<div class="wy-nav-content">
|
|
<div class="rst-content">
|
|
<div role="navigation" aria-label="Page navigation">
|
|
<ul class="wy-breadcrumbs">
|
|
<li><a href="index.html" class="icon icon-home" aria-label="Home"></a></li>
|
|
<li class="breadcrumb-item active">Best Practices for Tuning the Performance of TensorRT-LLM</li>
|
|
<li class="wy-breadcrumbs-aside">
|
|
<a href="_sources/perf_best_practices.md.txt" rel="nofollow"> View page source</a>
|
|
</li>
|
|
</ul>
|
|
<hr/>
|
|
</div>
|
|
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
|
|
<div itemprop="articleBody">
|
|
|
|
<section id="best-practices-for-tuning-the-performance-of-tensorrt-llm">
|
|
<h1>Best Practices for Tuning the Performance of TensorRT-LLM<a class="headerlink" href="#best-practices-for-tuning-the-performance-of-tensorrt-llm" title="Link to this heading"></a></h1>
|
|
<p>This document provides some best practices for tuning the performance of TensorRT-LLM.</p>
|
|
<section id="how-to-measure-performance">
|
|
<h2>How To Measure Performance?<a class="headerlink" href="#how-to-measure-performance" title="Link to this heading"></a></h2>
|
|
<p>TensorRT-LLM can be benchmarked using the included
|
|
<a class="reference external" href="https://github.com/NVIDIA/TensorRT-LLM/tree/rel/../../benchmarks/cpp/README.md">C++</a>
|
|
and
|
|
<a class="reference external" href="https://github.com/NVIDIA/TensorRT-LLM/tree/rel/../../benchmarks/python/README.md">Python</a> tools. However, it is <em>strongly</em>
|
|
recommended to use the C++ benchmarking tool. For detailed performance data and
|
|
the steps to reproduce those results, see
|
|
this <a class="reference internal" href="performance.html"><span class="std std-doc">Document</span></a>.
|
|
The <a class="reference external" href="https://github.com/triton-inference-server/tensorrtllm_backend">TensorRT-LLM backend</a>
|
|
can also be used to measure the performance of TensorRT-LLM for online serving.</p>
|
|
</section>
|
|
<section id="build-options-to-optimize-the-performance-of-tensorrt-llm-models">
|
|
<h2>Build Options to Optimize the Performance of TensorRT-LLM Models?<a class="headerlink" href="#build-options-to-optimize-the-performance-of-tensorrt-llm-models" title="Link to this heading"></a></h2>
|
|
<p>This part summarizes how to build engines to enhance the performance of the
|
|
runtime and, for some of them, decrease the engine build time.</p>
|
|
<p><em><strong>Note that some of those features and how to enable them may change in the future.</strong></em></p>
|
|
<section id="gpt-attention-plugin-and-context-fused-multi-head-attention">
|
|
<h3>GPT Attention Plugin and Context Fused Multi-Head Attention<a class="headerlink" href="#gpt-attention-plugin-and-context-fused-multi-head-attention" title="Link to this heading"></a></h3>
|
|
<p>The GPT attention plugin and fused multi-head attention kernel are enabled by
|
|
default. For the context phase, use the <code class="docutils literal notranslate"><span class="pre">--gpt_attention_plugin</span></code>
|
|
and <code class="docutils literal notranslate"><span class="pre">--context_fmha</span></code> arguments with <code class="docutils literal notranslate"><span class="pre">trtllm-build</span></code> to control.</p>
|
|
<p>The TensorRT-LLM GPT attention plugin uses efficient kernels and enables an
|
|
in-place update of the KV cache. It results in reduced memory consumption as
|
|
well as the removal of unneeded memory copy operations (compared with the
|
|
implementation that uses the <code class="docutils literal notranslate"><span class="pre">concat</span></code> operator to update the KV cache).</p>
|
|
<p>Enabling the fused multi-head attention, during the context phase, will trigger
|
|
a kernel that performs the MHA/MQA/GQA block using a single kernel, for more
|
|
details, see this <a class="reference internal" href="gpt_attention.html#context-phase"><span class="std std-ref">Document</span></a>.</p>
|
|
</section>
|
|
<section id="remove-input-padding">
|
|
<h3>Remove Input Padding<a class="headerlink" href="#remove-input-padding" title="Link to this heading"></a></h3>
|
|
<p>The remove input padding feature is enabled by default, the <code class="docutils literal notranslate"><span class="pre">--remove_input_padding</span></code>
|
|
argument in <code class="docutils literal notranslate"><span class="pre">trtllm-build</span></code> is used to control it.</p>
|
|
<p>When input padding is removed, the different tokens are packed together. It
|
|
reduces both the amount of computations and memory consumption. For more details, see
|
|
this <a class="reference internal" href="gpt_attention.html#padded-and-packed-tensors"><span class="std std-ref">Document</span></a>.</p>
|
|
</section>
|
|
<section id="maximum-number-of-tokens">
|
|
<h3>Maximum Number of Tokens<a class="headerlink" href="#maximum-number-of-tokens" title="Link to this heading"></a></h3>
|
|
<p>It is recommended to tune <code class="docutils literal notranslate"><span class="pre">--max_num_tokens</span></code> for better performance. The
|
|
<code class="docutils literal notranslate"><span class="pre">--max_num_tokens</span></code> could be roughly estimated as:</p>
|
|
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">max_batch_size</span> <span class="o">*</span> <span class="n">max_input_len</span> <span class="o">*</span> <span class="n">alpha</span> <span class="o">+</span> <span class="n">max_batch_size</span> <span class="o">*</span> <span class="n">max_beam_width</span> <span class="o">*</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">alpha</span><span class="p">)</span>
|
|
</pre></div>
|
|
</div>
|
|
<p>where <code class="docutils literal notranslate"><span class="pre">alpha</span></code> is a floating-point value between <code class="docutils literal notranslate"><span class="pre">0.0</span></code> and <code class="docutils literal notranslate"><span class="pre">1.0</span></code>. It stands for
|
|
a rough estimation of the number of requests in their context phase at each
|
|
invocation of the forward function during inference. It is recommended to use a
|
|
value between <code class="docutils literal notranslate"><span class="pre">0.05</span></code> and <code class="docutils literal notranslate"><span class="pre">0.20</span></code> (between 5%-20%) but it may depend on the
|
|
actual scenario.</p>
|
|
<p>The maximum number of tokens equals will not take effects when input padding is
|
|
not removed. When input padding is removed (see <a class="reference internal" href="#remove-input-padding"><span class="xref myst">Remove Input
|
|
Padding</span></a>), the tokens from different sequences are
|
|
packed together and the maximum number of the tokens can be set to a different
|
|
(lower) value, which by default to be <code class="docutils literal notranslate"><span class="pre">max_input_len</span> <span class="pre">*</span> <span class="pre">max_batch_size</span></code>.
|
|
Note that it has to be higher than <code class="docutils literal notranslate"><span class="pre">max_input_len</span></code>.</p>
|
|
<p>There are two aspects that must be considered. Firstly, some input sequences
|
|
will be shorter than the maximum input length. Secondly, when in-flight
|
|
sequence batching is enabled, requests in context phase will be executed with
|
|
requests in generation phase. Those latter requests produce a lot fewer tokens
|
|
than <code class="docutils literal notranslate"><span class="pre">max_input_len</span></code> (at most, <code class="docutils literal notranslate"><span class="pre">beam_width</span></code> tokens).</p>
|
|
<p>Using a more realistic value for <code class="docutils literal notranslate"><span class="pre">max_num_tokens</span></code> allows TensorRT-LLM to
|
|
allocate more memory to store the KV cache and execute more requests together.
|
|
It leads to an increased efficiency.</p>
|
|
<p>Increasing <code class="docutils literal notranslate"><span class="pre">max_num_tokens</span></code> appropriately will be beneficial to performance.
|
|
When increasing <code class="docutils literal notranslate"><span class="pre">--max_num_tokens</span></code> to some point, GPU utilization will plateau,
|
|
going beyond that saturation point may hurt both first token latency as well as
|
|
total end-to-end latency.</p>
|
|
<p>See also <a class="reference internal" href="#chunked-context"><span class="xref myst">chunked context</span></a>.</p>
|
|
</section>
|
|
<section id="paged-kv-cache">
|
|
<h3>Paged KV Cache<a class="headerlink" href="#paged-kv-cache" title="Link to this heading"></a></h3>
|
|
<p>Paged KV cache is enabled by default, the <code class="docutils literal notranslate"><span class="pre">--paged_kv_cache</span></code> argument in
|
|
<code class="docutils literal notranslate"><span class="pre">trtllm-build</span></code> is used to control it.</p>
|
|
<p>The paged KV cache helps manage memory for the KV cache more efficiently (see
|
|
this <a class="reference internal" href="gpt_attention.html#paged-kv-cache"><span class="std std-ref">Document</span></a>). It usually leads to an
|
|
increase in the batch size and an improved efficiency.</p>
|
|
</section>
|
|
<section id="in-flight-sequence-batching">
|
|
<h3>In-flight Sequence Batching<a class="headerlink" href="#in-flight-sequence-batching" title="Link to this heading"></a></h3>
|
|
<p>In-flight sequence batching is enabled by default with <code class="docutils literal notranslate"><span class="pre">trtllm-build</span></code>,
|
|
which requires that the GPT attention plugin, input padding removal and paged KV
|
|
cache are all enabled together.</p>
|
|
<p>In-flight sequence batching schedules sequences in context phase together with
|
|
sequences in generation phase to increase efficiency and reduce latency, see
|
|
this <a class="reference internal" href="gpt_attention.html#inflight-batching"><span class="std std-ref">Document</span></a> for more details.</p>
|
|
</section>
|
|
<section id="multi-block-mode">
|
|
<h3>Multi-Block Mode<a class="headerlink" href="#multi-block-mode" title="Link to this heading"></a></h3>
|
|
<p>When the following conditions are met, it is recommended to try the
|
|
<code class="docutils literal notranslate"><span class="pre">--multi_block_mode</span></code> argument with <code class="docutils literal notranslate"><span class="pre">trtllm-build</span></code> and evaluate the impact on
|
|
performance:</p>
|
|
<ol class="arabic simple">
|
|
<li><p><code class="docutils literal notranslate"><span class="pre">input_seq_len</span></code> > 1024 (An empirically derived value that indicates that the
|
|
context length is long enough),</p></li>
|
|
<li><p><code class="docutils literal notranslate"><span class="pre">sequence_count</span></code> * <code class="docutils literal notranslate"><span class="pre">num_head</span></code> < <code class="docutils literal notranslate"><span class="pre">multiprocessor_count</span></code> / 2</p></li>
|
|
</ol>
|
|
<p>Multi-block mode can be beneficial when <code class="docutils literal notranslate"><span class="pre">batch_size</span> <span class="pre">*</span> <span class="pre">num_heads</span></code> is not large
|
|
enough to fully utilize the GPU (the number of CUDA thread blocks is low
|
|
compared to the number of streaming multiprocessors). Hence, the multi-block
|
|
mode is expected to reduce the latency of the multi-head attention kernel in
|
|
the generation phase. However, it requires the context length to be long enough
|
|
for the work performed by each CUDA thread block to remain sufficient for
|
|
efficiency.</p>
|
|
</section>
|
|
<section id="custom-allreduce-plugin">
|
|
<h3>Custom AllReduce Plugin<a class="headerlink" href="#custom-allreduce-plugin" title="Link to this heading"></a></h3>
|
|
<p>On NVLink-based nodes, it is recommended to enable the custom AllReduce plugin
|
|
by using the <code class="docutils literal notranslate"><span class="pre">--use_custom_all_reduce</span></code> argument with <code class="docutils literal notranslate"><span class="pre">trtllm-build</span></code>. On PCIE-based
|
|
nodes, it is not recommended to enabled that plugin.</p>
|
|
<p>The custom AllReduce plugin activates a latency-optimized algorithm for
|
|
the AllReduce operation instead of the native NCCL operator. However, the
|
|
performance benefits may not be seen on PCIE-based systems.</p>
|
|
</section>
|
|
<section id="embedding-parallelism-embedding-sharing-and-look-up-plugin">
|
|
<h3>Embedding Parallelism, Embedding Sharing, and Look-Up Plugin<a class="headerlink" href="#embedding-parallelism-embedding-sharing-and-look-up-plugin" title="Link to this heading"></a></h3>
|
|
<p>The embedding parallelism feature enables the sharding of the embedding table
|
|
across multiple GPUs, so that the memory usage could be reduced and the
|
|
throughput improved. The embedding sharing feature enables the sharing of the
|
|
embedding table between <code class="docutils literal notranslate"><span class="pre">look_up</span></code> and <code class="docutils literal notranslate"><span class="pre">lm_head</span></code> layers.</p>
|
|
<p>The look-up plugin implements the embedding sharing feature and is required to
|
|
enable the aforementioned features for now (until TensorRT native layers
|
|
support embedding sharing).</p>
|
|
<p>It is recommended to enable the embedding parallelism and sharing features to
|
|
improve throughput. However, the following conditions have to be satisfied:</p>
|
|
<ol class="arabic simple">
|
|
<li><p>The model shares the embedding table between <code class="docutils literal notranslate"><span class="pre">look_up</span></code> and <code class="docutils literal notranslate"><span class="pre">lm_head</span></code> layers,</p></li>
|
|
<li><p>Both look_up plugin and gemm plugin are enabled,</p></li>
|
|
<li><p>The sharding dimension of the embedding lookup table is set correctly.</p></li>
|
|
</ol>
|
|
<p>To enable the features, use the <code class="docutils literal notranslate"><span class="pre">--use_parallel_embedding</span></code>,
|
|
<code class="docutils literal notranslate"><span class="pre">--use_embedding_sharing</span></code>, <code class="docutils literal notranslate"><span class="pre">--use_lookup_plugin</span></code>, <code class="docutils literal notranslate"><span class="pre">--use_gemm_plugin</span></code>
|
|
arguments, and set correct dimension to <code class="docutils literal notranslate"><span class="pre">--embedding_sharding_dim</span></code> argument
|
|
with <code class="docutils literal notranslate"><span class="pre">trtllm-build</span></code>. See those
|
|
<a class="reference external" href="https://github.com/NVIDIA/TensorRT-LLM/tree/rel/../../examples/gpt/README.md">Examples</a>
|
|
for details.</p>
|
|
</section>
|
|
<section id="horizontal-fusion-in-gated-mlp">
|
|
<h3>Horizontal Fusion in Gated-MLP<a class="headerlink" href="#horizontal-fusion-in-gated-mlp" title="Link to this heading"></a></h3>
|
|
<p>Horizontal fusion in Gated-MLP combines two Matmul operations into a single one
|
|
followed by a separate SwiGLU kernel. If both model and batch sizes are large,
|
|
it is recommended to enable the feature by using the <code class="docutils literal notranslate"><span class="pre">--use_fused_mlp</span></code> argument
|
|
with <code class="docutils literal notranslate"><span class="pre">trtllm-build</span></code>. When the workload is very small, it is not recommended to
|
|
enable that feature.</p>
|
|
</section>
|
|
<section id="bert-attention-plugin-and-context-fused-multi-head-attention">
|
|
<h3>BERT Attention Plugin and Context Fused Multi-Head Attention<a class="headerlink" href="#bert-attention-plugin-and-context-fused-multi-head-attention" title="Link to this heading"></a></h3>
|
|
<p>BERT attention plugin and context fused multi-head attention are both
|
|
recommended for the BERT model. They are enabled by default using the
|
|
<code class="docutils literal notranslate"><span class="pre">--bert_attention_plugin</span></code> and <code class="docutils literal notranslate"><span class="pre">--context_fmha</span></code> arguments with
|
|
<code class="docutils literal notranslate"><span class="pre">trtllm-build</span></code>.</p>
|
|
</section>
|
|
</section>
|
|
<section id="runtime-options-to-optimize-the-performance-of-tensorrt-llm-models">
|
|
<h2>Runtime Options to Optimize the Performance of TensorRT-LLM Models?<a class="headerlink" href="#runtime-options-to-optimize-the-performance-of-tensorrt-llm-models" title="Link to this heading"></a></h2>
|
|
<p>This part summarizes the runtime configuration knobs that can be tweaked to
|
|
enhance the performance of already built engines. Note that currently the
|
|
configurations can be modified using the
|
|
<a class="reference internal" href="batch_manager.html#the-batch-manager-api"><span class="std std-ref">Batch Manager API</span></a>
|
|
as well as the
|
|
<a class="reference external" href="https://github.com/triton-inference-server/tensorrtllm_backend">TensorRT-LLM backend</a>.</p>
|
|
<section id="gpt-model-type">
|
|
<h3>GPT Model Type<a class="headerlink" href="#gpt-model-type" title="Link to this heading"></a></h3>
|
|
<p>The GPT model type can be set to <code class="docutils literal notranslate"><span class="pre">V1</span></code>, <code class="docutils literal notranslate"><span class="pre">inflight_batching</span></code> and
|
|
<code class="docutils literal notranslate"><span class="pre">inflight_fused_batching</span></code>. It is recommended to use <code class="docutils literal notranslate"><span class="pre">inflight_fused_batching</span></code>
|
|
to increase throughput and reduce latency.</p>
|
|
</section>
|
|
<section id="max-tokens-in-paged-kv-cache-and-kv-cache-free-gpu-memory-fraction">
|
|
<h3>Max Tokens in Paged KV Cache and KV Cache Free GPU Memory Fraction<a class="headerlink" href="#max-tokens-in-paged-kv-cache-and-kv-cache-free-gpu-memory-fraction" title="Link to this heading"></a></h3>
|
|
<p>The <code class="docutils literal notranslate"><span class="pre">max_tokens_in_paged_kv_cache</span></code> and <code class="docutils literal notranslate"><span class="pre">kv_cache_free_gpu_mem_fraction</span></code>
|
|
parameters can be used to control the maximum number of tokens handled by the
|
|
KV cache manager. Setting them properly helps better control the amount of
|
|
available memory for the KV cache manager during inference. Keeping in mind
|
|
that increasing the amount of memory available to the KV cache manager tends to
|
|
translate to a higher achievable throughput.</p>
|
|
<p>The <code class="docutils literal notranslate"><span class="pre">max_tokens_in_paged_kv_cache</span></code> flag directly sets the maximum number of
|
|
tokens in the KV cache manager. When left unset, that value will be computed
|
|
based on the <code class="docutils literal notranslate"><span class="pre">kv_cache_free_gpu_mem_fraction</span></code> setting.</p>
|
|
<p>The <code class="docutils literal notranslate"><span class="pre">kv_cache_free_gpu_mem_fraction</span></code> is a floating-point number between <code class="docutils literal notranslate"><span class="pre">0.0</span></code>
|
|
and <code class="docutils literal notranslate"><span class="pre">1.0</span></code> that indicates the maximum fraction of GPU memory (after loading the
|
|
model) that will be used for the KV cache. The default value is <code class="docutils literal notranslate"><span class="pre">0.90</span></code> and
|
|
means that 90% of the free GPU memory will be used to save tokens in the KV
|
|
cache. Based on that value, TensorRT-LLM can determine the maximum number of
|
|
tokens in the KV cache manager.</p>
|
|
<p>When both parameters are set, the maximum number of tokens in the KV cache
|
|
manager will be set to the smaller value between <code class="docutils literal notranslate"><span class="pre">max_tokens_in_paged_kv_cache</span></code>
|
|
and the value computed from the amount of memory available for the KV cache.</p>
|
|
<p>Unless users clearly know the maximum number of tokens in the KV cache needed
|
|
by the model, it is recommended to leave <code class="docutils literal notranslate"><span class="pre">max_tokens_in_paged_kv_cache</span></code> unset.
|
|
For <code class="docutils literal notranslate"><span class="pre">kv_cache_free_gpu_mem_fraction</span></code>, if no other programs are executed on the
|
|
same GPU, it is recommended to test with a as high value as <code class="docutils literal notranslate"><span class="pre">0.95</span></code> to target a
|
|
high throughput. Note that the <code class="docutils literal notranslate"><span class="pre">kv_cache_free_gpu_mem_fraction</span></code> parameter
|
|
cannot be set to <code class="docutils literal notranslate"><span class="pre">1.0</span></code> because some amount of memory has to be reserved for
|
|
inputs and outputs.</p>
|
|
</section>
|
|
<section id="batch-scheduler-policy">
|
|
<h3>Batch Scheduler Policy<a class="headerlink" href="#batch-scheduler-policy" title="Link to this heading"></a></h3>
|
|
<p>There currently are two batch scheduler policies: <code class="docutils literal notranslate"><span class="pre">MAX_UTILIZATION</span></code> and
|
|
<code class="docutils literal notranslate"><span class="pre">GUARANTEED_NO_EVICT</span></code>.</p>
|
|
<p>As explained in the <a class="reference internal" href="batch_manager.html#gptmanager-design"><span class="std std-ref">GPT Manager Design</span></a>
|
|
section, the scheduling policy can be set to <code class="docutils literal notranslate"><span class="pre">MAX_UTILIZATION</span></code> to pack as many
|
|
requests as possible at each iteration of the forward loop, when in-flight
|
|
sequence batching is enabled. It maximizes the utilization of the GPUs by
|
|
aggressively scheduling requests at the risk of having to pause requests if the
|
|
KV cache size limit is reached.</p>
|
|
<p>For a more conservative approach with respect to the KV cache limitations in
|
|
terms of memory allocation, <code class="docutils literal notranslate"><span class="pre">schedulerPolicy</span></code> should be set to
|
|
<code class="docutils literal notranslate"><span class="pre">GUARANTEED_NO_EVICT</span></code> to guarantee that a started request is never paused.</p>
|
|
<p>If the goal is to maximizes the throughput, users should try <code class="docutils literal notranslate"><span class="pre">MAX_UTILIZATION</span></code>.
|
|
However, they need to keep in mind that it may have a negative impact on
|
|
latency if requests have to be paused.</p>
|
|
</section>
|
|
<section id="tensorrt-overlap">
|
|
<h3>TensorRT Overlap<a class="headerlink" href="#tensorrt-overlap" title="Link to this heading"></a></h3>
|
|
<p>When TensorRT overlap is enabled, available requests are partitioned into 2
|
|
micro-batches that can be run concurrently. It allows TensorRT-LLM to hide
|
|
exposed CPU runtime. However, it may not give performance benefits when the
|
|
size of the model is not big enough to overlap the host overhead, or when the
|
|
number of requests is too small.</p>
|
|
<p>If the goal is to increase throughput, it is recommended to try setting that
|
|
argument to <code class="docutils literal notranslate"><span class="pre">True</span></code>. However, it must be noted that it may actually hurt
|
|
latency.</p>
|
|
</section>
|
|
<section id="maximum-attention-window-size">
|
|
<h3>Maximum Attention Window Size<a class="headerlink" href="#maximum-attention-window-size" title="Link to this heading"></a></h3>
|
|
<p>The <code class="docutils literal notranslate"><span class="pre">max_attention_window_size</span></code> flag sets the maximum number of tokens that are
|
|
attended to in order to generate one token when using techniques like sliding window
|
|
attention. See this
|
|
<a class="reference internal" href="gpt_attention.html#sliding-window-attention-cyclic-rolling-buffer-kv-cache"><span class="std std-ref">Document</span></a>
|
|
for more details. It defaults to the maximum sequence length
|
|
(<code class="docutils literal notranslate"><span class="pre">max_input_length</span> <span class="pre">+</span> <span class="pre">max_output_length</span></code> when building the engine), which means
|
|
that the feature is disabled by default.</p>
|
|
<p>When set to a smaller value than <code class="docutils literal notranslate"><span class="pre">max_input_length</span> <span class="pre">+</span> <span class="pre">max_output_length</span></code> (during
|
|
engine build), only the KV cache of the last <code class="docutils literal notranslate"><span class="pre">max_attention_window_size</span></code> tokens
|
|
will be stored. If the input sequence length at runtime exceeds the
|
|
<code class="docutils literal notranslate"><span class="pre">max_attention_window_size</span></code> value, the accuracy may start dropping, but the
|
|
runtime performance will be better (due to the reduction in terms of
|
|
computations and GPU memory allocation). Users can modify that value to
|
|
increase runtime performance at the expense of reduced accuracy.</p>
|
|
</section>
|
|
<section id="chunked-context">
|
|
<h3>Chunked Context<a class="headerlink" href="#chunked-context" title="Link to this heading"></a></h3>
|
|
<p>Turning on context chunking by specifying <code class="docutils literal notranslate"><span class="pre">enable_chunked_context</span></code> in
|
|
<code class="docutils literal notranslate"><span class="pre">TrtGptModelOptionalParams</span></code> will increase the chance of batch processing between
|
|
the context and the generation phase, thereby balancing the calculation amount
|
|
of each iteration and increasing throughput. When this function is turned on,
|
|
different performance can be obtained by adjusting <code class="docutils literal notranslate"><span class="pre">max_num_tokens</span></code>. Usually
|
|
its recommended value is <code class="docutils literal notranslate"><span class="pre">N</span> <span class="pre">*</span> <span class="pre">tokens_per_block</span></code>, and <code class="docutils literal notranslate"><span class="pre">N</span></code> is an integer that is
|
|
recommended to start from <code class="docutils literal notranslate"><span class="pre">1</span></code> and increase until the best performance is achieved.</p>
|
|
</section>
|
|
</section>
|
|
</section>
|
|
|
|
|
|
</div>
|
|
</div>
|
|
<footer><div class="rst-footer-buttons" role="navigation" aria-label="Footer">
|
|
<a href="lora.html" class="btn btn-neutral float-left" title="Run gpt-2b + LoRA using GptManager / cpp runtime" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
|
|
<a href="performance_analysis.html" class="btn btn-neutral float-right" title="Performance Analysis of TensorRT-LLM" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
|
|
</div>
|
|
|
|
<hr/>
|
|
|
|
<div role="contentinfo">
|
|
<p>© Copyright 2023, NVidia.</p>
|
|
</div>
|
|
|
|
Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
|
|
<a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
|
|
provided by <a href="https://readthedocs.org">Read the Docs</a>.
|
|
|
|
|
|
</footer>
|
|
</div>
|
|
</div>
|
|
</section>
|
|
</div>
|
|
<script>
|
|
jQuery(function () {
|
|
SphinxRtdTheme.Navigation.enable(true);
|
|
});
|
|
</script>
|
|
|
|
</body>
|
|
</html> |