mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
* Update docs * Update windows install version Update gh pages (#2741) update gh pages (#2743) gh pages update (#2746) Update gh-pages (#2764) Update
552 lines
41 KiB
HTML
552 lines
41 KiB
HTML
|
||
|
||
<!DOCTYPE html>
|
||
<html class="writer-html5" lang="en" data-content_root="../">
|
||
<head>
|
||
<meta charset="utf-8" /><meta name="viewport" content="width=device-width, initial-scale=1" />
|
||
|
||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||
<title>FP8 Quantization — tensorrt_llm documentation</title>
|
||
<link rel="stylesheet" type="text/css" href="../_static/pygments.css?v=b86133f3" />
|
||
<link rel="stylesheet" type="text/css" href="../_static/css/theme.css?v=e59714d7" />
|
||
<link rel="stylesheet" type="text/css" href="../_static/copybutton.css?v=76b2166b" />
|
||
|
||
|
||
<script src="../_static/jquery.js?v=5d32c60e"></script>
|
||
<script src="../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||
<script src="../_static/documentation_options.js?v=5929fcd5"></script>
|
||
<script src="../_static/doctools.js?v=9bcbadda"></script>
|
||
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||
<script src="../_static/clipboard.min.js?v=a7894cd8"></script>
|
||
<script src="../_static/copybutton.js?v=65e89d2a"></script>
|
||
<script src="../_static/js/theme.js"></script>
|
||
<link rel="index" title="Index" href="../genindex.html" />
|
||
<link rel="search" title="Search" href="../search.html" />
|
||
</head>
|
||
|
||
<body class="wy-body-for-nav">
|
||
<div class="wy-grid-for-nav">
|
||
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
|
||
<div class="wy-side-scroll">
|
||
<div class="wy-side-nav-search" >
|
||
|
||
|
||
|
||
<a href="../index.html" class="icon icon-home">
|
||
tensorrt_llm
|
||
</a>
|
||
<div role="search">
|
||
<form id="rtd-search-form" class="wy-form" action="../search.html" method="get">
|
||
<input type="text" name="q" placeholder="Search docs" aria-label="Search docs" />
|
||
<input type="hidden" name="check_keywords" value="yes" />
|
||
<input type="hidden" name="area" value="default" />
|
||
</form>
|
||
</div>
|
||
</div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
|
||
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../overview.html">Overview</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../quick-start-guide.html">Quick Start Guide</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../key-features.html">Key Features</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../torch.html">PyTorch Backend</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../release-notes.html">Release Notes</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">Installation</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../installation/linux.html">Installing on Linux</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../installation/build-from-source-linux.html">Building from Source Code on Linux</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../installation/windows.html">Installing on Windows</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../installation/build-from-source-windows.html">Building from Source Code on Windows</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../installation/grace-hopper.html">Installing on Grace Hopper</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">LLM API</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../llm-api/index.html">API Introduction</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../llm-api/reference.html">API Reference</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">LLM API Examples</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../llm-api-examples/index.html">LLM Examples Introduction</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../llm-api-examples/customization.html">Common Customizations</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../llm-api-examples/llm_api_examples.html">Examples</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">Model Definition API</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../python-api/tensorrt_llm.layers.html">Layers</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../python-api/tensorrt_llm.functional.html">Functionals</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../python-api/tensorrt_llm.models.html">Models</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../python-api/tensorrt_llm.plugin.html">Plugin</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../python-api/tensorrt_llm.quantization.html">Quantization</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../python-api/tensorrt_llm.runtime.html">Runtime</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">C++ API</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../_cpp_gen/executor.html">Executor</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../_cpp_gen/runtime.html">Runtime</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">Command-Line Reference</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../commands/trtllm-build.html">trtllm-build</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../commands/trtllm-serve.html">trtllm-serve</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">Architecture</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../architecture/overview.html">TensorRT-LLM Architecture</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../architecture/core-concepts.html">Model Definition</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../architecture/core-concepts.html#compilation">Compilation</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../architecture/core-concepts.html#runtime">Runtime</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../architecture/core-concepts.html#multi-gpu-and-multi-node-support">Multi-GPU and Multi-Node Support</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../architecture/checkpoint.html">TensorRT-LLM Checkpoint</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../architecture/workflow.html">TensorRT-LLM Build Workflow</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../architecture/add-model.html">Adding a Model</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">Advanced</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../advanced/gpt-attention.html">Multi-Head, Multi-Query, and Group-Query Attention</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../advanced/gpt-runtime.html">C++ GPT Runtime</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../advanced/executor.html">Executor API</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../advanced/graph-rewriting.html">Graph Rewriting Module</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../advanced/inference-request.html">Inference Request</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../advanced/inference-request.html#responses">Responses</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../advanced/lora.html">Run gpt-2b + LoRA using GptManager / cpp runtime</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../advanced/speculative-decoding.html">Speculative Sampling</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../advanced/disaggregated-service.html">Disaggregated-Service (experimental)</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="perf-overview.html">Overview</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="perf-benchmarking.html">Benchmarking</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="introduction.html">Best Practices</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="perf-analysis.html">Performance Analysis</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">Reference</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../reference/troubleshooting.html">Troubleshooting</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../reference/support-matrix.html">Support Matrix</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../reference/precision.html">Numerical Precision</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../reference/memory.html">Memory Usage of TensorRT-LLM</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">Blogs</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../blogs/H100vsA100.html">H100 has 4.6x A100 Performance in TensorRT-LLM, achieving 10,000 tok/s at 100ms to first token</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../blogs/H200launch.html">H200 achieves nearly 12,000 tokens/sec on Llama2-13B with TensorRT-LLM</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../blogs/Falcon180B-H200.html">Falcon-180B on a single H200 GPU with INT4 AWQ, and 6.7x faster Llama-70B over A100</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../blogs/quantization-in-TRT-LLM.html">Speed up inference with SOTA quantization techniques in TRT-LLM</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../blogs/XQA-kernel.html">New XQA-kernel provides 2.4x more Llama-70B throughput within the same latency budget</a></li>
|
||
</ul>
|
||
|
||
</div>
|
||
</div>
|
||
</nav>
|
||
|
||
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
|
||
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
|
||
<a href="../index.html">tensorrt_llm</a>
|
||
</nav>
|
||
|
||
<div class="wy-nav-content">
|
||
<div class="rst-content">
|
||
<div role="navigation" aria-label="Page navigation">
|
||
<ul class="wy-breadcrumbs">
|
||
<li><a href="../index.html" class="icon icon-home" aria-label="Home"></a></li>
|
||
<li class="breadcrumb-item active">FP8 Quantization</li>
|
||
<li class="wy-breadcrumbs-aside">
|
||
<a href="../_sources/performance/fp8-quantization.md.txt" rel="nofollow"> View page source</a>
|
||
</li>
|
||
</ul>
|
||
<hr/>
|
||
</div>
|
||
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
|
||
<div itemprop="articleBody">
|
||
|
||
<section id="fp8-quantization">
|
||
<span id="id1"></span><h1>FP8 Quantization<a class="headerlink" href="#fp8-quantization" title="Link to this heading"></a></h1>
|
||
<p>Quantization is a technique that allows models to run in lower precisions like int8 and fp8 while maintaining acceptable output quality. Running in lower precisions can greatly boost performance, significantly increasing throughput and decreasing latency. The tradeoff is a drop in output quality, but in many cases the output quality is still acceptable and many real world deployments utilize quantization. If you want to learn more about quantization refer to <a class="reference external" href="https://developer.nvidia.com/blog/mastering-llm-techniques-inference-optimization/">Mastering LLM Techniques - Inference Optimization</a></p>
|
||
<p>This section walks through enabling fp8 quantization and highlight some fp8 quantization specific configuration options for boosting performance. It also continues the case study of Llama-3.3-70B split across 4 H100-sxm-80GB GPUs via tensor parallelism and showcase the effects of enabling these configuration options on performance.</p>
|
||
<blockquote>
|
||
<div><p>Disclaimer: While performance numbers shown here are real, they are only for demonstration purposes. Differences in environment, SKU, interconnect, and workload can all significantly affect performance and lead to your results differing from what is shown here.</p>
|
||
</div></blockquote>
|
||
<section id="enabling-quantization">
|
||
<h2>Enabling Quantization<a class="headerlink" href="#enabling-quantization" title="Link to this heading"></a></h2>
|
||
<p>To enable quantization you need to configure the <code class="docutils literal notranslate"><span class="pre">QuantConfig</span></code> class and pass it to the <code class="docutils literal notranslate"><span class="pre">quant_config</span></code> parameter of the LLM class. At a minimum the <code class="docutils literal notranslate"><span class="pre">quant_algo</span></code> parameter, which sets the quantization algorithm (fp8, fp8 per token, int8awq, etc.) must be specified. You can find all supported quantization algorithms and other configurable options for <code class="docutils literal notranslate"><span class="pre">QuantConfig</span></code> in the LLM-API->Reference section of the docs. While it is not required if you are using weights/checkpoints from that are already quantized, if you are using an fp16 checkpoint then you also need to specify the calibration dataset that will be used to determine the quantization scales via <code class="docutils literal notranslate"><span class="pre">CalibConfig</span></code>. <code class="docutils literal notranslate"><span class="pre">CalibConfig</span></code> provides several options for setting the calibration dataset that can also be referenced in the LLM-API->Reference section of the docs. Although TensorRT-LLM supports several other types of quantization, this guide focuses on fp8.</p>
|
||
<p>Here is an example of building and saving an fp8 engine from a bf16 checkpoint (Note that fp8 is supported only on devices with compute capability > 8.9 - Ada, Hopper, Blackwell, and beyond):</p>
|
||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">from</span><span class="w"> </span><span class="nn">tensorrt_llm</span><span class="w"> </span><span class="kn">import</span> <span class="n">LLM</span><span class="p">,</span> <span class="n">BuildConfig</span>
|
||
<span class="kn">from</span><span class="w"> </span><span class="nn">tensorrt_llm.llmapi</span><span class="w"> </span><span class="kn">import</span> <span class="n">QuantConfig</span><span class="p">,</span> <span class="n">QuantAlgo</span><span class="p">,</span> <span class="n">CalibConfig</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">main</span><span class="p">():</span>
|
||
|
||
<span class="n">quant_config</span> <span class="o">=</span> <span class="n">QuantConfig</span><span class="p">(</span><span class="n">quant_algo</span><span class="o">=</span><span class="n">QuantAlgo</span><span class="o">.</span><span class="n">FP8</span><span class="p">)</span>
|
||
|
||
<span class="n">calib_config</span> <span class="o">=</span> <span class="n">CalibConfig</span><span class="p">(</span>
|
||
<span class="n">calib_batches</span><span class="o">=</span><span class="mi">512</span><span class="p">,</span>
|
||
<span class="n">calib_batch_size</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">calib_max_seq_length</span><span class="o">=</span><span class="mi">2048</span><span class="p">,</span>
|
||
<span class="n">tokenizer_max_seq_length</span><span class="o">=</span><span class="mi">4096</span>
|
||
<span class="p">)</span>
|
||
|
||
<span class="n">build_config</span> <span class="o">=</span> <span class="n">BuildConfig</span><span class="p">(</span>
|
||
<span class="n">max_num_tokens</span><span class="o">=</span><span class="mi">2048</span><span class="p">,</span>
|
||
<span class="n">max_batch_size</span><span class="o">=</span><span class="mi">512</span><span class="p">,</span>
|
||
<span class="p">)</span>
|
||
|
||
<span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_paged_context_fmha</span> <span class="o">=</span> <span class="kc">True</span>
|
||
<span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">multiple_profiles</span> <span class="o">=</span> <span class="kc">True</span>
|
||
|
||
<span class="n">llm</span> <span class="o">=</span> <span class="n">LLM</span><span class="p">(</span>
|
||
<span class="n">model</span><span class="o">=</span><span class="s2">"/path/to/Llama-3.3-70B"</span><span class="p">,</span>
|
||
<span class="n">tensor_parallel_size</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span>
|
||
<span class="n">pipeline_parallel_size</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">build_config</span><span class="o">=</span><span class="n">build_config</span><span class="p">,</span>
|
||
<span class="n">quant_config</span><span class="o">=</span><span class="n">quant_config</span><span class="p">,</span>
|
||
<span class="n">calib_config</span><span class="o">=</span><span class="n">calib_config</span>
|
||
<span class="p">)</span>
|
||
|
||
<span class="n">llm</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="s2">"baseline_fp8_engine"</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">'__main__'</span><span class="p">:</span>
|
||
<span class="n">main</span><span class="p">()</span>
|
||
</pre></div>
|
||
</div>
|
||
<p>For an example of how to build an fp8 engine using the <a class="reference internal" href="benchmarking-default-performance.html#building-and-saving-engines-via-cli"><span class="std std-ref">TensorRT-LLM CLI workflow</span></a> flow see <a class="reference external" href="https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/llama">TensorRT-LLM LLaMA examples</a>. In short you first run <a class="reference external" href="https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/quantization"><code class="docutils literal notranslate"><span class="pre">examples/quantization/quantize.py</span></code></a> to quantize and convert the model checkpoint to TensorRT-LLM format and then use <code class="docutils literal notranslate"><span class="pre">trtllm-build</span></code>.</p>
|
||
<blockquote>
|
||
<div><p><em><strong>Note: While quantization aims to preserve model accuracy this is not guaranteed and it is extremely important you check that the quality of outputs remains sufficient after quantization.</strong></em></p>
|
||
</div></blockquote>
|
||
</section>
|
||
<section id="fp8-baseline-performance">
|
||
<h2>FP8 “Baseline” Performance<a class="headerlink" href="#fp8-baseline-performance" title="Link to this heading"></a></h2>
|
||
<p>Benchmarking the engine produced by the example above yielded the following performance results. Note that we enabled some of the build flags we mentioned <a class="reference internal" href="useful-build-time-flags.html"><span class="std std-doc">earlier</span></a> (multiple profiles, paged_context_fmha) and also tuned max batch size and max num tokens. This is done to give a sense of what performance is achievable if you tune an fp8 engine but exclude options that have been tailored for quantization. We recommend disabling the gemm plugin for quantized engines which is why it is not included here (it is off by default). Reduce fusion has a quantization specific optimization that will be covered later. For the remainder of this page we will refer to this setup as the “baseline” numbers for fp8.</p>
|
||
<table class="docutils align-default">
|
||
<thead>
|
||
<tr class="row-odd"><th class="head"><p>Metric</p></th>
|
||
<th class="head"><p>Value</p></th>
|
||
</tr>
|
||
</thead>
|
||
<tbody>
|
||
<tr class="row-even"><td><p>Token Throughput (tokens/sec)</p></td>
|
||
<td><p>3389.5305</p></td>
|
||
</tr>
|
||
<tr class="row-odd"><td><p>Request Throughput (req/sec)</p></td>
|
||
<td><p>1.6550</p></td>
|
||
</tr>
|
||
<tr class="row-even"><td><p>Average Time To First Token (ms)</p></td>
|
||
<td><p>96.1597</p></td>
|
||
</tr>
|
||
<tr class="row-odd"><td><p>Average Inter-Token Latency (ms)</p></td>
|
||
<td><p>12.4248</p></td>
|
||
</tr>
|
||
</tbody>
|
||
</table>
|
||
</section>
|
||
<section id="quantized-kv-cache">
|
||
<h2>Quantized KV-Cache<a class="headerlink" href="#quantized-kv-cache" title="Link to this heading"></a></h2>
|
||
<p>By default the KV-Cache is not quantized but TensorRT-LLM supports quantizing the KV-Cache to further improve performance. However, quantizing the model more aggressively also increases the risk of model output quality degrading so it is important to check that when using this feature.</p>
|
||
<section id="enabling-quantized-kv-cache">
|
||
<h3>Enabling Quantized KV Cache<a class="headerlink" href="#enabling-quantized-kv-cache" title="Link to this heading"></a></h3>
|
||
<p>The LLM-API exposes the quantization algorithm to be used for kv cache via the <code class="docutils literal notranslate"><span class="pre">kv_cache_quant_algo</span></code> field in <code class="docutils literal notranslate"><span class="pre">QuantConfig</span></code>. To enable fp8 kv cache, you would modify <code class="docutils literal notranslate"><span class="pre">QuantConfig</span></code> as such:</p>
|
||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">quant_config</span> <span class="o">=</span> <span class="n">QuantConfig</span><span class="p">(</span><span class="n">quant_algo</span><span class="o">=</span><span class="n">QuantAlgo</span><span class="o">.</span><span class="n">FP8</span><span class="p">,</span>
|
||
<span class="n">kv_cache_quant_algo</span><span class="o">=</span><span class="n">QuantAlgo</span><span class="o">.</span><span class="n">FP8</span><span class="p">)</span>
|
||
</pre></div>
|
||
</div>
|
||
<p>If you are using the <a class="reference internal" href="benchmarking-default-performance.html#building-and-saving-engines-via-cli"><span class="std std-ref">CLI flow for building engines</span></a> pass <code class="docutils literal notranslate"><span class="pre">--kv_cache_dtype</span> <span class="pre">fp8</span></code> to <a class="reference external" href="https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/quantization"><code class="docutils literal notranslate"><span class="pre">examples/quantization/quantize.py</span></code></a>.</p>
|
||
</section>
|
||
<section id="performance-with-quantized-kv-cache">
|
||
<h3>Performance with Quantized KV Cache<a class="headerlink" href="#performance-with-quantized-kv-cache" title="Link to this heading"></a></h3>
|
||
<table class="docutils align-default">
|
||
<thead>
|
||
<tr class="row-odd"><th class="head"><p>Metric</p></th>
|
||
<th class="head"><p>Baseline</p></th>
|
||
<th class="head"><p>FP8 KV-Cache ON</p></th>
|
||
</tr>
|
||
</thead>
|
||
<tbody>
|
||
<tr class="row-even"><td><p>Token Throughput (tokens/sec)</p></td>
|
||
<td><p>3389.5305</p></td>
|
||
<td><p>5299.6372</p></td>
|
||
</tr>
|
||
<tr class="row-odd"><td><p>Request Throughput (req/sec)</p></td>
|
||
<td><p>1.6550</p></td>
|
||
<td><p>2.5877</p></td>
|
||
</tr>
|
||
<tr class="row-even"><td><p>Average Time To First Token (ms)</p></td>
|
||
<td><p>96.1597</p></td>
|
||
<td><p>97.1287</p></td>
|
||
</tr>
|
||
<tr class="row-odd"><td><p>Average Inter-Token Latency (ms)</p></td>
|
||
<td><p>12.4248</p></td>
|
||
<td><p>12.5496</p></td>
|
||
</tr>
|
||
</tbody>
|
||
</table>
|
||
</section>
|
||
</section>
|
||
<section id="reduce-norm-fusion-with-user-buffers-for-llama-models">
|
||
<h2>Reduce Norm Fusion with User Buffers for Llama Models<a class="headerlink" href="#reduce-norm-fusion-with-user-buffers-for-llama-models" title="Link to this heading"></a></h2>
|
||
<p>The <a class="reference internal" href="useful-build-time-flags.html#reduce-norm-fusion-plugin-for-llama-models"><span class="std std-ref">Reduce Norm Fusion</span></a> feature is supported for fp8. An additional optimization called “User Buffers” is also supported for fp8 models. The user buffer feature aims to eliminate extra copies from the local buffer to the shared buffer in the communication kernel, leading to improved end-to-end performance.</p>
|
||
<section id="enabling-reduce-norm-fusion-with-user-buffers">
|
||
<h3>Enabling Reduce Norm Fusion with User Buffers<a class="headerlink" href="#enabling-reduce-norm-fusion-with-user-buffers" title="Link to this heading"></a></h3>
|
||
<p>To enable reduce norm fusion with user buffers, add the following lines below <code class="docutils literal notranslate"><span class="pre">BuildConfig</span></code>’s initialization</p>
|
||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">reduce_fusion</span> <span class="o">=</span> <span class="kc">True</span>
|
||
<span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">user_buffer</span> <span class="o">=</span> <span class="kc">True</span>
|
||
</pre></div>
|
||
</div>
|
||
<p>If you are using the <a class="reference internal" href="benchmarking-default-performance.html#building-and-saving-engines-via-cli"><span class="std std-ref">CLI flow for building engines</span></a> pass <code class="docutils literal notranslate"><span class="pre">--reduce_fusion</span> <span class="pre">enable</span></code> and <code class="docutils literal notranslate"><span class="pre">--user_buffer</span> <span class="pre">enable</span></code> to <code class="docutils literal notranslate"><span class="pre">trtllm-build</span></code> to enable the feature.</p>
|
||
<blockquote>
|
||
<div><p>Note: You must have enabled <code class="docutils literal notranslate"><span class="pre">reduce_fusion</span></code> in order to enable <code class="docutils literal notranslate"><span class="pre">user_buffer</span></code></p>
|
||
</div></blockquote>
|
||
</section>
|
||
<section id="performance-with-reduce-norm-fusion-user-buffers">
|
||
<h3>Performance with Reduce Norm Fusion + User Buffers:<a class="headerlink" href="#performance-with-reduce-norm-fusion-user-buffers" title="Link to this heading"></a></h3>
|
||
<p>Reduce Norm Fusion + User Buffer ON: Same engine previously referred to as FP8 KV-Cache ON.</p>
|
||
<p>Reduce Norm Fusion + User Buffer ON: Previous example with reduce fusion and user buffers enabled. Max-num tokens set to 16384 and max-batch size set to 512 after tuning.</p>
|
||
<table class="docutils align-default">
|
||
<thead>
|
||
<tr class="row-odd"><th class="head"><p>Metric</p></th>
|
||
<th class="head"><p>Reduce Norm Fusion + User Buffer OFF</p></th>
|
||
<th class="head"><p>Reduce Norm Fusion + User Buffer ON</p></th>
|
||
</tr>
|
||
</thead>
|
||
<tbody>
|
||
<tr class="row-even"><td><p>Token Throughput (tokens/sec)</p></td>
|
||
<td><p>5299.6372</p></td>
|
||
<td><p>5980.7842</p></td>
|
||
</tr>
|
||
<tr class="row-odd"><td><p>Request Throughput (req/sec)</p></td>
|
||
<td><p>2.5877</p></td>
|
||
<td><p>2.9203</p></td>
|
||
</tr>
|
||
<tr class="row-even"><td><p>Average Time To First Token (ms)</p></td>
|
||
<td><p>97.1287</p></td>
|
||
<td><p>82.2679</p></td>
|
||
</tr>
|
||
<tr class="row-odd"><td><p>Average Inter-Token Latency (ms)</p></td>
|
||
<td><p>12.5496</p></td>
|
||
<td><p>12.6975</p></td>
|
||
</tr>
|
||
</tbody>
|
||
</table>
|
||
</section>
|
||
</section>
|
||
<section id="gemm-swiglu-fusion-in-gated-mlp">
|
||
<h2>GEMM + SwiGLU Fusion in Gated-MLP<a class="headerlink" href="#gemm-swiglu-fusion-in-gated-mlp" title="Link to this heading"></a></h2>
|
||
<p>The GEMM + SwiGLU fusion in Gated-MLP combines two Matmul operations and one SwiGLU operation into a single kernel. Currently this is only supported for FP8 precision on Hopper. While this fusion improves performance, it can slightly reduce accuracy in FP8 PTQ because one quantization scaling factor is discarded.</p>
|
||
<p>We recommend enabling this feature for large models running on Hopper with FP8 precision.We do not recommend enabling this feature for very small workloads or if the
|
||
accuracy loss is unacceptable.</p>
|
||
<section id="enabling-gemm-swiglu-fusion">
|
||
<h3>Enabling GEMM + SwiGLU Fusion<a class="headerlink" href="#enabling-gemm-swiglu-fusion" title="Link to this heading"></a></h3>
|
||
<p>To enable the GEMM + SwiGLU fusion, add the following lines below <code class="docutils literal notranslate"><span class="pre">BuildConfig</span></code>’s initialization</p>
|
||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">gemm_swiglu_plugin</span> <span class="o">=</span> <span class="s1">'fp8'</span>
|
||
</pre></div>
|
||
</div>
|
||
<p>For small batch size cases where latency is important, you can replace the above line with</p>
|
||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">low_latency_gemm_swiglu_plugin</span> <span class="o">=</span> <span class="s1">'fp8'</span>
|
||
</pre></div>
|
||
</div>
|
||
<p>If you are using the <a class="reference internal" href="benchmarking-default-performance.html#building-and-saving-engines-via-cli"><span class="std std-ref">CLI flow for building engines</span></a> pass <code class="docutils literal notranslate"><span class="pre">--gemm_swiglu_plugin=fp8</span></code> or <code class="docutils literal notranslate"><span class="pre">--low_latency_gemm_swiglu_plugin=fp8</span></code> for the low latency case (only include one or the other) to <code class="docutils literal notranslate"><span class="pre">trtllm-build</span></code>.</p>
|
||
</section>
|
||
<section id="performance-with-gemm-swiglu-fusion">
|
||
<h3>Performance with GEMM + SwiGLU Fusion<a class="headerlink" href="#performance-with-gemm-swiglu-fusion" title="Link to this heading"></a></h3>
|
||
<table class="docutils align-default">
|
||
<thead>
|
||
<tr class="row-odd"><th class="head"><p>Metric</p></th>
|
||
<th class="head"><p>GEMM + SwiGLU fusion OFF</p></th>
|
||
<th class="head"><p>GEMM + SwiGLU fusion ON</p></th>
|
||
</tr>
|
||
</thead>
|
||
<tbody>
|
||
<tr class="row-even"><td><p>Token Throughput (tokens/sec)</p></td>
|
||
<td><p>5980.7842</p></td>
|
||
<td><p>5976.7977</p></td>
|
||
</tr>
|
||
<tr class="row-odd"><td><p>Request Throughput (req/sec)</p></td>
|
||
<td><p>2.9203</p></td>
|
||
<td><p>2.9184</p></td>
|
||
</tr>
|
||
<tr class="row-even"><td><p>Average Time To First Token (ms)</p></td>
|
||
<td><p>82.2679</p></td>
|
||
<td><p>81.8841</p></td>
|
||
</tr>
|
||
<tr class="row-odd"><td><p>Average Inter-Token Latency (ms)</p></td>
|
||
<td><p>12.6975</p></td>
|
||
<td><p>11.7031</p></td>
|
||
</tr>
|
||
</tbody>
|
||
</table>
|
||
<p>In this case, the GEMM + SwiGLU plugin performs almost equivalently to when it was disabled. The throughput drop is within run to run variance and the TTFT and ITL improvements are slight. However, we found that when paired with the low latency gemm plugin discussed next, enabling this feature was necessary for getting the maximum throughput.</p>
|
||
</section>
|
||
</section>
|
||
<section id="low-latency-gemm-plugin">
|
||
<h2>Low Latency GEMM Plugin<a class="headerlink" href="#low-latency-gemm-plugin" title="Link to this heading"></a></h2>
|
||
<p>Previously we mentioned the <a class="reference internal" href="useful-build-time-flags.html#gemm-plugin"><span class="std std-ref">GEMM Plugin</span></a> feature. Although it has fp8 support we recommend disabling it (by default it is disabled). However for low-latency scenarios in fp8 we recommend trying the low latency GEMM plugin to see if it is effective for your workload.</p>
|
||
<section id="enabling-low-latency-gemm-plugin">
|
||
<h3>Enabling Low Latency GEMM plugin<a class="headerlink" href="#enabling-low-latency-gemm-plugin" title="Link to this heading"></a></h3>
|
||
<p>To enable the low latency GEMM plugin, add the following lines below <code class="docutils literal notranslate"><span class="pre">BuildConfig</span></code>’s initialization</p>
|
||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">low_latency_gemm_plugin</span> <span class="o">=</span> <span class="s1">'fp8'</span>
|
||
</pre></div>
|
||
</div>
|
||
<p>If you are using the <a class="reference internal" href="benchmarking-default-performance.html#building-and-saving-engines-via-cli"><span class="std std-ref">CLI flow for building engines</span></a> pass <code class="docutils literal notranslate"><span class="pre">--low_latency_gemm_plugin=fp8</span></code> to <code class="docutils literal notranslate"><span class="pre">trtllm-build</span></code> to enable the feature. Again, <strong>we recommend disabling the gemm plugin for fp8</strong> so if you are passing <code class="docutils literal notranslate"><span class="pre">--gemm_plugin=fp8</span></code> to <code class="docutils literal notranslate"><span class="pre">trtllm-build</span></code> we recommend removing that.</p>
|
||
</section>
|
||
<section id="performance-with-low-latency-gemm-plugin">
|
||
<h3>Performance with Low Latency GEMM plugin<a class="headerlink" href="#performance-with-low-latency-gemm-plugin" title="Link to this heading"></a></h3>
|
||
<p>Low Latency GEMM ON: Same configuration as previous example but with low latency GEMM plugin enabled. Max num tokens was set to 16384 and max-batch size was set to 512 after tuning.</p>
|
||
<table class="docutils align-default">
|
||
<thead>
|
||
<tr class="row-odd"><th class="head"><p>Metric</p></th>
|
||
<th class="head"><p>Low Latency GEMM OFF</p></th>
|
||
<th class="head"><p>Low Latency GEMM ON</p></th>
|
||
</tr>
|
||
</thead>
|
||
<tbody>
|
||
<tr class="row-even"><td><p>Token Throughput (tokens/sec)</p></td>
|
||
<td><p>5976.7977</p></td>
|
||
<td><p>6049.1625</p></td>
|
||
</tr>
|
||
<tr class="row-odd"><td><p>Request Throughput (req/sec)</p></td>
|
||
<td><p>2.9184</p></td>
|
||
<td><p>2.9537</p></td>
|
||
</tr>
|
||
<tr class="row-even"><td><p>Average Time To First Token (ms)</p></td>
|
||
<td><p>81.8841</p></td>
|
||
<td><p>88.0162</p></td>
|
||
</tr>
|
||
<tr class="row-odd"><td><p>Average Inter-Token Latency (ms)</p></td>
|
||
<td><p>11.7031</p></td>
|
||
<td><p>10.8225</p></td>
|
||
</tr>
|
||
</tbody>
|
||
</table>
|
||
<p>In this case, enabling the low-latency gemm plugin actually provided a meaningful boost to throughput. Additionally it also improved ITL but at the expense of TTFT. Furthermore, when used without the gemm+swiglu fusion, performance was actually worse than with out the plugin turned on. This suggests that for this workload the low-latency gemm plugin was choosing a worse kernel for the gemm right before the swiglu, but once that was handled by the gemm+swiglu fusion custom kernel, the rest of the kernels the low-latency gemm plugin was choosing was better than the baseline, resulting in improved performance. This underscores the importance of benchmarking different settings as the impact of this plugin is highly workload dependent. If possible some grid searching can be useful for extremely performance sensitive workloads</p>
|
||
</section>
|
||
</section>
|
||
<section id="conclusion">
|
||
<h2>Conclusion<a class="headerlink" href="#conclusion" title="Link to this heading"></a></h2>
|
||
<p>Overall leveraging quantization can provide significant uplifts in performance. Here are the performance uplifts from our tuned fp8 model as compared to the tuned fp16 numbers we reached in the <a class="reference internal" href="tuning-max-batch-size-and-max-num-tokens.html"><span class="std std-doc">previous page of guide</span></a></p>
|
||
<table class="docutils align-default">
|
||
<thead>
|
||
<tr class="row-odd"><th class="head"><p>Metric</p></th>
|
||
<th class="head"><p>Tuned FP16 Model</p></th>
|
||
<th class="head"><p>Tuned FP8 Model</p></th>
|
||
<th class="head"><p>% Improvement</p></th>
|
||
</tr>
|
||
</thead>
|
||
<tbody>
|
||
<tr class="row-even"><td><p>Token Throughput (tokens/sec)</p></td>
|
||
<td><p>2474.2581</p></td>
|
||
<td><p>6049.1625</p></td>
|
||
<td><p>144.48</p></td>
|
||
</tr>
|
||
<tr class="row-odd"><td><p>Request Throughput (req/sec)</p></td>
|
||
<td><p>1.2081</p></td>
|
||
<td><p>2.9537</p></td>
|
||
<td><p>144.49</p></td>
|
||
</tr>
|
||
<tr class="row-even"><td><p>Average Time To First Token (ms)</p></td>
|
||
<td><p>147.5742</p></td>
|
||
<td><p>88.0162</p></td>
|
||
<td><p>40.36</p></td>
|
||
</tr>
|
||
<tr class="row-odd"><td><p>Average Inter-Token Latency (ms)</p></td>
|
||
<td><p>14.6852</p></td>
|
||
<td><p>10.8225</p></td>
|
||
<td><p>26.30</p></td>
|
||
</tr>
|
||
</tbody>
|
||
</table>
|
||
<p>Additionally, compared to the fp8 baseline numbers (the baseline numbers had some degree of tuning, see <a class="reference internal" href="#fp8-baseline-performance"><span class="std std-ref">Baseline Performance</span></a> for details), we received the following performance uplifts from enabling the flags discussed above:</p>
|
||
<table class="docutils align-default">
|
||
<thead>
|
||
<tr class="row-odd"><th class="head"><p>Metric</p></th>
|
||
<th class="head"><p>Baseline FP8 Model</p></th>
|
||
<th class="head"><p>Tuned FP8 Model</p></th>
|
||
<th class="head"><p>% Improvement</p></th>
|
||
</tr>
|
||
</thead>
|
||
<tbody>
|
||
<tr class="row-even"><td><p>Token Throughput (tokens/sec)</p></td>
|
||
<td><p>3389.5305</p></td>
|
||
<td><p>6049.1625</p></td>
|
||
<td><p>78.47</p></td>
|
||
</tr>
|
||
<tr class="row-odd"><td><p>Request Throughput (req/sec)</p></td>
|
||
<td><p>1.6550</p></td>
|
||
<td><p>2.9537</p></td>
|
||
<td><p>78.47</p></td>
|
||
</tr>
|
||
<tr class="row-even"><td><p>Average Time To First Token (ms)</p></td>
|
||
<td><p>96.1597</p></td>
|
||
<td><p>88.0162</p></td>
|
||
<td><p>8.47</p></td>
|
||
</tr>
|
||
<tr class="row-odd"><td><p>Average Inter-Token Latency (ms)</p></td>
|
||
<td><p>12.4248</p></td>
|
||
<td><p>10.8225</p></td>
|
||
<td><p>12.90</p></td>
|
||
</tr>
|
||
</tbody>
|
||
</table>
|
||
<p>As mentioned previously, the caveat with leveraging quantization are potential drops in accuracy, and we strongly recommend having a way to test whether model output quality is acceptable before attempting to use quantization. That said, many real world cases successfully use quantization and the significant performance boosts it enables are often worth the effort to see if it is a fit.</p>
|
||
<section id="summary-of-configuration-option-recommendations">
|
||
<h3>Summary of Configuration Option Recommendations:<a class="headerlink" href="#summary-of-configuration-option-recommendations" title="Link to this heading"></a></h3>
|
||
<ol class="arabic simple">
|
||
<li><p>Quantized KV-cache: Typically provides significant throughput boost. We recommend turning it on as long as output quality is still acceptable with the feature enabled.</p></li>
|
||
<li><p>Reduce fusion + user buffers: This feature is only supported on fp8 Llama and Mistral/Mixtral models. Effectiveness is workload dependent so we recommend turning it on and benchmarking to check.</p></li>
|
||
<li><p>Gemm + Swiglu Plugin: This feature is only supported on fp8 models with Swiglu operators like Llama, Mixtral etc. Like reduce fusion effectiveness is workload dependent and we recommend sanity checking effectiveness. Has increased risk of affecting accuracy since it drops a quantization scale.</p></li>
|
||
<li><p>Low-Latency GEMM plugin: Effectiveness is workload dependent so we recommend turning it on and benchmarking. Effectiveness can be affected by other flags as we saw in our case study, so if possible benchmarking various combinations of configuration options is ideal.</p></li>
|
||
</ol>
|
||
</section>
|
||
</section>
|
||
</section>
|
||
|
||
|
||
</div>
|
||
</div>
|
||
<footer>
|
||
|
||
<hr/>
|
||
|
||
<div role="contentinfo">
|
||
<jinja2.runtime.BlockReference object at 0x7f6400606db0>
|
||
|
||
<div class="footer">
|
||
<p>
|
||
Copyright © 2024 NVIDIA Corporation
|
||
</p>
|
||
<p>
|
||
<a class="Link" href="https://www.nvidia.com/en-us/about-nvidia/privacy-policy/" target="_blank" rel="noopener"
|
||
data-cms-ai="0">Privacy Policy</a> |
|
||
<a class="Link" href="https://www.nvidia.com/en-us/about-nvidia/privacy-center/" target="_blank" rel="noopener"
|
||
data-cms-ai="0">Manage My Privacy</a> |
|
||
<a class="Link" href="https://www.nvidia.com/en-us/preferences/start/" target="_blank" rel="noopener"
|
||
data-cms-ai="0">Do Not Sell or Share My Data</a> |
|
||
<a class="Link" href="https://www.nvidia.com/en-us/about-nvidia/terms-of-service/" target="_blank"
|
||
rel="noopener" data-cms-ai="0">Terms of Service</a> |
|
||
<a class="Link" href="https://www.nvidia.com/en-us/about-nvidia/accessibility/" target="_blank" rel="noopener"
|
||
data-cms-ai="0">Accessibility</a> |
|
||
<a class="Link" href="https://www.nvidia.com/en-us/about-nvidia/company-policies/" target="_blank"
|
||
rel="noopener" data-cms-ai="0">Corporate Policies</a> |
|
||
<a class="Link" href="https://www.nvidia.com/en-us/product-security/" target="_blank" rel="noopener"
|
||
data-cms-ai="0">Product Security</a> |
|
||
<a class="Link" href="https://www.nvidia.com/en-us/contact/" target="_blank" rel="noopener"
|
||
data-cms-ai="0">Contact</a>
|
||
</p>
|
||
</div>
|
||
|
||
|
||
</div>
|
||
|
||
|
||
|
||
</footer>
|
||
</div>
|
||
</div>
|
||
</section>
|
||
</div>
|
||
<script>
|
||
jQuery(function () {
|
||
SphinxRtdTheme.Navigation.enable(true);
|
||
});
|
||
</script>
|
||
|
||
</body>
|
||
</html> |