TensorRT-LLMs/lora.html
2024-02-29 20:56:26 +08:00

274 lines
22 KiB
HTML

<!DOCTYPE html>
<html class="writer-html5" lang="en" data-content_root="./">
<head>
<meta charset="utf-8" /><meta name="viewport" content="width=device-width, initial-scale=1" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Run gpt-2b + LoRA using GptManager / cpp runtime &mdash; tensorrt_llm documentation</title>
<link rel="stylesheet" type="text/css" href="_static/pygments.css?v=80d5e7a1" />
<link rel="stylesheet" type="text/css" href="_static/css/theme.css?v=19f00094" />
<!--[if lt IE 9]>
<script src="_static/js/html5shiv.min.js"></script>
<![endif]-->
<script src="_static/jquery.js?v=5d32c60e"></script>
<script src="_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
<script src="_static/documentation_options.js?v=5929fcd5"></script>
<script src="_static/doctools.js?v=888ff710"></script>
<script src="_static/sphinx_highlight.js?v=dc90522c"></script>
<script src="_static/js/theme.js"></script>
<link rel="index" title="Index" href="genindex.html" />
<link rel="search" title="Search" href="search.html" />
<link rel="next" title="Best Practices for Tuning the Performance of TensorRT-LLM" href="perf_best_practices.html" />
<link rel="prev" title="New Workflow" href="new_workflow.html" />
</head>
<body class="wy-body-for-nav">
<div class="wy-grid-for-nav">
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
<div class="wy-side-scroll">
<div class="wy-side-nav-search" >
<a href="index.html" class="icon icon-home">
tensorrt_llm
</a>
<div role="search">
<form id="rtd-search-form" class="wy-form" action="search.html" method="get">
<input type="text" name="q" placeholder="Search docs" aria-label="Search docs" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
</div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
<p class="caption" role="heading"><span class="caption-text">Contents:</span></p>
<ul class="current">
<li class="toctree-l1"><a class="reference internal" href="architecture.html">TensorRT-LLM Architecture</a></li>
<li class="toctree-l1"><a class="reference internal" href="gpt_runtime.html">C++ GPT Runtime</a></li>
<li class="toctree-l1"><a class="reference internal" href="batch_manager.html">The Batch Manager in TensorRT-LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="inference_request.html">Inference Request</a></li>
<li class="toctree-l1"><a class="reference internal" href="gpt_attention.html">Multi-head, Multi-query and Group-query Attention</a></li>
<li class="toctree-l1"><a class="reference internal" href="precision.html">Numerical Precision</a></li>
<li class="toctree-l1"><a class="reference internal" href="build_from_source.html">Build from Source</a></li>
<li class="toctree-l1"><a class="reference internal" href="performance.html">Performance of TensorRT-LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="2023-05-19-how-to-debug.html">How to debug</a></li>
<li class="toctree-l1"><a class="reference internal" href="2023-05-17-how-to-add-a-new-model.html">How to add a new model</a></li>
<li class="toctree-l1"><a class="reference internal" href="graph-rewriting.html">Graph Rewriting Module</a></li>
<li class="toctree-l1"><a class="reference internal" href="memory.html">Memory Usage of TensorRT-LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="new_workflow.html">New Workflow</a></li>
<li class="toctree-l1 current"><a class="current reference internal" href="#">Run gpt-2b + LoRA using GptManager / cpp runtime</a><ul>
<li class="toctree-l2"><a class="reference internal" href="#lora-tensor-format-details">LoRA tensor format details</a><ul>
<li class="toctree-l3"><a class="reference internal" href="#example-lora-tensors">Example LoRA tensors</a></li>
<li class="toctree-l3"><a class="reference internal" href="#lora-module-id-mapping">LoRA Module id mapping</a></li>
</ul>
</li>
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="perf_best_practices.html">Best Practices for Tuning the Performance of TensorRT-LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="performance_analysis.html">Performance Analysis of TensorRT-LLM</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="python-api/tensorrt_llm.layers.html">Layers</a></li>
<li class="toctree-l1"><a class="reference internal" href="python-api/tensorrt_llm.functional.html">Functionals</a></li>
<li class="toctree-l1"><a class="reference internal" href="python-api/tensorrt_llm.models.html">Models</a></li>
<li class="toctree-l1"><a class="reference internal" href="python-api/tensorrt_llm.plugin.html">Plugin</a></li>
<li class="toctree-l1"><a class="reference internal" href="python-api/tensorrt_llm.quantization.html">Quantization</a></li>
<li class="toctree-l1"><a class="reference internal" href="python-api/tensorrt_llm.runtime.html">Runtime</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">C++ API</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="_cpp_gen/runtime.html">Runtime</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Blogs</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="blogs/H100vsA100.html">H100 has 4.6x A100 Performance in TensorRT-LLM, achieving 10,000 tok/s at 100ms to first token</a></li>
<li class="toctree-l1"><a class="reference internal" href="blogs/H200launch.html">H200 achieves nearly 12,000 tokens/sec on Llama2-13B with TensorRT-LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="blogs/Falcon180B-H200.html">Falcon-180B on a single H200 GPU with INT4 AWQ, and 6.7x faster Llama-70B over A100</a></li>
<li class="toctree-l1"><a class="reference internal" href="blogs/quantization-in-TRT-LLM.html">Speed up inference with SOTA quantization techniques in TRT-LLM</a></li>
</ul>
</div>
</div>
</nav>
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
<a href="index.html">tensorrt_llm</a>
</nav>
<div class="wy-nav-content">
<div class="rst-content">
<div role="navigation" aria-label="Page navigation">
<ul class="wy-breadcrumbs">
<li><a href="index.html" class="icon icon-home" aria-label="Home"></a></li>
<li class="breadcrumb-item active">Run gpt-2b + LoRA using GptManager / cpp runtime</li>
<li class="wy-breadcrumbs-aside">
<a href="_sources/lora.md.txt" rel="nofollow"> View page source</a>
</li>
</ul>
<hr/>
</div>
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
<div itemprop="articleBody">
<section id="run-gpt-2b-lora-using-gptmanager-cpp-runtime">
<h1>Run gpt-2b + LoRA using GptManager / cpp runtime<a class="headerlink" href="#run-gpt-2b-lora-using-gptmanager-cpp-runtime" title="Link to this heading"></a></h1>
<p>First build a model with LoRA and inflight-batching enabled.</p>
<div class="highlight-bash notranslate"><div class="highlight"><pre><span></span>git-lfs<span class="w"> </span>clone<span class="w"> </span>https://huggingface.co/qychen/luotuo-lora-7b-0.1
git-lfs<span class="w"> </span>clone<span class="w"> </span>https://huggingface.co/kunishou/Japanese-Alpaca-LoRA-7b-v0
<span class="nv">BASE_MODEL</span><span class="o">=</span>llama-7b-hf
python<span class="w"> </span>examples/llama/convert_checkpoint.py<span class="w"> </span>--model_dir<span class="w"> </span><span class="si">${</span><span class="nv">BASE_MODEL</span><span class="si">}</span><span class="w"> </span><span class="se">\</span>
<span class="w"> </span>--output_dir<span class="w"> </span>/tmp/llama_7b_with_lora_qkv/trt_ckpt/fp16/1-gpu/<span class="w"> </span><span class="se">\</span>
<span class="w"> </span>--dtype<span class="w"> </span>float16<span class="w"> </span><span class="se">\</span>
<span class="w"> </span>--hf_lora_dir<span class="w"> </span>Japanese-Alpaca-LoRA-7b-v0<span class="w"> </span><span class="se">\</span>
<span class="w"> </span>--max_lora_rank<span class="w"> </span><span class="m">8</span><span class="w"> </span><span class="se">\</span>
<span class="w"> </span>--lora_target_modules<span class="w"> </span><span class="s2">&quot;attn_q&quot;</span><span class="w"> </span><span class="s2">&quot;attn_k&quot;</span><span class="w"> </span><span class="s2">&quot;attn_v&quot;</span>
trtllm-build<span class="w"> </span>--checkpoint_dir<span class="w"> </span>/tmp/llama_7b_with_lora_qkv/trt_ckpt/fp16/1-gpu/<span class="w"> </span><span class="se">\</span>
<span class="w"> </span>--output_dir<span class="w"> </span>/tmp/llama_7b_with_lora_qkv/trt_engines/fp16/1-gpu/<span class="w"> </span><span class="se">\</span>
<span class="w"> </span>--remove_input_padding<span class="w"> </span><span class="nb">enable</span><span class="w"> </span><span class="se">\</span>
<span class="w"> </span>--gpt_attention_plugin<span class="w"> </span>float16<span class="w"> </span><span class="se">\</span>
<span class="w"> </span>--context_fmha<span class="w"> </span><span class="nb">enable</span><span class="w"> </span><span class="se">\</span>
<span class="w"> </span>--paged_kv_cache<span class="w"> </span><span class="nb">enable</span><span class="w"> </span><span class="se">\</span>
<span class="w"> </span>--gemm_plugin<span class="w"> </span>float16<span class="w"> </span><span class="se">\</span>
<span class="w"> </span>--lora_plugin<span class="w"> </span>float16<span class="w"> </span><span class="se">\</span>
<span class="w"> </span>--max_batch_size<span class="w"> </span><span class="m">128</span><span class="w"> </span><span class="se">\</span>
<span class="w"> </span>--max_input_len<span class="w"> </span><span class="m">512</span><span class="w"> </span><span class="se">\</span>
<span class="w"> </span>--max_output_len<span class="w"> </span><span class="m">50</span><span class="w"> </span><span class="se">\</span>
</pre></div>
</div>
<p>To pass LoRAs into the cpp runtime they must be converted to the format below.
The script below will convert a huggingface LoRA model to the correct numpy tensors.</p>
<div class="highlight-bash notranslate"><div class="highlight"><pre><span></span>python3<span class="w"> </span>tensorrt_llm/examples/hf_lora_convert.py<span class="w"> </span>-i<span class="w"> </span>Japanese-Alpaca-LoRA-7b-v0<span class="w"> </span>-o<span class="w"> </span>Japanese-Alpaca-LoRA-7b-v0-weights<span class="w"> </span>--storage-type<span class="w"> </span>float16
python3<span class="w"> </span>tensorrt_llm/examples/hf_lora_convert.py<span class="w"> </span>-i<span class="w"> </span>luotuo-lora-7b-0.1<span class="w"> </span>-o<span class="w"> </span>luotuo-lora-7b-0.1-weights<span class="w"> </span>--storage-type<span class="w"> </span>float16
</pre></div>
</div>
<p>See tensorrtllm_backend <a class="reference external" href="https://github.com/triton-inference-server/tensorrtllm_backend/blob/main/inflight_batcher_llm/README.md">docs</a> for a Multi-LoRA example using Triton.</p>
<section id="lora-tensor-format-details">
<h2>LoRA tensor format details<a class="headerlink" href="#lora-tensor-format-details" title="Link to this heading"></a></h2>
<p>To run inference with LoRA weights using GptManager, InferenceRequests must have LoraWeights (lora_weights) and LoraConfig (lora_config) parameters.</p>
<p><code class="docutils literal notranslate"><span class="pre">LoraWeights</span></code> contains the weights for all the LoRAs. Currently this should include weight for all tp and pp ranks.
The weights tensor has the shape <code class="docutils literal notranslate"><span class="pre">[</span> <span class="pre">num_lora_modules_layers,</span> <span class="pre">D</span> <span class="pre">x</span> <span class="pre">Hi</span> <span class="pre">+</span> <span class="pre">Ho</span> <span class="pre">x</span> <span class="pre">D</span> <span class="pre">]</span></code>. the last dimension holds the in / out adapter weights for the associated module (e.g. attn_qkv) and model layer.
Each of the in / out tensors are first flattened and then concatenated together in the format above.
The first dimension (of size <code class="docutils literal notranslate"><span class="pre">num_lora_module_layers</span></code>) has an entry for each module-layer (ie there is an entry for attn_q layer1 and another for attn_k layer1).</p>
<p><code class="docutils literal notranslate"><span class="pre">D=adapter_size</span> <span class="pre">(i.e.</span> <span class="pre">R</span> <span class="pre">value),</span> <span class="pre">Hi=hidden_size_in,</span> <span class="pre">Ho=hidden_size_out.</span></code></p>
<p><code class="docutils literal notranslate"><span class="pre">LoraConfig</span></code> is a configuration tensor which identifies the moduleId, layerId, and adapter size of each element of <code class="docutils literal notranslate"><span class="pre">LoraWeights</span></code>.
It has the shape <code class="docutils literal notranslate"><span class="pre">[num_lora_modules_layers,</span> <span class="pre">3]</span></code>.
The last dimension holds <code class="docutils literal notranslate"><span class="pre">[</span> <span class="pre">module_id,</span> <span class="pre">layer_idx,</span> <span class="pre">adapter_size</span> <span class="pre">D</span> <span class="pre">(i.e.</span> <span class="pre">R</span> <span class="pre">value)</span> <span class="pre">]</span></code></p>
<p>Reference: This feature supports LoRAs as described in https://arxiv.org/pdf/2106.09685.pdf</p>
<section id="example-lora-tensors">
<h3>Example LoRA tensors<a class="headerlink" href="#example-lora-tensors" title="Link to this heading"></a></h3>
<p>Here is an example of loraWeights and loraConfig tensors for a model with tp=1, pp=1, 4 layers, and a hidden size of 4.
The tensors below are for a LoRA which has a <code class="docutils literal notranslate"><span class="pre">q</span></code> and <code class="docutils literal notranslate"><span class="pre">k</span></code> adapter.</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="c1"># loraConfig</span>
<span class="p">[</span>
<span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">]</span>
<span class="p">[</span><span class="mi">2</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">4</span><span class="p">]</span>
<span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">]</span>
<span class="p">[</span><span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">4</span><span class="p">]</span>
<span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">]</span> <span class="c1"># Note that the final 2 layers only adapt `q`</span>
<span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">8</span><span class="p">]</span>
<span class="p">]</span>
<span class="c1"># Note: The loraConfig tensor configures the loraWeights tensor.</span>
<span class="c1"># The contents of each row of loraWeights is specified be the corresponding row in loraConfig</span>
<span class="c1"># loraWeights</span>
<span class="c1"># Note: that &#39;in weights&#39; and &#39;out weights&#39; are &#39;A&#39; and &#39;B&#39; in the LoRA paper.</span>
<span class="p">[</span>
<span class="p">[</span> <span class="o">&lt;</span><span class="mi">2</span> <span class="n">x</span> <span class="mi">4</span> <span class="ow">in</span> <span class="n">weights</span><span class="o">&gt;</span><span class="p">,</span> <span class="o">&lt;</span><span class="mi">4</span> <span class="n">x</span> <span class="mi">2</span> <span class="n">out</span> <span class="n">weights</span><span class="o">&gt;</span> <span class="o">&lt;</span><span class="n">padding</span><span class="o">&gt;</span> <span class="p">]</span> <span class="c1"># `q` adapter for layer 0</span>
<span class="p">[</span> <span class="o">&lt;</span><span class="mi">4</span> <span class="n">x</span> <span class="mi">4</span> <span class="ow">in</span> <span class="n">weights</span><span class="o">&gt;</span><span class="p">,</span> <span class="o">&lt;</span><span class="mi">4</span> <span class="n">x</span> <span class="mi">4</span> <span class="n">out</span> <span class="n">weights</span><span class="o">&gt;</span> <span class="o">&lt;</span><span class="n">padding</span><span class="o">&gt;</span> <span class="p">]</span> <span class="c1"># `k` adapter for layer 0</span>
<span class="p">[</span> <span class="o">&lt;</span><span class="mi">2</span> <span class="n">x</span> <span class="mi">4</span> <span class="ow">in</span> <span class="n">weights</span><span class="o">&gt;</span><span class="p">,</span> <span class="o">&lt;</span><span class="mi">4</span> <span class="n">x</span> <span class="mi">2</span> <span class="n">out</span> <span class="n">weights</span><span class="o">&gt;</span> <span class="o">&lt;</span><span class="n">padding</span><span class="o">&gt;</span> <span class="p">]</span> <span class="c1"># `q` adapter for layer 1</span>
<span class="p">[</span> <span class="o">&lt;</span><span class="mi">4</span> <span class="n">x</span> <span class="mi">4</span> <span class="ow">in</span> <span class="n">weights</span><span class="o">&gt;</span><span class="p">,</span> <span class="o">&lt;</span><span class="mi">4</span> <span class="n">x</span> <span class="mi">4</span> <span class="n">out</span> <span class="n">weights</span><span class="o">&gt;</span> <span class="o">&lt;</span><span class="n">padding</span><span class="o">&gt;</span> <span class="p">]</span> <span class="c1"># `k` adapter for layer 1</span>
<span class="p">[</span> <span class="o">&lt;</span><span class="mi">2</span> <span class="n">x</span> <span class="mi">4</span> <span class="ow">in</span> <span class="n">weights</span><span class="o">&gt;</span><span class="p">,</span> <span class="o">&lt;</span><span class="mi">4</span> <span class="n">x</span> <span class="mi">2</span> <span class="n">out</span> <span class="n">weights</span><span class="o">&gt;</span> <span class="o">&lt;</span><span class="n">padding</span><span class="o">&gt;</span> <span class="p">]</span> <span class="c1"># `q` adapter for layer 2</span>
<span class="p">[</span> <span class="o">&lt;</span><span class="mi">8</span> <span class="n">x</span> <span class="mi">4</span> <span class="ow">in</span> <span class="n">weights</span><span class="o">&gt;</span><span class="p">,</span> <span class="o">&lt;</span><span class="mi">4</span> <span class="n">x</span> <span class="mi">8</span> <span class="n">out</span> <span class="n">weights</span><span class="o">&gt;</span> <span class="p">]</span> <span class="c1"># `q` adapter for layer 3. Note the final layer has a adapter size of 8</span>
<span class="p">]</span>
</pre></div>
</div>
</section>
<section id="lora-module-id-mapping">
<h3>LoRA Module id mapping<a class="headerlink" href="#lora-module-id-mapping" title="Link to this heading"></a></h3>
<p>See LoraModule::ModuleType for model id mapping</p>
<table class="docutils align-default">
<thead>
<tr class="row-odd"><th class="head"><p>module name (as specified in convert_checkpoint.py scripts)</p></th>
<th class="head"><p>module id</p></th>
<th class="head"><p>description</p></th>
</tr>
</thead>
<tbody>
<tr class="row-even"><td><p>attn_qkv</p></td>
<td><p>0</p></td>
<td><p>compbined qkv adapter</p></td>
</tr>
<tr class="row-odd"><td><p>attn_q</p></td>
<td><p>1</p></td>
<td><p>q adapter</p></td>
</tr>
<tr class="row-even"><td><p>attn_k</p></td>
<td><p>2</p></td>
<td><p>k adapter</p></td>
</tr>
<tr class="row-odd"><td><p>attn_v</p></td>
<td><p>3</p></td>
<td><p>v adapter</p></td>
</tr>
<tr class="row-even"><td><p>attn_dense</p></td>
<td><p>4</p></td>
<td><p>adapter for the dense layer in attention</p></td>
</tr>
<tr class="row-odd"><td><p>mlp_h_to_4h</p></td>
<td><p>5</p></td>
<td><p>for llama2 adapter for gated mlp layer after attention / RMSNorm: up projection</p></td>
</tr>
<tr class="row-even"><td><p>mlp_4h_to_h</p></td>
<td><p>6</p></td>
<td><p>for llama2 adapter for gated mlp layer after attention / RMSNorm: down projection</p></td>
</tr>
<tr class="row-odd"><td><p>mlp_gate</p></td>
<td><p>7</p></td>
<td><p>for llama2 adapter for gated mlp later after attention / RMSNorm: gate</p></td>
</tr>
</tbody>
</table>
</section>
</section>
</section>
</div>
</div>
<footer><div class="rst-footer-buttons" role="navigation" aria-label="Footer">
<a href="new_workflow.html" class="btn btn-neutral float-left" title="New Workflow" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
<a href="perf_best_practices.html" class="btn btn-neutral float-right" title="Best Practices for Tuning the Performance of TensorRT-LLM" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
</div>
<hr/>
<div role="contentinfo">
<p>&#169; Copyright 2023, NVidia.</p>
</div>
Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
<a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
provided by <a href="https://readthedocs.org">Read the Docs</a>.
</footer>
</div>
</div>
</section>
</div>
<script>
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
</script>
</body>
</html>