TensorRT-LLMs/architecture/workflow.html
2024-12-25 13:44:02 +08:00

414 lines
40 KiB
HTML
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<!DOCTYPE html>
<html class="writer-html5" lang="en" data-content_root="../">
<head>
<meta charset="utf-8" /><meta name="viewport" content="width=device-width, initial-scale=1" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>TensorRT-LLM Build Workflow &mdash; tensorrt_llm documentation</title>
<link rel="stylesheet" type="text/css" href="../_static/pygments.css?v=80d5e7a1" />
<link rel="stylesheet" type="text/css" href="../_static/css/theme.css?v=e59714d7" />
<link rel="stylesheet" type="text/css" href="../_static/copybutton.css?v=76b2166b" />
<script src="../_static/jquery.js?v=5d32c60e"></script>
<script src="../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
<script src="../_static/documentation_options.js?v=5929fcd5"></script>
<script src="../_static/doctools.js?v=888ff710"></script>
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
<script src="../_static/clipboard.min.js?v=a7894cd8"></script>
<script src="../_static/copybutton.js?v=65e89d2a"></script>
<script src="../_static/js/theme.js"></script>
<link rel="index" title="Index" href="../genindex.html" />
<link rel="search" title="Search" href="../search.html" />
<link rel="next" title="Adding a Model" href="add-model.html" />
<link rel="prev" title="TensorRT-LLM Checkpoint" href="checkpoint.html" />
</head>
<body class="wy-body-for-nav">
<div class="wy-grid-for-nav">
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
<div class="wy-side-scroll">
<div class="wy-side-nav-search" >
<a href="../index.html" class="icon icon-home">
tensorrt_llm
</a>
<div role="search">
<form id="rtd-search-form" class="wy-form" action="../search.html" method="get">
<input type="text" name="q" placeholder="Search docs" aria-label="Search docs" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
</div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../overview.html">Overview</a></li>
<li class="toctree-l1"><a class="reference internal" href="../quick-start-guide.html">Quick Start Guide</a></li>
<li class="toctree-l1"><a class="reference internal" href="../key-features.html">Key Features</a></li>
<li class="toctree-l1"><a class="reference internal" href="../release-notes.html">Release Notes</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Installation</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../installation/linux.html">Installing on Linux</a></li>
<li class="toctree-l1"><a class="reference internal" href="../installation/build-from-source-linux.html">Building from Source Code on Linux</a></li>
<li class="toctree-l1"><a class="reference internal" href="../installation/windows.html">Installing on Windows</a></li>
<li class="toctree-l1"><a class="reference internal" href="../installation/build-from-source-windows.html">Building from Source Code on Windows</a></li>
<li class="toctree-l1"><a class="reference internal" href="../installation/grace-hopper.html">Installing on Grace Hopper</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">LLM API</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../llm-api/index.html">API Introduction</a></li>
<li class="toctree-l1"><a class="reference internal" href="../llm-api/reference.html">API Reference</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">LLM API Examples</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../llm-api-examples/index.html">LLM Examples Introduction</a></li>
<li class="toctree-l1"><a class="reference internal" href="../llm-api-examples/customization.html">Common Customizations</a></li>
<li class="toctree-l1"><a class="reference internal" href="../llm-api-examples/llm_api_examples.html">Examples</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Model Definition API</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../python-api/tensorrt_llm.layers.html">Layers</a></li>
<li class="toctree-l1"><a class="reference internal" href="../python-api/tensorrt_llm.functional.html">Functionals</a></li>
<li class="toctree-l1"><a class="reference internal" href="../python-api/tensorrt_llm.models.html">Models</a></li>
<li class="toctree-l1"><a class="reference internal" href="../python-api/tensorrt_llm.plugin.html">Plugin</a></li>
<li class="toctree-l1"><a class="reference internal" href="../python-api/tensorrt_llm.quantization.html">Quantization</a></li>
<li class="toctree-l1"><a class="reference internal" href="../python-api/tensorrt_llm.runtime.html">Runtime</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">C++ API</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../_cpp_gen/executor.html">Executor</a></li>
<li class="toctree-l1"><a class="reference internal" href="../_cpp_gen/runtime.html">Runtime</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Command-Line Reference</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../commands/trtllm-build.html">trtllm-build</a></li>
<li class="toctree-l1"><a class="reference internal" href="../commands/trtllm-serve.html">trtllm-serve</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Architecture</span></p>
<ul class="current">
<li class="toctree-l1"><a class="reference internal" href="overview.html">TensorRT-LLM Architecture</a></li>
<li class="toctree-l1"><a class="reference internal" href="core-concepts.html">Model Definition</a></li>
<li class="toctree-l1"><a class="reference internal" href="core-concepts.html#compilation">Compilation</a></li>
<li class="toctree-l1"><a class="reference internal" href="core-concepts.html#runtime">Runtime</a></li>
<li class="toctree-l1"><a class="reference internal" href="core-concepts.html#multi-gpu-and-multi-node-support">Multi-GPU and Multi-Node Support</a></li>
<li class="toctree-l1"><a class="reference internal" href="checkpoint.html">TensorRT-LLM Checkpoint</a></li>
<li class="toctree-l1 current"><a class="current reference internal" href="#">TensorRT-LLM Build Workflow</a><ul>
<li class="toctree-l2"><a class="reference internal" href="#overview">Overview</a></li>
<li class="toctree-l2"><a class="reference internal" href="#conversion-apis">Conversion APIs</a></li>
<li class="toctree-l2"><a class="reference internal" href="#quantization-apis">Quantization APIs</a></li>
<li class="toctree-l2"><a class="reference internal" href="#build-apis">Build APIs</a></li>
<li class="toctree-l2"><a class="reference internal" href="#cli-tools">CLI Tools</a></li>
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="add-model.html">Adding a Model</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Advanced</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../advanced/gpt-attention.html">Multi-Head, Multi-Query, and Group-Query Attention</a></li>
<li class="toctree-l1"><a class="reference internal" href="../advanced/gpt-runtime.html">C++ GPT Runtime</a></li>
<li class="toctree-l1"><a class="reference internal" href="../advanced/executor.html">Executor API</a></li>
<li class="toctree-l1"><a class="reference internal" href="../advanced/graph-rewriting.html">Graph Rewriting Module</a></li>
<li class="toctree-l1"><a class="reference internal" href="../advanced/inference-request.html">Inference Request</a></li>
<li class="toctree-l1"><a class="reference internal" href="../advanced/inference-request.html#responses">Responses</a></li>
<li class="toctree-l1"><a class="reference internal" href="../advanced/lora.html">Run gpt-2b + LoRA using GptManager / cpp runtime</a></li>
<li class="toctree-l1"><a class="reference internal" href="../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
<li class="toctree-l1"><a class="reference internal" href="../advanced/speculative-decoding.html">Speculative Sampling</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../performance/perf-overview.html">Overview</a></li>
<li class="toctree-l1"><a class="reference internal" href="../performance/perf-benchmarking.html">Benchmarking</a></li>
<li class="toctree-l1"><a class="reference internal" href="../performance/perf-best-practices.html">Best Practices</a></li>
<li class="toctree-l1"><a class="reference internal" href="../performance/perf-analysis.html">Performance Analysis</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Reference</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../reference/troubleshooting.html">Troubleshooting</a></li>
<li class="toctree-l1"><a class="reference internal" href="../reference/support-matrix.html">Support Matrix</a></li>
<li class="toctree-l1"><a class="reference internal" href="../reference/precision.html">Numerical Precision</a></li>
<li class="toctree-l1"><a class="reference internal" href="../reference/memory.html">Memory Usage of TensorRT-LLM</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Blogs</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../blogs/H100vsA100.html">H100 has 4.6x A100 Performance in TensorRT-LLM, achieving 10,000 tok/s at 100ms to first token</a></li>
<li class="toctree-l1"><a class="reference internal" href="../blogs/H200launch.html">H200 achieves nearly 12,000 tokens/sec on Llama2-13B with TensorRT-LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../blogs/Falcon180B-H200.html">Falcon-180B on a single H200 GPU with INT4 AWQ, and 6.7x faster Llama-70B over A100</a></li>
<li class="toctree-l1"><a class="reference internal" href="../blogs/quantization-in-TRT-LLM.html">Speed up inference with SOTA quantization techniques in TRT-LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../blogs/XQA-kernel.html">New XQA-kernel provides 2.4x more Llama-70B throughput within the same latency budget</a></li>
</ul>
</div>
</div>
</nav>
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
<a href="../index.html">tensorrt_llm</a>
</nav>
<div class="wy-nav-content">
<div class="rst-content">
<div role="navigation" aria-label="Page navigation">
<ul class="wy-breadcrumbs">
<li><a href="../index.html" class="icon icon-home" aria-label="Home"></a></li>
<li class="breadcrumb-item active">TensorRT-LLM Build Workflow</li>
<li class="wy-breadcrumbs-aside">
<a href="../_sources/architecture/workflow.md.txt" rel="nofollow"> View page source</a>
</li>
</ul>
<hr/>
</div>
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
<div itemprop="articleBody">
<section id="tensorrt-llm-build-workflow">
<h1>TensorRT-LLM Build Workflow<a class="headerlink" href="#tensorrt-llm-build-workflow" title="Link to this heading"></a></h1>
<section id="overview">
<h2>Overview<a class="headerlink" href="#overview" title="Link to this heading"></a></h2>
<p>The build workflow contains two major steps.</p>
<ol class="arabic simple">
<li><p>Create TensorRT-LLM models from existing model checkpoints exported by the training framework.</p></li>
<li><p>Build the TensorRT-LLM models to TensorRT-LLM engines.</p></li>
</ol>
<p>To generalize the TensorRT-LLM optimization features to all models, and to share the same workflow between different models for TensorRT-LLM users, TensorRT-LLM has conventions about how the models shall be defined and how the models shall be imported.</p>
<p>TensorRT-LLM checkpoint convention is documented in <a class="reference internal" href="checkpoint.html"><span class="std std-doc">TensorRT-LLM Checkpoint</span></a> and all decoder-only models had been migrated to adopt the convention. Model-specific convert_checkpoint.py scripts are shipped as source code in example directories, and a trtllm-build CLI tool had been added. However, there are some disadvantages of providing convert checkpoint scripts outside the core TensorRT-LLM lib as example:</p>
<ol class="arabic simple">
<li><p>TensorRT-LLM evolves so quickly that the models definition code might have changed for better performance; which means the <code class="docutils literal notranslate"><span class="pre">convert_checkpoint.py</span></code> is out of date.</p></li>
<li><p>TensorRT-LLM is creating a new set of high-level APIs which handle model conversion, engine building, and inference in one class for easier-of-use. Thus, the high-level APIs need to call the weights conversion code, which shall be part of TensorRT-LLM core lib, not the example. And the conversion code of different models shall have same interface such that the high-level APIs do not need to add many ad-hoc code for different models.</p></li>
</ol>
<p>To mitigate these issues, the model specific <code class="docutils literal notranslate"><span class="pre">convert_checkpoint.py</span></code> scripts are being refactored. Most of the conversion code will be moved into core lib, sitting next to the model definition. Refer to <code class="docutils literal notranslate"><span class="pre">tensorrt_llm/models/llama/</span></code> as an example. There is a new set of APIs for importing models and converting weights. The 0.9 release refactored the LLaMA model class to adopt the new APIs, others models refactor work is ongoing.</p>
</section>
<section id="conversion-apis">
<h2>Conversion APIs<a class="headerlink" href="#conversion-apis" title="Link to this heading"></a></h2>
<p>The API for weight conversion of the LLaMA model looks like this. A <code class="docutils literal notranslate"><span class="pre">TopModelMixin</span></code> class is introduced, <code class="docutils literal notranslate"><span class="pre">from_hugging_face()</span></code> interface is declared, the <code class="docutils literal notranslate"><span class="pre">LLaMAForCausalLM</span></code> class inherits <code class="docutils literal notranslate"><span class="pre">TopModelMixin</span></code> (not direct parent class, but in its base class hierarchy), and implements the interface.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">TopModelMixin</span>
<span class="nd">@classmethod</span>
<span class="k">def</span> <span class="nf">from_hugging_face</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span>
<span class="n">hf_model_dir</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
<span class="n">dtype</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="s1">&#39;float16&#39;</span><span class="p">,</span>
<span class="n">mapping</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Mapping</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">(</span><span class="s2">&quot;Subclass shall override this&quot;</span><span class="p">)</span>
<span class="c1"># TopModelMixin is in the part of base class hierarchy</span>
<span class="k">class</span> <span class="nc">LLaMAForCausalLM</span> <span class="p">(</span><span class="n">DecoderModelForCausalLM</span><span class="p">):</span>
<span class="nd">@classmethod</span>
<span class="k">def</span> <span class="nf">from_hugging_face</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span>
<span class="n">hf_model_dir</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="s1">&#39;float16&#39;</span><span class="p">,</span>
<span class="n">mapping</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Mapping</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">LLaMAForCausalLM</span><span class="p">:</span>
<span class="c1"># creating a TensorRT-LLM llama model object</span>
<span class="c1"># converting HuggingFace checkpoint to TensorRT-LLM expected weights dict</span>
<span class="c1"># Load the weights to llama model object</span>
</pre></div>
</div>
<p>Then, in the convert_checkpoint.py script in the
<a class="reference external" href="https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/llama/"><code class="docutils literal notranslate"><span class="pre">examples/llama/</span></code></a> directory of the GitHub repo,
the logic can be greatly simplified. Even if the model definition code of TensorRT-LLM LLaMA class is changed due to some reason, the <code class="docutils literal notranslate"><span class="pre">from_hugging_face</span></code> API will keep the same, thus the existing workflow using this interface will not be affected.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="c1">#other args omitted for simplicity here.</span>
<span class="n">llama</span> <span class="o">=</span> <span class="n">LLaMAForCausalLM</span><span class="o">.</span><span class="n">from_hugging_face</span><span class="p">(</span><span class="n">model_dir</span><span class="p">,</span> <span class="n">dtype</span><span class="p">,</span> <span class="n">mapping</span><span class="o">=</span><span class="n">mapping</span><span class="p">)</span>
<span class="n">llama</span><span class="o">.</span><span class="n">save_checkpoint</span><span class="p">(</span><span class="n">output_dir</span><span class="p">,</span> <span class="n">save_config</span><span class="o">=</span><span class="p">(</span><span class="n">rank</span><span class="o">==</span><span class="mi">0</span><span class="p">))</span>
</pre></div>
</div>
<p>The <code class="docutils literal notranslate"><span class="pre">from_hugging_face</span></code> API does not save the checkpoint into disk intentionally, instead it returns an in-memory object. Call <code class="docutils literal notranslate"><span class="pre">save_checkpoint</span></code> to save the models. This keeps the flexibility and makes the flow of convert-&gt;build in one process faster. Typically, saving and loading disk for large models are slower and thus should be avoided.</p>
<p>Since LLaMA models were also released with different formats, such as the Meta checkpoint, the <code class="docutils literal notranslate"><span class="pre">LLaMAForCausalLM</span></code> class has a <code class="docutils literal notranslate"><span class="pre">from_meta_ckpt</span></code> function for that. This function is not declared in the <code class="docutils literal notranslate"><span class="pre">TopModelMixin</span></code> class due to it being LLaMA specific, and therefore, other models dont use it.</p>
<p>In the 0.9 release, only LLaMA is refactored. Since popular LLaMA (and its variants) models are released by Hugging Face and Meta checkpoint formats, only these two functions are implemented.</p>
<p>In future releases, there might be <code class="docutils literal notranslate"><span class="pre">from_jax</span></code>, <code class="docutils literal notranslate"><span class="pre">from_nemo</span></code>, <code class="docutils literal notranslate"><span class="pre">from_keras</span></code> or other factory methods for different training checkpoints added.
For example, the Gemma 2B model and the convert_checkpoint.py file in the <a class="reference external" href="https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/gemma/"><code class="docutils literal notranslate"><span class="pre">examples/gemma</span></code></a>
directory support JAX and Keras formats in addition to Hugging Face. The model developers can choose to implement <strong>any subset</strong> of these factory methods for the models they contributed to TensorRT-LLM.</p>
<p>For some formats which are not supported by TensorRT-LLM model developers, you still have the freedom to implement your own weights conversion outside the core lib; the flow will look like this:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">config</span> <span class="o">=</span> <span class="n">read_config_from_the_custom_training_checkpoint</span><span class="p">(</span><span class="n">model_dir</span><span class="p">)</span>
<span class="n">llama</span> <span class="o">=</span> <span class="n">LLaMAForCausalLM</span><span class="p">(</span><span class="n">config</span><span class="p">)</span>
<span class="c1"># option 1:</span>
<span class="c1"># Create a weights dict and then calls LLaMAForCausalLM.load</span>
<span class="n">weights_dict</span> <span class="o">=</span> <span class="n">convert_weights_from_custom_training_checkpoint</span><span class="p">(</span><span class="n">model_dir</span><span class="p">)</span>
<span class="n">llama</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">weights_dict</span><span class="p">)</span>
<span class="c1"># option 2:</span>
<span class="c1"># Internally assign the model parameters directly</span>
<span class="n">convert_and_load_weights_into_trtllm_llama</span><span class="p">(</span><span class="n">llama</span><span class="p">,</span> <span class="n">model_dir</span><span class="p">)</span>
<span class="c1"># Use the llama object as usual, to save the checkpoint or build engines</span>
</pre></div>
</div>
<p>Though there are some limitations and pitfalls of doing these custom weights loading, if the model definition is inside TensorRT-LLM core lib, and the weights loading/conversion are outside the core lib, the conversion code might need to be updated when new TensorRT-LLM is released.</p>
</section>
<section id="quantization-apis">
<h2>Quantization APIs<a class="headerlink" href="#quantization-apis" title="Link to this heading"></a></h2>
<p>TensorRT-LLM relies on NVIDIA Modelopt toolkit to support some of the quantization like: FP8, W4A16_AWQ, W4A8_AWQ, while it also has some its own quantization implementation for Smooth Quant, INT8 KV cache, and INT4/INT8 weight only.</p>
<p>In TensorRT-LLM 0.8 version:</p>
<ul class="simple">
<li><p>For Modelopt-supported quantization algorithms, a standalone script,
<a class="reference external" href="https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/quantization/quantize.py">example/quantization/quantize.py</a>
can export TensorRT-LLM checkpoints, and the trtllm-build command needs to be executed to build the checkpoints to engines.</p></li>
<li><p>For the non-Modelopt quantization algorithms, users need to use the per-model convert_checkpoint.py scripts to export TensorRT-LLM checkpoints.</p></li>
</ul>
<p>Use the <code class="docutils literal notranslate"><span class="pre">quantize()</span></code> interface to unify the different quantization flows. The default implementation is added in the <code class="docutils literal notranslate"><span class="pre">PretrainedModel</span></code> class.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">PretrainedModel</span><span class="p">:</span>
<span class="nd">@classmethod</span>
<span class="k">def</span> <span class="nf">quantize</span><span class="p">(</span>
<span class="bp">cls</span><span class="p">,</span>
<span class="n">hf_model_dir</span><span class="p">,</span>
<span class="n">output_dir</span><span class="p">,</span>
<span class="n">quant_config</span><span class="p">:</span> <span class="n">QuantConfig</span><span class="p">,</span>
<span class="n">mapping</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Mapping</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span> <span class="c1">#some args are omitted here</span>
<span class="c1"># Internally quantize the given hugging face models using Modelopt</span>
<span class="c1"># and save the checkpoint to output_dir</span>
</pre></div>
</div>
<ul class="simple">
<li><p>The default implementation only handles the Modelopt supported quantization. The LLaMA class then inherits this <code class="docutils literal notranslate"><span class="pre">PretrainedModel</span></code> and dispatches the Modelopt quantization to the super classs default implementation.</p></li>
<li><p>The model developer raises errors in the sub-class implementation if the new model is not supported by Modelopt yet.</p></li>
</ul>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">LLaMAForCausalLM</span><span class="p">:</span>
<span class="nd">@classmethod</span>
<span class="k">def</span> <span class="nf">quantize</span><span class="p">(</span>
<span class="bp">cls</span><span class="p">,</span>
<span class="n">hf_model_dir</span><span class="p">,</span>
<span class="n">output_dir</span><span class="p">,</span>
<span class="n">quant_config</span><span class="p">:</span> <span class="n">QuantiConfig</span><span class="p">,</span>
<span class="n">mapping</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Mapping</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span> <span class="c1">#some args are omitted here</span>
<span class="n">use_modelopt_quantization</span> <span class="o">=</span> <span class="o">...</span> <span class="c1"># determine if to use Modelopt or use native</span>
<span class="k">if</span> <span class="n">use_modelopt_quantization</span><span class="p">:</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">quantize</span><span class="p">(</span><span class="n">hf_model_dir</span><span class="p">,</span>
<span class="n">output_dir</span><span class="p">,</span>
<span class="n">quant_config</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="c1"># handles TensorRT-LLM native model specific quantization</span>
<span class="c1"># or raise exceptions if not supported</span>
</pre></div>
</div>
<p>The <code class="docutils literal notranslate"><span class="pre">quantize</span></code> API is designed to take multi-GPU resources internally to make quantization. For example, a LLaMA 70B BF16 takes 140G memory, if we make FP8 quantization, then, another 70G is needed. So, we need at least 210G, 4 * A100(H100) is needed to quantize the LLaMA 70B model. If you want to call <code class="docutils literal notranslate"><span class="pre">quantize</span></code> API inside a MPI program, be cautious and ensure the quantize API is only called by rank 0.</p>
<p>Usage of the <code class="docutils literal notranslate"><span class="pre">quantize</span></code> API in an MPI program looks like this, only rank 0 calls it. In an non-MPI program, the <code class="docutils literal notranslate"><span class="pre">if</span> <span class="pre">rank</span> <span class="pre">==</span> <span class="pre">0</span></code> and the <code class="docutils literal notranslate"><span class="pre">mpi_barrier()</span></code> are not needed.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">quant_config</span> <span class="o">=</span> <span class="n">QuantConfig</span><span class="p">()</span>
<span class="n">quant_config</span><span class="o">.</span><span class="n">quant_algo</span> <span class="o">=</span> <span class="n">quant_mode</span><span class="o">.</span><span class="n">W4A16_AWQ</span>
<span class="n">mapping</span> <span class="o">=</span> <span class="n">Mapping</span><span class="p">(</span><span class="n">world_size</span><span class="o">=</span><span class="n">tp_size</span><span class="p">,</span> <span class="n">tp_size</span><span class="o">=</span><span class="n">tp_size</span><span class="p">)</span>
<span class="k">if</span> <span class="n">rank</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">LLaMAForCausalLM</span><span class="o">.</span><span class="n">quantize</span><span class="p">(</span><span class="n">hf_model_dir</span><span class="p">,</span>
<span class="n">checkpoint_dir</span><span class="p">,</span>
<span class="n">quant_config</span><span class="o">=</span><span class="n">quant_config</span><span class="p">)</span>
<span class="n">mpi_barrier</span><span class="p">()</span> <span class="c1"># wait for rank-o finishes the quantization</span>
<span class="n">llama</span> <span class="o">=</span> <span class="n">LLaMAForCausalLM</span><span class="o">.</span><span class="n">from_checkpoint</span><span class="p">(</span><span class="n">checkpoint_dir</span><span class="p">,</span> <span class="n">rank</span><span class="p">)</span>
<span class="n">engine</span> <span class="o">=</span> <span class="n">build</span><span class="p">(</span><span class="n">llama</span><span class="p">,</span> <span class="n">build_config</span><span class="p">)</span>
<span class="n">engine</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">engine_dir</span><span class="p">)</span>
</pre></div>
</div>
<p>The <code class="docutils literal notranslate"><span class="pre">examples/quantization/quantize.py</span></code> is kept for backward compatibility.</p>
</section>
<section id="build-apis">
<h2>Build APIs<a class="headerlink" href="#build-apis" title="Link to this heading"></a></h2>
<p>The <code class="docutils literal notranslate"><span class="pre">tensorrt_llm.build</span></code> API builds the TensorRT-LLM model object to TensorRT-LLM engine. This new API replaced the older flow: creating a builder, creating a network object, tracing the model to the network, and building TensorRT engines.
The usage of this API looks like this:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">llama</span> <span class="o">=</span> <span class="o">...</span> <span class="c1"># create LLaMAForCausalLM object</span>
<span class="n">build_config</span> <span class="o">=</span> <span class="n">BuildConfig</span><span class="p">(</span><span class="n">max_batch_size</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="n">engine</span> <span class="o">=</span> <span class="n">tensorrt_llm</span><span class="o">.</span><span class="n">build</span><span class="p">(</span><span class="n">llama</span><span class="p">,</span> <span class="n">build_config</span><span class="p">)</span>
<span class="n">engine</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">engine_dir</span><span class="p">)</span>
</pre></div>
</div>
<p>The Llama object can be created by any method mentioned in the <a class="reference internal" href="#conversion-apis"><span class="std std-ref">Conversion APIs</span></a> or <a class="reference internal" href="#quantization-apis"><span class="std std-ref">Quantization APIs</span></a> sections.</p>
<p>The <code class="docutils literal notranslate"><span class="pre">trtllm-build</span></code> CLI tool is a thin wrapper around this <code class="docutils literal notranslate"><span class="pre">tensorrt_llm.build</span></code> API. The flags of the CLI tool are kept close to the fields of the <code class="docutils literal notranslate"><span class="pre">BuildConfig</span></code> class.</p>
<p>If a model were to be saved into disk and then built to the engine later, TensorRT-LLM provides a <code class="docutils literal notranslate"><span class="pre">from_checkpoint</span></code> API to deserialize the checkpoint.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="c1">## TensorRT-LLM code</span>
<span class="k">class</span> <span class="nc">PretrainedModel</span><span class="p">:</span>
<span class="nd">@classmethod</span>
<span class="k">def</span> <span class="nf">from_checkpoint</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span>
<span class="n">ckpt_dir</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
<span class="n">rank</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
<span class="n">config</span><span class="p">:</span> <span class="n">PretrainedConfig</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
<span class="c1"># Internally load the model weights from a given checkpoint directory</span>
</pre></div>
</div>
<p>The <code class="docutils literal notranslate"><span class="pre">from_checkpoint</span></code> API is called to deserialize the checkpoint to a model object. The <code class="docutils literal notranslate"><span class="pre">tensorrt_llm.build</span></code> API can be called to build the engine.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">llama</span> <span class="o">=</span> <span class="n">LLaMAForCausalLM</span><span class="o">.</span><span class="n">from_checkpoint</span><span class="p">(</span><span class="n">checkpoint_dir</span><span class="p">)</span>
<span class="n">engine</span> <span class="o">=</span> <span class="n">build</span><span class="p">(</span><span class="n">llama</span><span class="p">,</span> <span class="n">build_config</span><span class="p">)</span>
<span class="n">engine</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">engine_dir</span><span class="p">)</span>
</pre></div>
</div>
</section>
<section id="cli-tools">
<h2>CLI Tools<a class="headerlink" href="#cli-tools" title="Link to this heading"></a></h2>
<p>All the weights conversion, quantization, and build APIs mentioned above have corresponding CLI tools for convenience.</p>
<ul class="simple">
<li><p>Model specific <code class="docutils literal notranslate"><span class="pre">convert_checkpoint.py</span></code> scripts are inside the <code class="docutils literal notranslate"><span class="pre">examples/&lt;model</span> <span class="pre">xxx&gt;/</span></code> folder.</p></li>
<li><p>A unified quantization script is inside the <code class="docutils literal notranslate"><span class="pre">examples/quantization/quantize.py</span></code> and can be shared by all <strong>supported</strong> models.</p></li>
<li><p>A <code class="docutils literal notranslate"><span class="pre">trtllm-build</span></code> CLI tool builds all models from TensorRT-LLM checkpoint.</p></li>
</ul>
<p>Refer to the following considerations for the CLI tools:</p>
<ul>
<li><p>These scripts and tools should be used for scripting. Do not import the Python functions/class defined in these tools. TensorRT-LLM does not promise the content of these scripts can be compatible with previous versions. The options of these tools may also be changed when its not avoidable.</p></li>
<li><p>These scripts in the example folder may use TensorRT-LLM internal/unstable APIs, which is not guaranteed to work if the examples version and the TensorRT-LLM install version are mismatched. There are some GitHub issues caused by version mismatch.</p>
<ul class="simple">
<li><p>https://github.com/NVIDIA/TensorRT-LLM/issues/1293</p></li>
<li><p>https://github.com/NVIDIA/TensorRT-LLM/issues/1252</p></li>
<li><p>https://github.com/NVIDIA/TensorRT-LLM/issues/1079</p></li>
</ul>
<p>You should always install the same TensorRT-LLM version specified in <code class="docutils literal notranslate"><span class="pre">examples/&lt;model</span> <span class="pre">xxx&gt;/requirements.txt</span></code>.</p>
</li>
<li><p>In the future, the per-model conversion script may or may not be unified to one single script shared by models, given the nature of different models attributes may be different. However, the TensorRT-LLM team will try to make sure the flags for the same feature are consistent between different scripts.</p></li>
<li><p>The TensorRT-LLM team encourages use of the new low-level conversion/quantization/build API instead of these scripts. The conversion APIs will be added model-by-model gradually, which may span a few releases.</p></li>
</ul>
</section>
</section>
</div>
</div>
<footer><div class="rst-footer-buttons" role="navigation" aria-label="Footer">
<a href="checkpoint.html" class="btn btn-neutral float-left" title="TensorRT-LLM Checkpoint" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
<a href="add-model.html" class="btn btn-neutral float-right" title="Adding a Model" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
</div>
<hr/>
<div role="contentinfo">
<jinja2.runtime.BlockReference object at 0x7fed9c65b380>
<div class="footer">
<p>
Copyright © 2024 NVIDIA Corporation
</p>
<p>
<a class="Link" href="https://www.nvidia.com/en-us/about-nvidia/privacy-policy/" target="_blank" rel="noopener"
data-cms-ai="0">Privacy Policy</a> |
<a class="Link" href="https://www.nvidia.com/en-us/about-nvidia/privacy-center/" target="_blank" rel="noopener"
data-cms-ai="0">Manage My Privacy</a> |
<a class="Link" href="https://www.nvidia.com/en-us/preferences/start/" target="_blank" rel="noopener"
data-cms-ai="0">Do Not Sell or Share My Data</a> |
<a class="Link" href="https://www.nvidia.com/en-us/about-nvidia/terms-of-service/" target="_blank"
rel="noopener" data-cms-ai="0">Terms of Service</a> |
<a class="Link" href="https://www.nvidia.com/en-us/about-nvidia/accessibility/" target="_blank" rel="noopener"
data-cms-ai="0">Accessibility</a> |
<a class="Link" href="https://www.nvidia.com/en-us/about-nvidia/company-policies/" target="_blank"
rel="noopener" data-cms-ai="0">Corporate Policies</a> |
<a class="Link" href="https://www.nvidia.com/en-us/product-security/" target="_blank" rel="noopener"
data-cms-ai="0">Product Security</a> |
<a class="Link" href="https://www.nvidia.com/en-us/contact/" target="_blank" rel="noopener"
data-cms-ai="0">Contact</a>
</p>
</div>
</div>
</footer>
</div>
</div>
</section>
</div>
<script>
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
</script>
</body>
</html>