mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
517 lines
33 KiB
HTML
517 lines
33 KiB
HTML
<!DOCTYPE html>
|
||
<html class="writer-html5" lang="en" data-content_root="./">
|
||
<head>
|
||
<meta charset="utf-8" /><meta name="viewport" content="width=device-width, initial-scale=1" />
|
||
|
||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||
<title>New Workflow — tensorrt_llm documentation</title>
|
||
<link rel="stylesheet" type="text/css" href="_static/pygments.css?v=80d5e7a1" />
|
||
<link rel="stylesheet" type="text/css" href="_static/css/theme.css?v=19f00094" />
|
||
|
||
|
||
<!--[if lt IE 9]>
|
||
<script src="_static/js/html5shiv.min.js"></script>
|
||
<![endif]-->
|
||
|
||
<script src="_static/jquery.js?v=5d32c60e"></script>
|
||
<script src="_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||
<script src="_static/documentation_options.js?v=5929fcd5"></script>
|
||
<script src="_static/doctools.js?v=888ff710"></script>
|
||
<script src="_static/sphinx_highlight.js?v=dc90522c"></script>
|
||
<script src="_static/js/theme.js"></script>
|
||
<link rel="index" title="Index" href="genindex.html" />
|
||
<link rel="search" title="Search" href="search.html" />
|
||
<link rel="next" title="Run gpt-2b + LoRA using GptManager / cpp runtime" href="lora.html" />
|
||
<link rel="prev" title="Memory Usage of TensorRT-LLM" href="memory.html" />
|
||
</head>
|
||
|
||
<body class="wy-body-for-nav">
|
||
<div class="wy-grid-for-nav">
|
||
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
|
||
<div class="wy-side-scroll">
|
||
<div class="wy-side-nav-search" >
|
||
|
||
|
||
|
||
<a href="index.html" class="icon icon-home">
|
||
tensorrt_llm
|
||
</a>
|
||
<div role="search">
|
||
<form id="rtd-search-form" class="wy-form" action="search.html" method="get">
|
||
<input type="text" name="q" placeholder="Search docs" aria-label="Search docs" />
|
||
<input type="hidden" name="check_keywords" value="yes" />
|
||
<input type="hidden" name="area" value="default" />
|
||
</form>
|
||
</div>
|
||
</div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
|
||
<p class="caption" role="heading"><span class="caption-text">Contents:</span></p>
|
||
<ul class="current">
|
||
<li class="toctree-l1"><a class="reference internal" href="architecture.html">TensorRT-LLM Architecture</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="gpt_runtime.html">C++ GPT Runtime</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="batch_manager.html">The Batch Manager in TensorRT-LLM</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="inference_request.html">Inference Request</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="gpt_attention.html">Multi-head, Multi-query and Group-query Attention</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="precision.html">Numerical Precision</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="build_from_source.html">Build from Source</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="performance.html">Performance of TensorRT-LLM</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="2023-05-19-how-to-debug.html">How to debug</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="2023-05-17-how-to-add-a-new-model.html">How to add a new model</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="graph-rewriting.html">Graph Rewriting Module</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="memory.html">Memory Usage of TensorRT-LLM</a></li>
|
||
<li class="toctree-l1 current"><a class="current reference internal" href="#">New Workflow</a><ul>
|
||
<li class="toctree-l2"><a class="reference internal" href="#overview">Overview</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="#prepare-the-tensorrt-llm-checkpoint">Prepare the TensorRT-LLM Checkpoint</a><ul>
|
||
<li class="toctree-l3"><a class="reference internal" href="#config">Config</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="#rank-weights">Rank Weights</a><ul>
|
||
<li class="toctree-l4"><a class="reference internal" href="#attention-weights">Attention Weights</a></li>
|
||
<li class="toctree-l4"><a class="reference internal" href="#mlp-weights">MLP Weights</a></li>
|
||
<li class="toctree-l4"><a class="reference internal" href="#layernorm-weights">LayerNorm Weights</a></li>
|
||
<li class="toctree-l4"><a class="reference internal" href="#kv-cache-quantization-scaling-factors">KV Cache Quantization Scaling Factors</a></li>
|
||
<li class="toctree-l4"><a class="reference internal" href="#fp8-quantization-scaling-factors">FP8 Quantization Scaling Factors</a></li>
|
||
<li class="toctree-l4"><a class="reference internal" href="#awq-quantization-scaling-factors">AWQ Quantization Scaling Factors</a></li>
|
||
</ul>
|
||
</li>
|
||
<li class="toctree-l3"><a class="reference internal" href="#example">Example</a></li>
|
||
</ul>
|
||
</li>
|
||
<li class="toctree-l2"><a class="reference internal" href="#build-checkpoint-into-tensorrt-engine">Build Checkpoint into TensorRT Engine</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="#make-evaluation">Make Evaluation</a></li>
|
||
</ul>
|
||
</li>
|
||
<li class="toctree-l1"><a class="reference internal" href="lora.html">Run gpt-2b + LoRA using GptManager / cpp runtime</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="perf_best_practices.html">Best Practices for Tuning the Performance of TensorRT-LLM</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="performance_analysis.html">Performance Analysis of TensorRT-LLM</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="python-api/tensorrt_llm.layers.html">Layers</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="python-api/tensorrt_llm.functional.html">Functionals</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="python-api/tensorrt_llm.models.html">Models</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="python-api/tensorrt_llm.plugin.html">Plugin</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="python-api/tensorrt_llm.quantization.html">Quantization</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="python-api/tensorrt_llm.runtime.html">Runtime</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">C++ API</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="_cpp_gen/runtime.html">Runtime</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">Blogs</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="blogs/H100vsA100.html">H100 has 4.6x A100 Performance in TensorRT-LLM, achieving 10,000 tok/s at 100ms to first token</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="blogs/H200launch.html">H200 achieves nearly 12,000 tokens/sec on Llama2-13B with TensorRT-LLM</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="blogs/Falcon180B-H200.html">Falcon-180B on a single H200 GPU with INT4 AWQ, and 6.7x faster Llama-70B over A100</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="blogs/quantization-in-TRT-LLM.html">Speed up inference with SOTA quantization techniques in TRT-LLM</a></li>
|
||
</ul>
|
||
|
||
</div>
|
||
</div>
|
||
</nav>
|
||
|
||
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
|
||
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
|
||
<a href="index.html">tensorrt_llm</a>
|
||
</nav>
|
||
|
||
<div class="wy-nav-content">
|
||
<div class="rst-content">
|
||
<div role="navigation" aria-label="Page navigation">
|
||
<ul class="wy-breadcrumbs">
|
||
<li><a href="index.html" class="icon icon-home" aria-label="Home"></a></li>
|
||
<li class="breadcrumb-item active">New Workflow</li>
|
||
<li class="wy-breadcrumbs-aside">
|
||
<a href="_sources/new_workflow.md.txt" rel="nofollow"> View page source</a>
|
||
</li>
|
||
</ul>
|
||
<hr/>
|
||
</div>
|
||
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
|
||
<div itemprop="articleBody">
|
||
|
||
<section id="new-workflow">
|
||
<h1>New Workflow<a class="headerlink" href="#new-workflow" title="Link to this heading"></a></h1>
|
||
<section id="overview">
|
||
<h2>Overview<a class="headerlink" href="#overview" title="Link to this heading"></a></h2>
|
||
<p>The first versions of TensorRT-LLM were developed with a very aggressive timeline. For those versions emphasis was not put
|
||
on defining a unified workflow. Now that TensorRT-LLM has reached some level of feature richness, the development team has
|
||
decided to put more efforts into unifying the APIs and workflow of TensorRT-LLM. This document summarises the new workflow
|
||
adopted by TensorRT-LLM at its core.</p>
|
||
<p>There are 3 steps in the new workflow:</p>
|
||
<ol class="arabic simple">
|
||
<li><p>Convert weights from different source frameworks into TensorRT-LLM checkpoint</p></li>
|
||
<li><p>Build the TensorRT-LLM checkpoint into TensorRT engine(s) with a unified build command</p></li>
|
||
<li><p>Load the engine(s) to TensorRT-LLM model runner and make evaluation with different evaluation tasks</p></li>
|
||
</ol>
|
||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">NeMo</span> <span class="o">-------------</span>
|
||
<span class="o">|</span>
|
||
<span class="n">HuggingFace</span> <span class="o">------</span>
|
||
<span class="o">|</span> <span class="n">convert</span> <span class="n">build</span> <span class="n">load</span>
|
||
<span class="n">AMMO</span> <span class="o">-------------</span> <span class="o">----------></span> <span class="n">TensorRT</span><span class="o">-</span><span class="n">LLM</span> <span class="n">Checkpoint</span> <span class="o">--------></span> <span class="n">TensorRT</span> <span class="n">Engine</span> <span class="o">------></span> <span class="n">TensorRT</span><span class="o">-</span><span class="n">LLM</span> <span class="n">ModelRunner</span>
|
||
<span class="o">|</span>
|
||
<span class="n">JAX</span> <span class="o">--------------</span>
|
||
<span class="o">|</span>
|
||
<span class="n">DeepSpeed</span> <span class="o">--------</span>
|
||
</pre></div>
|
||
</div>
|
||
</section>
|
||
<section id="prepare-the-tensorrt-llm-checkpoint">
|
||
<h2>Prepare the TensorRT-LLM Checkpoint<a class="headerlink" href="#prepare-the-tensorrt-llm-checkpoint" title="Link to this heading"></a></h2>
|
||
<p>TensorRT-LLM aims at supporting different of sources:</p>
|
||
<ol class="arabic simple">
|
||
<li><p>Trained models from NeMo, DeepSpeed, JAX</p></li>
|
||
<li><p>Quantized models from AMMO</p></li>
|
||
<li><p>Popular models from HuggingFace</p></li>
|
||
</ol>
|
||
<p>TensorRT-LLM defines its own checkpoint format. A checkpoint directory includes:</p>
|
||
<ol class="arabic simple">
|
||
<li><p>One config json file, which contains several model hyper-parameters</p></li>
|
||
<li><p>One or several rank weights files, each file contains a dictionary of tensors (weights).
|
||
The different files will be loaded by different ranks in a multi-GPU (multi-process) scenario</p></li>
|
||
</ol>
|
||
<section id="config">
|
||
<h3>Config<a class="headerlink" href="#config" title="Link to this heading"></a></h3>
|
||
<table class="docutils align-default">
|
||
<thead>
|
||
<tr class="row-odd"><th class="head text-left"><p>Field</p></th>
|
||
<th class="head text-left"><p>Type</p></th>
|
||
<th class="head text-left"><p>Default Value</p></th>
|
||
</tr>
|
||
</thead>
|
||
<tbody>
|
||
<tr class="row-even"><td class="text-left"><p>architecture</p></td>
|
||
<td class="text-left"><p>string</p></td>
|
||
<td class="text-left"><p>mandatory</p></td>
|
||
</tr>
|
||
<tr class="row-odd"><td class="text-left"><p>dtype</p></td>
|
||
<td class="text-left"><p>string</p></td>
|
||
<td class="text-left"><p>mandatory</p></td>
|
||
</tr>
|
||
<tr class="row-even"><td class="text-left"><p>logits_dtype</p></td>
|
||
<td class="text-left"><p>string</p></td>
|
||
<td class="text-left"><p>‘float32’</p></td>
|
||
</tr>
|
||
<tr class="row-odd"><td class="text-left"><p>vocab_size</p></td>
|
||
<td class="text-left"><p>int</p></td>
|
||
<td class="text-left"><p>mandatory</p></td>
|
||
</tr>
|
||
<tr class="row-even"><td class="text-left"><p>max_position_embeddings</p></td>
|
||
<td class="text-left"><p>int</p></td>
|
||
<td class="text-left"><p>null</p></td>
|
||
</tr>
|
||
<tr class="row-odd"><td class="text-left"><p>hidden_size</p></td>
|
||
<td class="text-left"><p>int</p></td>
|
||
<td class="text-left"><p>mandatory</p></td>
|
||
</tr>
|
||
<tr class="row-even"><td class="text-left"><p>num_hidden_layers</p></td>
|
||
<td class="text-left"><p>int</p></td>
|
||
<td class="text-left"><p>mandatory</p></td>
|
||
</tr>
|
||
<tr class="row-odd"><td class="text-left"><p>num_attention_heads</p></td>
|
||
<td class="text-left"><p>int</p></td>
|
||
<td class="text-left"><p>mandatory</p></td>
|
||
</tr>
|
||
<tr class="row-even"><td class="text-left"><p>num_key_value_heads</p></td>
|
||
<td class="text-left"><p>int</p></td>
|
||
<td class="text-left"><p>num_attention_heads</p></td>
|
||
</tr>
|
||
<tr class="row-odd"><td class="text-left"><p>hidden_act</p></td>
|
||
<td class="text-left"><p>string</p></td>
|
||
<td class="text-left"><p>mandatory</p></td>
|
||
</tr>
|
||
<tr class="row-even"><td class="text-left"><p>intermediate_size</p></td>
|
||
<td class="text-left"><p>int</p></td>
|
||
<td class="text-left"><p>null</p></td>
|
||
</tr>
|
||
<tr class="row-odd"><td class="text-left"><p>norm_epsilon</p></td>
|
||
<td class="text-left"><p>float</p></td>
|
||
<td class="text-left"><p>1e-5</p></td>
|
||
</tr>
|
||
<tr class="row-even"><td class="text-left"><p>position_embedding_type</p></td>
|
||
<td class="text-left"><p>string</p></td>
|
||
<td class="text-left"><p>‘learned_absolute’</p></td>
|
||
</tr>
|
||
<tr class="row-odd"><td class="text-left"><p>use_prompt_tuning</p></td>
|
||
<td class="text-left"><p>bool</p></td>
|
||
<td class="text-left"><p>false</p></td>
|
||
</tr>
|
||
<tr class="row-even"><td class="text-left"><p>mapping.world_size</p></td>
|
||
<td class="text-left"><p>int</p></td>
|
||
<td class="text-left"><p>1</p></td>
|
||
</tr>
|
||
<tr class="row-odd"><td class="text-left"><p>mapping.tp_size</p></td>
|
||
<td class="text-left"><p>int</p></td>
|
||
<td class="text-left"><p>1</p></td>
|
||
</tr>
|
||
<tr class="row-even"><td class="text-left"><p>mapping.pp_size</p></td>
|
||
<td class="text-left"><p>int</p></td>
|
||
<td class="text-left"><p>1</p></td>
|
||
</tr>
|
||
<tr class="row-odd"><td class="text-left"><p>quantization.quant_aglo</p></td>
|
||
<td class="text-left"><p>str</p></td>
|
||
<td class="text-left"><p>null</p></td>
|
||
</tr>
|
||
<tr class="row-even"><td class="text-left"><p>quantization.kv_cache_quant_aglo</p></td>
|
||
<td class="text-left"><p>str</p></td>
|
||
<td class="text-left"><p>null</p></td>
|
||
</tr>
|
||
<tr class="row-odd"><td class="text-left"><p>quantization.group_size</p></td>
|
||
<td class="text-left"><p>int</p></td>
|
||
<td class="text-left"><p>64</p></td>
|
||
</tr>
|
||
<tr class="row-even"><td class="text-left"><p>quantization.has_zero_point</p></td>
|
||
<td class="text-left"><p>bool</p></td>
|
||
<td class="text-left"><p>False</p></td>
|
||
</tr>
|
||
<tr class="row-odd"><td class="text-left"><p>quantization.pre_quant_scale</p></td>
|
||
<td class="text-left"><p>bool</p></td>
|
||
<td class="text-left"><p>False</p></td>
|
||
</tr>
|
||
<tr class="row-even"><td class="text-left"><p>quantization.exclude_modules</p></td>
|
||
<td class="text-left"><p>list</p></td>
|
||
<td class="text-left"><p>null</p></td>
|
||
</tr>
|
||
</tbody>
|
||
</table>
|
||
<p><code class="docutils literal notranslate"><span class="pre">mapping.world_size</span></code> means <code class="docutils literal notranslate"><span class="pre">mapping</span></code> is a dictionary containing the <code class="docutils literal notranslate"><span class="pre">world_size</span></code> sub field.</p>
|
||
<div class="highlight-json notranslate"><div class="highlight"><pre><span></span><span class="p">{</span>
|
||
<span class="w"> </span><span class="nt">"architecture"</span><span class="p">:</span><span class="w"> </span><span class="s2">"OPTForCausalLM"</span><span class="p">,</span>
|
||
<span class="w"> </span><span class="nt">"mapping"</span><span class="p">:</span><span class="w"> </span><span class="p">{</span>
|
||
<span class="w"> </span><span class="nt">"world_size"</span><span class="p">:</span><span class="w"> </span><span class="mi">1</span>
|
||
<span class="w"> </span><span class="p">}</span>
|
||
<span class="p">}</span>
|
||
</pre></div>
|
||
</div>
|
||
<p>Supported quantization algorithm list:</p>
|
||
<ul class="simple">
|
||
<li><p>W8A16</p></li>
|
||
<li><p>W4A16</p></li>
|
||
<li><p>W4A16_AWQ</p></li>
|
||
<li><p>W4A8_AWQ</p></li>
|
||
<li><p>W4A16_GPTQ</p></li>
|
||
<li><p>FP8</p></li>
|
||
<li><p>W8A8_SQ_PER_CHANNEL</p></li>
|
||
</ul>
|
||
<p>Supported KV cache quantization algorithm list:</p>
|
||
<ul class="simple">
|
||
<li><p>FP8</p></li>
|
||
<li><p>INT8</p></li>
|
||
</ul>
|
||
<p>The config field is extensible, a model could add its own specific config fields.
|
||
For example, OPT model has a <code class="docutils literal notranslate"><span class="pre">do_layer_norm_before</span></code> field.</p>
|
||
<p>Here is the model specific config list:</p>
|
||
<table class="docutils align-default">
|
||
<thead>
|
||
<tr class="row-odd"><th class="head text-left"><p>Field</p></th>
|
||
<th class="head text-left"><p>Type</p></th>
|
||
<th class="head text-left"><p>Default Value</p></th>
|
||
</tr>
|
||
</thead>
|
||
<tbody>
|
||
<tr class="row-even"><td class="text-left"><p>OPT</p></td>
|
||
<td class="text-left"><p></p></td>
|
||
<td class="text-left"><p></p></td>
|
||
</tr>
|
||
<tr class="row-odd"><td class="text-left"><p>do_layer_norm_before</p></td>
|
||
<td class="text-left"><p>bool</p></td>
|
||
<td class="text-left"><p>False</p></td>
|
||
</tr>
|
||
<tr class="row-even"><td class="text-left"><p></p></td>
|
||
<td class="text-left"><p></p></td>
|
||
<td class="text-left"><p></p></td>
|
||
</tr>
|
||
<tr class="row-odd"><td class="text-left"><p>Falcon</p></td>
|
||
<td class="text-left"><p></p></td>
|
||
<td class="text-left"><p></p></td>
|
||
</tr>
|
||
<tr class="row-even"><td class="text-left"><p>bias</p></td>
|
||
<td class="text-left"><p>bool</p></td>
|
||
<td class="text-left"><p>True</p></td>
|
||
</tr>
|
||
<tr class="row-odd"><td class="text-left"><p>new_decoder_architecture</p></td>
|
||
<td class="text-left"><p>bool</p></td>
|
||
<td class="text-left"><p>False</p></td>
|
||
</tr>
|
||
<tr class="row-even"><td class="text-left"><p>parallel_attention</p></td>
|
||
<td class="text-left"><p>bool</p></td>
|
||
<td class="text-left"><p>False</p></td>
|
||
</tr>
|
||
</tbody>
|
||
</table>
|
||
</section>
|
||
<section id="rank-weights">
|
||
<h3>Rank Weights<a class="headerlink" href="#rank-weights" title="Link to this heading"></a></h3>
|
||
<p>Like PyTorch, the tensor(weight) name is a string containing hierarchical information,
|
||
which is uniquely mapped to a certain parameter of a TensorRT-LLM model.</p>
|
||
<p>For example, each transformer layer of the OPT model contains an <code class="docutils literal notranslate"><span class="pre">Attention</span></code> layer, an <code class="docutils literal notranslate"><span class="pre">MLP</span></code> layer and two <code class="docutils literal notranslate"><span class="pre">LayerNorm</span></code> layers.</p>
|
||
<section id="attention-weights">
|
||
<h4>Attention Weights<a class="headerlink" href="#attention-weights" title="Link to this heading"></a></h4>
|
||
<p>The <code class="docutils literal notranslate"><span class="pre">Attention</span></code> layer contains two <code class="docutils literal notranslate"><span class="pre">Linear</span></code> layers, qkv and dense; each <code class="docutils literal notranslate"><span class="pre">Linear</span></code> layer contains one weight and one bias.
|
||
So, there are four tensors (weights) in total, whose names are:</p>
|
||
<ul class="simple">
|
||
<li><p>“transformer.layers.0.attention.qkv.weight”</p></li>
|
||
<li><p>“transformer.layers.0.attention.qkv.bias”</p></li>
|
||
<li><p>“transformer.layers.0.attention.dense.weight”</p></li>
|
||
<li><p>“transformer.layers.0.attention.dense.bias”</p></li>
|
||
</ul>
|
||
<p>where <code class="docutils literal notranslate"><span class="pre">transformer.layers.0.attention</span></code> is the prefix name, indicating that the weights/biases are in the attention module of the 0-th transformer layer.</p>
|
||
</section>
|
||
<section id="mlp-weights">
|
||
<h4>MLP Weights<a class="headerlink" href="#mlp-weights" title="Link to this heading"></a></h4>
|
||
<p>The <code class="docutils literal notranslate"><span class="pre">MLP</span></code> layer also contains two <code class="docutils literal notranslate"><span class="pre">Linear</span></code> layers, fc and proj; each <code class="docutils literal notranslate"><span class="pre">Linear</span></code> layer contains one weight and one bias.
|
||
So, there are four tensors (weights) in total, whose names are:</p>
|
||
<ul class="simple">
|
||
<li><p>“transformer.layers.0.mlp.fc.weight”</p></li>
|
||
<li><p>“transformer.layers.0.mlp.fc.bias”</p></li>
|
||
<li><p>“transformer.layers.0.mlp.proj.weight”</p></li>
|
||
<li><p>“transformer.layers.0.mlp.proj.bias”</p></li>
|
||
</ul>
|
||
<p>where <code class="docutils literal notranslate"><span class="pre">transformer.layers.0.mlp</span></code> is the prefix name, indicating that the weights/biases are in the mlp module of the 0-th transformer layer.</p>
|
||
</section>
|
||
<section id="layernorm-weights">
|
||
<h4>LayerNorm Weights<a class="headerlink" href="#layernorm-weights" title="Link to this heading"></a></h4>
|
||
<p>Each of the two <code class="docutils literal notranslate"><span class="pre">LayerNorm</span></code> layers, namely input_layernorm and post_layernorm, contains one weight and one bias.
|
||
So, there are four tensors (weights) in total, whose names are:</p>
|
||
<ul class="simple">
|
||
<li><p>“transformer.layers.0.input_layernorm.weight”</p></li>
|
||
<li><p>“transformer.layers.0.input_layernorm.bias”</p></li>
|
||
<li><p>“transformer.layers.0.post_layernorm.weight”</p></li>
|
||
<li><p>“transformer.layers.0.post_layernorm.bias”</p></li>
|
||
</ul>
|
||
<p>where <code class="docutils literal notranslate"><span class="pre">transformer.layers.0.input_layernorm</span></code> and <code class="docutils literal notranslate"><span class="pre">transformer.layers.0.post_layernorm</span></code> are prefix names for the two layernorm modules.</p>
|
||
</section>
|
||
<section id="kv-cache-quantization-scaling-factors">
|
||
<h4>KV Cache Quantization Scaling Factors<a class="headerlink" href="#kv-cache-quantization-scaling-factors" title="Link to this heading"></a></h4>
|
||
<p>Note that if we quantize the model, there will be different tensors (depending on the quantization method applied).
|
||
For example, if we quantize the KV cache, the <code class="docutils literal notranslate"><span class="pre">Attention</span></code> layer will have this extra scaling factor:</p>
|
||
<ul class="simple">
|
||
<li><p>“transformer.layers.0.attention.kv_cache_scaling_factor”</p></li>
|
||
</ul>
|
||
</section>
|
||
<section id="fp8-quantization-scaling-factors">
|
||
<h4>FP8 Quantization Scaling Factors<a class="headerlink" href="#fp8-quantization-scaling-factors" title="Link to this heading"></a></h4>
|
||
<p>For example, here is the FP8 scaling factors of attention.qkv linear layer:</p>
|
||
<ul class="simple">
|
||
<li><p>“transformer.layers.0.attention.qkv.activation_scaling_factor”</p></li>
|
||
<li><p>“transformer.layers.0.attention.qkv.weights_scaling_factor”</p></li>
|
||
</ul>
|
||
</section>
|
||
<section id="awq-quantization-scaling-factors">
|
||
<h4>AWQ Quantization Scaling Factors<a class="headerlink" href="#awq-quantization-scaling-factors" title="Link to this heading"></a></h4>
|
||
<p>For example, here is the AWQ scaling factors of mlp.fc linear layer:</p>
|
||
<ul class="simple">
|
||
<li><p>“transformer.layers.0.mlp.fc.weights_scaling_factor”</p></li>
|
||
<li><p>“transformer.layers.0.mlp.fc.prequant_scaling_factor”</p></li>
|
||
</ul>
|
||
<p><strong>Note</strong>: The linear weights in TensorRT-LLM checkpoint always follows (out_feature, in_feature) shape,
|
||
whereas some quantized linear in TensorRT-LLM implemented by plugin may use (in_feature, out_fature) shape.
|
||
<code class="docutils literal notranslate"><span class="pre">trtllm-build</span></code> command will add a transpose operation to post-process it.</p>
|
||
</section>
|
||
</section>
|
||
<section id="example">
|
||
<h3>Example<a class="headerlink" href="#example" title="Link to this heading"></a></h3>
|
||
<p>Let’s take OPT as an example, say we want to deploy the model with tensor parallelism 2:</p>
|
||
<div class="highlight-bash notranslate"><div class="highlight"><pre><span></span><span class="nb">cd</span><span class="w"> </span>examples/opt
|
||
python3<span class="w"> </span>convert_checkpoint.py<span class="w"> </span>--model_dir<span class="w"> </span>./opt-125m<span class="w"> </span><span class="se">\</span>
|
||
<span class="w"> </span>--dtype<span class="w"> </span>float16<span class="w"> </span><span class="se">\</span>
|
||
<span class="w"> </span>--world_size<span class="w"> </span><span class="m">2</span><span class="w"> </span><span class="se">\</span>
|
||
<span class="w"> </span>--output_dir<span class="w"> </span>./opt/125M/trt_ckpt/fp16/2-gpu/
|
||
</pre></div>
|
||
</div>
|
||
<p>Here is the checkpoint directory:</p>
|
||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="o">./</span><span class="n">opt</span><span class="o">/</span><span class="mi">125</span><span class="n">M</span><span class="o">/</span><span class="n">trt_ckpt</span><span class="o">/</span><span class="n">fp16</span><span class="o">/</span><span class="mi">1</span><span class="o">-</span><span class="n">gpu</span><span class="o">/</span>
|
||
<span class="n">config</span><span class="o">.</span><span class="n">json</span>
|
||
<span class="n">rank0</span><span class="o">.</span><span class="n">safetensors</span>
|
||
<span class="n">rank1</span><span class="o">.</span><span class="n">safetensors</span>
|
||
</pre></div>
|
||
</div>
|
||
<p>Here is the <code class="docutils literal notranslate"><span class="pre">config.json</span></code>:</p>
|
||
<div class="highlight-json notranslate"><div class="highlight"><pre><span></span><span class="p">{</span>
|
||
<span class="w"> </span><span class="nt">"architecture"</span><span class="p">:</span><span class="w"> </span><span class="s2">"OPTForCausalLM"</span><span class="p">,</span>
|
||
<span class="w"> </span><span class="nt">"dtype"</span><span class="p">:</span><span class="w"> </span><span class="s2">"float16"</span><span class="p">,</span>
|
||
<span class="w"> </span><span class="nt">"logits_dtype"</span><span class="p">:</span><span class="w"> </span><span class="s2">"float32"</span><span class="p">,</span>
|
||
<span class="w"> </span><span class="nt">"num_hidden_layers"</span><span class="p">:</span><span class="w"> </span><span class="mi">12</span><span class="p">,</span>
|
||
<span class="w"> </span><span class="nt">"num_attention_heads"</span><span class="p">:</span><span class="w"> </span><span class="mi">12</span><span class="p">,</span>
|
||
<span class="w"> </span><span class="nt">"hidden_size"</span><span class="p">:</span><span class="w"> </span><span class="mi">768</span><span class="p">,</span>
|
||
<span class="w"> </span><span class="nt">"vocab_size"</span><span class="p">:</span><span class="w"> </span><span class="mi">50272</span><span class="p">,</span>
|
||
<span class="w"> </span><span class="nt">"position_embedding_type"</span><span class="p">:</span><span class="w"> </span><span class="s2">"learned_absolute"</span><span class="p">,</span>
|
||
<span class="w"> </span><span class="nt">"max_position_embeddings"</span><span class="p">:</span><span class="w"> </span><span class="mi">2048</span><span class="p">,</span>
|
||
<span class="w"> </span><span class="nt">"hidden_act"</span><span class="p">:</span><span class="w"> </span><span class="s2">"relu"</span><span class="p">,</span>
|
||
<span class="w"> </span><span class="nt">"quantization"</span><span class="p">:</span><span class="w"> </span><span class="p">{</span>
|
||
<span class="w"> </span><span class="nt">"use_weight_only"</span><span class="p">:</span><span class="w"> </span><span class="kc">false</span><span class="p">,</span>
|
||
<span class="w"> </span><span class="nt">"weight_only_precision"</span><span class="p">:</span><span class="w"> </span><span class="s2">"int8"</span>
|
||
<span class="w"> </span><span class="p">},</span>
|
||
<span class="w"> </span><span class="nt">"mapping"</span><span class="p">:</span><span class="w"> </span><span class="p">{</span>
|
||
<span class="w"> </span><span class="nt">"world_size"</span><span class="p">:</span><span class="w"> </span><span class="mi">2</span><span class="p">,</span>
|
||
<span class="w"> </span><span class="nt">"tp_size"</span><span class="p">:</span><span class="w"> </span><span class="mi">2</span>
|
||
<span class="w"> </span><span class="p">},</span>
|
||
<span class="w"> </span><span class="nt">"use_parallel_embedding"</span><span class="p">:</span><span class="w"> </span><span class="kc">false</span><span class="p">,</span>
|
||
<span class="w"> </span><span class="nt">"embedding_sharding_dim"</span><span class="p">:</span><span class="w"> </span><span class="mi">0</span><span class="p">,</span>
|
||
<span class="w"> </span><span class="nt">"share_embedding_table"</span><span class="p">:</span><span class="w"> </span><span class="kc">false</span><span class="p">,</span>
|
||
<span class="w"> </span><span class="nt">"do_layer_norm_before"</span><span class="p">:</span><span class="w"> </span><span class="kc">true</span><span class="p">,</span>
|
||
<span class="w"> </span><span class="nt">"use_prompt_tuning"</span><span class="p">:</span><span class="w"> </span><span class="kc">false</span>
|
||
<span class="p">}</span>
|
||
</pre></div>
|
||
</div>
|
||
</section>
|
||
</section>
|
||
<section id="build-checkpoint-into-tensorrt-engine">
|
||
<h2>Build Checkpoint into TensorRT Engine<a class="headerlink" href="#build-checkpoint-into-tensorrt-engine" title="Link to this heading"></a></h2>
|
||
<p>TensorRT-LLM provides a unified build command: <code class="docutils literal notranslate"><span class="pre">trtllm-build</span></code>. Before using it,
|
||
you may need to add it to the <code class="docutils literal notranslate"><span class="pre">PATH</span></code></p>
|
||
<div class="highlight-bash notranslate"><div class="highlight"><pre><span></span><span class="nb">export</span><span class="w"> </span><span class="nv">PATH</span><span class="o">=</span>/usr/local/bin:<span class="nv">$PATH</span>
|
||
|
||
trtllm-build<span class="w"> </span>--checkpoint_dir<span class="w"> </span>./opt/125M/trt_ckpt/fp16/2-gpu/<span class="w"> </span><span class="se">\</span>
|
||
<span class="w"> </span>--gemm_plugin<span class="w"> </span>float16<span class="w"> </span><span class="se">\</span>
|
||
<span class="w"> </span>--max_batch_size<span class="w"> </span><span class="m">8</span><span class="w"> </span><span class="se">\</span>
|
||
<span class="w"> </span>--max_input_len<span class="w"> </span><span class="m">924</span><span class="w"> </span><span class="se">\</span>
|
||
<span class="w"> </span>--max_output_len<span class="w"> </span><span class="m">100</span><span class="w"> </span><span class="se">\</span>
|
||
<span class="w"> </span>--output_dir<span class="w"> </span>./opt/125M/trt_engines/fp16/2-gpu/
|
||
</pre></div>
|
||
</div>
|
||
</section>
|
||
<section id="make-evaluation">
|
||
<h2>Make Evaluation<a class="headerlink" href="#make-evaluation" title="Link to this heading"></a></h2>
|
||
<div class="highlight-bash notranslate"><div class="highlight"><pre><span></span>mpirun<span class="w"> </span>-n<span class="w"> </span><span class="m">2</span><span class="w"> </span>--allow-run-as-root<span class="w"> </span><span class="se">\</span>
|
||
<span class="w"> </span>python3<span class="w"> </span>../summarize.py<span class="w"> </span>--engine_dir<span class="w"> </span>./opt/125M/trt_engines/fp16/2-gpu/<span class="w"> </span><span class="se">\</span>
|
||
<span class="w"> </span>--batch_size<span class="w"> </span><span class="m">1</span><span class="w"> </span><span class="se">\</span>
|
||
<span class="w"> </span>--test_trt_llm<span class="w"> </span><span class="se">\</span>
|
||
<span class="w"> </span>--hf_model_dir<span class="w"> </span>opt-125m<span class="w"> </span><span class="se">\</span>
|
||
<span class="w"> </span>--data_type<span class="w"> </span>fp16<span class="w"> </span><span class="se">\</span>
|
||
<span class="w"> </span>--check_accuracy<span class="w"> </span><span class="se">\</span>
|
||
<span class="w"> </span>--tensorrt_llm_rouge1_threshold<span class="o">=</span><span class="m">14</span>
|
||
</pre></div>
|
||
</div>
|
||
</section>
|
||
</section>
|
||
|
||
|
||
</div>
|
||
</div>
|
||
<footer><div class="rst-footer-buttons" role="navigation" aria-label="Footer">
|
||
<a href="memory.html" class="btn btn-neutral float-left" title="Memory Usage of TensorRT-LLM" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
|
||
<a href="lora.html" class="btn btn-neutral float-right" title="Run gpt-2b + LoRA using GptManager / cpp runtime" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
|
||
</div>
|
||
|
||
<hr/>
|
||
|
||
<div role="contentinfo">
|
||
<p>© Copyright 2023, NVidia.</p>
|
||
</div>
|
||
|
||
Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
|
||
<a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
|
||
provided by <a href="https://readthedocs.org">Read the Docs</a>.
|
||
|
||
|
||
</footer>
|
||
</div>
|
||
</div>
|
||
</section>
|
||
</div>
|
||
<script>
|
||
jQuery(function () {
|
||
SphinxRtdTheme.Navigation.enable(true);
|
||
});
|
||
</script>
|
||
|
||
</body>
|
||
</html> |