mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
450 lines
56 KiB
HTML
450 lines
56 KiB
HTML
<!DOCTYPE html>
|
||
<html class="writer-html5" lang="en" data-content_root="../">
|
||
<head>
|
||
<meta charset="utf-8" /><meta name="viewport" content="width=device-width, initial-scale=1" />
|
||
|
||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||
<title>Troubleshooting — tensorrt_llm documentation</title>
|
||
<link rel="stylesheet" type="text/css" href="../_static/pygments.css?v=80d5e7a1" />
|
||
<link rel="stylesheet" type="text/css" href="../_static/css/theme.css?v=19f00094" />
|
||
|
||
|
||
<!--[if lt IE 9]>
|
||
<script src="../_static/js/html5shiv.min.js"></script>
|
||
<![endif]-->
|
||
|
||
<script src="../_static/jquery.js?v=5d32c60e"></script>
|
||
<script src="../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||
<script src="../_static/documentation_options.js?v=5929fcd5"></script>
|
||
<script src="../_static/doctools.js?v=9a2dae69"></script>
|
||
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||
<script src="../_static/js/theme.js"></script>
|
||
<link rel="index" title="Index" href="../genindex.html" />
|
||
<link rel="search" title="Search" href="../search.html" />
|
||
<link rel="next" title="Support Matrix" href="support-matrix.html" />
|
||
<link rel="prev" title="Performance Analysis" href="../performance/perf-analysis.html" />
|
||
</head>
|
||
|
||
<body class="wy-body-for-nav">
|
||
<div class="wy-grid-for-nav">
|
||
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
|
||
<div class="wy-side-scroll">
|
||
<div class="wy-side-nav-search" >
|
||
|
||
|
||
|
||
<a href="../index.html" class="icon icon-home">
|
||
tensorrt_llm
|
||
</a>
|
||
<div role="search">
|
||
<form id="rtd-search-form" class="wy-form" action="../search.html" method="get">
|
||
<input type="text" name="q" placeholder="Search docs" aria-label="Search docs" />
|
||
<input type="hidden" name="check_keywords" value="yes" />
|
||
<input type="hidden" name="area" value="default" />
|
||
</form>
|
||
</div>
|
||
</div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
|
||
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../overview.html">Overview</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../quick-start-guide.html">Quick Start Guide</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../release-notes.html">Release Notes</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">Installation</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../installation/linux.html">Installing on Linux</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../installation/build-from-source-linux.html">Building from Source Code on Linux</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../installation/windows.html">Installing on Windows</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../installation/build-from-source-windows.html">Building from Source Code on Windows</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">Architecture</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../architecture/overview.html">TensorRT-LLM Architecture</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../architecture/core-concepts.html">Model Definition</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../architecture/core-concepts.html#compilation">Compilation</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../architecture/core-concepts.html#runtime">Runtime</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../architecture/core-concepts.html#multi-gpu-and-multi-node-support">Multi-GPU and Multi-Node Support</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../architecture/checkpoint.html">TensorRT-LLM Checkpoint</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../architecture/workflow.html">TensorRT-LLM Build Workflow</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../architecture/add-model.html">Adding a Model</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">Advanced</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../advanced/gpt-attention.html">Multi-Head, Multi-Query, and Group-Query Attention</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../advanced/gpt-runtime.html">C++ GPT Runtime</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../advanced/graph-rewriting.html">Graph Rewriting Module</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../advanced/batch-manager.html">The Batch Manager in TensorRT-LLM</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../advanced/inference-request.html">Inference Request</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../advanced/inference-request.html#responses">Responses</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../advanced/lora.html">Run gpt-2b + LoRA using GptManager / cpp runtime</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../performance/perf-overview.html">Overview</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../performance/perf-best-practices.html">Best Practices for Tuning the Performance of TensorRT-LLM</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../performance/perf-analysis.html">Performance Analysis</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">Reference</span></p>
|
||
<ul class="current">
|
||
<li class="toctree-l1 current"><a class="current reference internal" href="#">Troubleshooting</a><ul>
|
||
<li class="toctree-l2"><a class="reference internal" href="#build-errors">Build Errors</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="#cudnn-linking-errors">cuDNN Linking Errors</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="#debug-on-unit-tests">Debug on Unit Tests</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="#debug-on-e2e-models">Debug on E2E Models</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="#debug-execution-errors">Debug Execution Errors</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="#installation-errors">Installation Errors</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="#tips">Tips</a></li>
|
||
</ul>
|
||
</li>
|
||
<li class="toctree-l1"><a class="reference internal" href="support-matrix.html">Support Matrix</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="precision.html">Numerical Precision</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="memory.html">Memory Usage of TensorRT-LLM</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">C++ API</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../_cpp_gen/executor.html">Executor</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../_cpp_gen/runtime.html">Runtime</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">Blogs</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../blogs/H100vsA100.html">H100 has 4.6x A100 Performance in TensorRT-LLM, achieving 10,000 tok/s at 100ms to first token</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../blogs/H200launch.html">H200 achieves nearly 12,000 tokens/sec on Llama2-13B with TensorRT-LLM</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../blogs/Falcon180B-H200.html">Falcon-180B on a single H200 GPU with INT4 AWQ, and 6.7x faster Llama-70B over A100</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../blogs/quantization-in-TRT-LLM.html">Speed up inference with SOTA quantization techniques in TRT-LLM</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../blogs/XQA-kernel.html">New XQA-kernel provides 2.4x more Llama-70B throughput within the same latency budget</a></li>
|
||
</ul>
|
||
|
||
</div>
|
||
</div>
|
||
</nav>
|
||
|
||
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
|
||
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
|
||
<a href="../index.html">tensorrt_llm</a>
|
||
</nav>
|
||
|
||
<div class="wy-nav-content">
|
||
<div class="rst-content">
|
||
<div role="navigation" aria-label="Page navigation">
|
||
<ul class="wy-breadcrumbs">
|
||
<li><a href="../index.html" class="icon icon-home" aria-label="Home"></a></li>
|
||
<li class="breadcrumb-item active">Troubleshooting</li>
|
||
<li class="wy-breadcrumbs-aside">
|
||
<a href="../_sources/reference/troubleshooting.md.txt" rel="nofollow"> View page source</a>
|
||
</li>
|
||
</ul>
|
||
<hr/>
|
||
</div>
|
||
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
|
||
<div itemprop="articleBody">
|
||
|
||
<section id="troubleshooting">
|
||
<span id="id1"></span><h1>Troubleshooting<a class="headerlink" href="#troubleshooting" title="Link to this heading"></a></h1>
|
||
<p>This document describes how to debug in TensorRT-LLM.</p>
|
||
<p>Usually, we want to print the intermediate tensor values when debugging a TensorRT-LLM model.
|
||
TensorRT-LLM obeys define-and-run paradigm, we should mark the interested intermediate tensors as the network outputs.
|
||
Then, we print the values at runtime.</p>
|
||
<section id="build-errors">
|
||
<h2>Build Errors<a class="headerlink" href="#build-errors" title="Link to this heading"></a></h2>
|
||
<p>Many build errors can be resolved by simply deleting the build tree. Try running the build script with <code class="docutils literal notranslate"><span class="pre">--clean</span></code> or running <code class="docutils literal notranslate"><span class="pre">rm</span> <span class="pre">-r</span> <span class="pre">cpp/build</span></code>.</p>
|
||
</section>
|
||
<section id="cudnn-linking-errors">
|
||
<h2>cuDNN Linking Errors<a class="headerlink" href="#cudnn-linking-errors" title="Link to this heading"></a></h2>
|
||
<p>If you encounter errors such as “Entry Point Not Found” (see for example <a class="reference external" href="https://github.com/NVIDIA/TensorRT-LLM/issues/1062">#1062</a>) the issue might be a mismatch in the <code class="docutils literal notranslate"><span class="pre">cuDNN</span></code> libraries shipped from <code class="docutils literal notranslate"><span class="pre">torch</span></code> and <code class="docutils literal notranslate"><span class="pre">tensorrt</span></code>. To rectify this, please try the following steps</p>
|
||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">python</span> <span class="o">-</span><span class="n">m</span> <span class="n">pip</span> <span class="n">uninstall</span> <span class="o">-</span><span class="n">y</span> <span class="n">tensorrt_llm</span>
|
||
<span class="n">python</span> <span class="o">-</span><span class="n">m</span> <span class="n">pip</span> <span class="n">install</span> <span class="o">--</span><span class="n">upgrade</span> <span class="n">pip</span>
|
||
<span class="n">python</span> <span class="o">-</span><span class="n">m</span> <span class="n">pip</span> <span class="n">install</span> <span class="n">nvidia</span><span class="o">-</span><span class="n">cudnn</span><span class="o">-</span><span class="n">cu11</span><span class="o">==</span><span class="mf">8.9.4.25</span> <span class="o">--</span><span class="n">no</span><span class="o">-</span><span class="n">cache</span><span class="o">-</span><span class="nb">dir</span>
|
||
<span class="n">python</span> <span class="o">-</span><span class="n">m</span> <span class="n">pip</span> <span class="n">install</span> <span class="o">--</span><span class="n">pre</span> <span class="o">--</span><span class="n">extra</span><span class="o">-</span><span class="n">index</span><span class="o">-</span><span class="n">url</span> <span class="n">https</span><span class="p">:</span><span class="o">//</span><span class="n">pypi</span><span class="o">.</span><span class="n">nvidia</span><span class="o">.</span><span class="n">com</span><span class="o">/</span> <span class="n">tensorrt</span><span class="o">==</span><span class="mf">9.2.0</span><span class="o">.</span><span class="n">post12</span><span class="o">.</span><span class="n">dev5</span> <span class="o">--</span><span class="n">no</span><span class="o">-</span><span class="n">cache</span><span class="o">-</span><span class="nb">dir</span>
|
||
<span class="n">python</span> <span class="o">-</span><span class="n">m</span> <span class="n">pip</span> <span class="n">uninstall</span> <span class="o">-</span><span class="n">y</span> <span class="n">nvidia</span><span class="o">-</span><span class="n">cudnn</span><span class="o">-</span><span class="n">cu11</span>
|
||
<span class="n">python</span> <span class="o">-</span><span class="n">m</span> <span class="n">pip</span> <span class="n">install</span> <span class="n">tensorrt_llm</span> <span class="o">--</span><span class="n">extra</span><span class="o">-</span><span class="n">index</span><span class="o">-</span><span class="n">url</span> <span class="n">https</span><span class="p">:</span><span class="o">//</span><span class="n">pypi</span><span class="o">.</span><span class="n">nvidia</span><span class="o">.</span><span class="n">com</span><span class="o">/</span> <span class="o">--</span><span class="n">extra</span><span class="o">-</span><span class="n">index</span><span class="o">-</span><span class="n">url</span> <span class="n">https</span><span class="p">:</span><span class="o">//</span><span class="n">pypi</span><span class="o">.</span><span class="n">nvidia</span><span class="o">.</span><span class="n">com</span><span class="o">/</span> <span class="o">--</span><span class="n">extra</span><span class="o">-</span><span class="n">index</span><span class="o">-</span><span class="n">url</span> <span class="n">https</span><span class="p">:</span><span class="o">//</span><span class="n">download</span><span class="o">.</span><span class="n">pytorch</span><span class="o">.</span><span class="n">org</span><span class="o">/</span><span class="n">whl</span><span class="o">/</span><span class="n">cu121</span>
|
||
</pre></div>
|
||
</div>
|
||
</section>
|
||
<section id="debug-on-unit-tests">
|
||
<h2>Debug on Unit Tests<a class="headerlink" href="#debug-on-unit-tests" title="Link to this heading"></a></h2>
|
||
<ol class="arabic simple">
|
||
<li><p>Register the intermediate tensors as the network outputs with <code class="docutils literal notranslate"><span class="pre">register_network_output</span></code> API.</p></li>
|
||
</ol>
|
||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">MLP</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span>
|
||
|
||
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
|
||
<span class="n">hidden_size</span><span class="p">,</span>
|
||
<span class="n">ffn_hidden_size</span><span class="p">,</span>
|
||
<span class="n">bias</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||
<span class="n">tp_group</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">tp_size</span><span class="o">=</span><span class="mi">1</span><span class="p">):</span>
|
||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">fc</span> <span class="o">=</span> <span class="n">tensorrt_llm</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">ColumnLinear</span><span class="p">(</span><span class="n">hidden_size</span><span class="p">,</span>
|
||
<span class="n">ffn_hidden_size</span><span class="p">,</span>
|
||
<span class="n">bias</span><span class="o">=</span><span class="n">bias</span><span class="p">,</span>
|
||
<span class="n">tp_group</span><span class="o">=</span><span class="n">tp_group</span><span class="p">,</span>
|
||
<span class="n">tp_size</span><span class="o">=</span><span class="n">tp_size</span><span class="p">,</span>
|
||
<span class="n">gather_output</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">proj</span> <span class="o">=</span> <span class="n">tensorrt_llm</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">RowLinear</span><span class="p">(</span><span class="n">ffn_hidden_size</span><span class="p">,</span>
|
||
<span class="n">hidden_size</span><span class="p">,</span>
|
||
<span class="n">bias</span><span class="o">=</span><span class="n">bias</span><span class="p">,</span>
|
||
<span class="n">tp_group</span><span class="o">=</span><span class="n">tp_group</span><span class="p">,</span>
|
||
<span class="n">tp_size</span><span class="o">=</span><span class="n">tp_size</span><span class="p">)</span>
|
||
|
||
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">hidden_states</span><span class="p">):</span>
|
||
<span class="n">inter</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">)</span>
|
||
<span class="n">inter</span> <span class="o">=</span> <span class="n">tensorrt_llm</span><span class="o">.</span><span class="n">functional</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="n">inter</span><span class="p">)</span>
|
||
<span class="c1"># Here, we want to print the tensor value after relu</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">register_network_output</span><span class="p">(</span><span class="s1">'inter'</span><span class="p">,</span> <span class="n">inter</span><span class="p">)</span>
|
||
<span class="n">output</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">proj</span><span class="p">(</span><span class="n">inter</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">output</span>
|
||
</pre></div>
|
||
</div>
|
||
<ol class="arabic simple" start="2">
|
||
<li><p>Mark the intermediate tensors as network outputs.</p></li>
|
||
</ol>
|
||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">gm</span><span class="o">.</span><span class="n">named_network_outputs</span><span class="p">():</span>
|
||
<span class="n">net</span><span class="o">.</span><span class="n">_mark_output</span><span class="p">(</span><span class="n">v</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">dtype</span><span class="p">)</span>
|
||
</pre></div>
|
||
</div>
|
||
<ol class="arabic simple" start="3">
|
||
<li><p>Print the tensors at runtime.</p></li>
|
||
</ol>
|
||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="nb">print</span><span class="p">(</span><span class="n">outputs</span><span class="o">.</span><span class="n">keys</span><span class="p">())</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">outputs</span><span class="p">[</span><span class="s1">'inter'</span><span class="p">])</span>
|
||
</pre></div>
|
||
</div>
|
||
<p>Here is the <a class="reference external" href="https://github.com/NVIDIA/TensorRT-LLM/tree/rel/tests/test_debugging_api.py">full example</a>.</p>
|
||
</section>
|
||
<section id="debug-on-e2e-models">
|
||
<h2>Debug on E2E Models<a class="headerlink" href="#debug-on-e2e-models" title="Link to this heading"></a></h2>
|
||
<p>Here is an example to print the values of the MLP output tensor in the GPT model.</p>
|
||
<ol class="arabic simple">
|
||
<li><p>In <code class="docutils literal notranslate"><span class="pre">tensorrt_llm/models/gpt/model.py</span></code>, we register the MLP output tensor:</p></li>
|
||
</ol>
|
||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span> <span class="n">hidden_states</span> <span class="o">=</span> <span class="n">residual</span> <span class="o">+</span> <span class="n">attention_output</span><span class="o">.</span><span class="n">data</span>
|
||
|
||
<span class="n">residual</span> <span class="o">=</span> <span class="n">hidden_states</span>
|
||
<span class="n">hidden_states</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">post_layernorm</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">)</span>
|
||
|
||
<span class="n">hidden_states</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">mlp</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">)</span>
|
||
<span class="c1"># register as model output</span>
|
||
<span class="c1"># ------------------------------------------------------</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">register_network_output</span><span class="p">(</span><span class="s1">'mlp_output'</span><span class="p">,</span> <span class="n">hidden_states</span><span class="p">)</span>
|
||
<span class="c1"># ------------------------------------------------------</span>
|
||
|
||
<span class="n">hidden_states</span> <span class="o">=</span> <span class="n">residual</span> <span class="o">+</span> <span class="n">hidden_states</span>
|
||
</pre></div>
|
||
</div>
|
||
<ol class="arabic simple" start="2">
|
||
<li><p>Build the TensorRT engine of the model:</p></li>
|
||
</ol>
|
||
<p>When building engines with <code class="docutils literal notranslate"><span class="pre">trtllm-build</span></code>, enable the <code class="docutils literal notranslate"><span class="pre">--enable_debug_output</span></code> option.</p>
|
||
<div class="highlight-bash notranslate"><div class="highlight"><pre><span></span><span class="nb">cd</span><span class="w"> </span>examples/gpt
|
||
|
||
<span class="c1"># Download hf gpt2 model</span>
|
||
rm<span class="w"> </span>-rf<span class="w"> </span>gpt2<span class="w"> </span><span class="o">&&</span><span class="w"> </span>git<span class="w"> </span>clone<span class="w"> </span>https://huggingface.co/gpt2-medium<span class="w"> </span>gpt2
|
||
<span class="nb">pushd</span><span class="w"> </span>gpt2<span class="w"> </span><span class="o">&&</span><span class="w"> </span>rm<span class="w"> </span>pytorch_model.bin<span class="w"> </span>model.safetensors<span class="w"> </span><span class="o">&&</span><span class="w"> </span>wget<span class="w"> </span>-q<span class="w"> </span>https://huggingface.co/gpt2-medium/resolve/main/pytorch_model.bin<span class="w"> </span><span class="o">&&</span><span class="w"> </span><span class="nb">popd</span>
|
||
|
||
<span class="c1"># Convert to TensorRT-LLM checkpoint</span>
|
||
python3<span class="w"> </span>convert_checkpoint.py<span class="w"> </span>--model_dir<span class="w"> </span>gpt2<span class="w"> </span><span class="se">\</span>
|
||
<span class="w"> </span>--dtype<span class="w"> </span>float16<span class="w"> </span><span class="se">\</span>
|
||
<span class="w"> </span>--output_dir<span class="w"> </span>gpt2/trt_ckpt/fp16/1-gpu
|
||
|
||
<span class="c1"># Build TensorRT-LLM engines with --enable_debug_output</span>
|
||
trtllm-build<span class="w"> </span>--checkpoint_dir<span class="w"> </span>gpt2/trt_ckpt/fp16/1-gpu<span class="w"> </span><span class="se">\</span>
|
||
<span class="w"> </span>--gpt_attention_plugin<span class="w"> </span>float16<span class="w"> </span><span class="se">\</span>
|
||
<span class="w"> </span>--remove_input_padding<span class="w"> </span><span class="nb">enable</span><span class="w"> </span><span class="se">\</span>
|
||
<span class="w"> </span>--enable_debug_output<span class="w"> </span><span class="se">\</span>
|
||
<span class="w"> </span>--output_dir<span class="w"> </span>gpt2/trt_engines/fp16/1-gpu
|
||
</pre></div>
|
||
</div>
|
||
<ol class="arabic simple" start="3">
|
||
<li><p>Print the intermediate output tensors:</p></li>
|
||
</ol>
|
||
<p>In <code class="docutils literal notranslate"><span class="pre">tensorrt_llm/runtime/generation.py</span></code>, we print the debug info:</p>
|
||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span> <span class="n">stream</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">current_stream</span><span class="p">()</span><span class="o">.</span><span class="n">cuda_stream</span>
|
||
<span class="n">instance_idx</span> <span class="o">=</span> <span class="n">step</span> <span class="o">%</span> <span class="mi">2</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cuda_graph_mode</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">cuda_graph_instances</span><span class="p">[</span>
|
||
<span class="n">instance_idx</span><span class="p">]</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="c1"># launch cuda graph</span>
|
||
<span class="n">CUASSERT</span><span class="p">(</span>
|
||
<span class="n">cudart</span><span class="o">.</span><span class="n">cudaGraphLaunch</span><span class="p">(</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">cuda_graph_instances</span><span class="p">[</span><span class="n">instance_idx</span><span class="p">],</span> <span class="n">stream</span><span class="p">))</span>
|
||
<span class="n">ok</span> <span class="o">=</span> <span class="kc">True</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">ok</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">_run</span><span class="p">(</span><span class="n">context</span><span class="p">,</span> <span class="n">stream</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="n">ok</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Executing TRT engine failed step=</span><span class="si">{</span><span class="n">step</span><span class="si">}</span><span class="s2">!"</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">debug_mode</span><span class="p">:</span>
|
||
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">synchronize</span><span class="p">()</span>
|
||
<span class="c1"># -------------------------------------------</span>
|
||
<span class="k">if</span> <span class="n">step</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">debug_buffer</span><span class="o">.</span><span class="n">keys</span><span class="p">())</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Step: </span><span class="si">{</span><span class="n">step</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">debug_buffer</span><span class="p">[</span><span class="s1">'transformer.layers.6.mlp_output'</span><span class="p">])</span>
|
||
<span class="c1"># -------------------------------------------</span>
|
||
</pre></div>
|
||
</div>
|
||
<p>Then, run <code class="docutils literal notranslate"><span class="pre">../run.py</span></code> with <code class="docutils literal notranslate"><span class="pre">--debug_mode</span></code> and <code class="docutils literal notranslate"><span class="pre">--use_py_session</span></code>:</p>
|
||
<div class="highlight-bash notranslate"><div class="highlight"><pre><span></span>python3<span class="w"> </span>../run.py<span class="w"> </span>--engine_dir<span class="w"> </span>gpt2/trt_engines/fp16/1-gpu<span class="w"> </span><span class="se">\</span>
|
||
<span class="w"> </span>--tokenizer_dir<span class="w"> </span>gpt2<span class="w"> </span><span class="se">\</span>
|
||
<span class="w"> </span>--max_output_len<span class="w"> </span><span class="m">8</span><span class="w"> </span><span class="se">\</span>
|
||
<span class="w"> </span>--debug_mode<span class="w"> </span><span class="se">\</span>
|
||
<span class="w"> </span>--use_py_session
|
||
</pre></div>
|
||
</div>
|
||
<p>We will see the tensor values:</p>
|
||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="o">......</span>
|
||
<span class="n">dict_keys</span><span class="p">([</span><span class="s1">'context_lengths'</span><span class="p">,</span> <span class="s1">'cache_indirection'</span><span class="p">,</span> <span class="s1">'position_ids'</span><span class="p">,</span> <span class="s1">'logits'</span><span class="p">,</span> <span class="s1">'last_token_ids'</span><span class="p">,</span> <span class="s1">'input_ids'</span><span class="p">,</span> <span class="s1">'kv_cache_block_pointers'</span><span class="p">,</span> <span class="s1">'host_kv_cache_block_pointers'</span><span class="p">,</span> <span class="s1">'sequence_length'</span><span class="p">,</span> <span class="s1">'host_past_key_value_lengths'</span><span class="p">,</span> <span class="s1">'host_sink_token_length'</span><span class="p">,</span> <span class="s1">'host_request_types'</span><span class="p">,</span> <span class="s1">'host_max_attention_window_sizes'</span><span class="p">,</span> <span class="s1">'host_context_lengths'</span><span class="p">,</span> <span class="s1">'transformer.layers.0.mlp_output'</span><span class="p">,</span> <span class="s1">'transformer.layers.1.mlp_output'</span><span class="p">,</span> <span class="s1">'transformer.layers.2.mlp_output'</span><span class="p">,</span> <span class="s1">'transformer.layers.3.mlp_output'</span><span class="p">,</span> <span class="s1">'transformer.layers.4.mlp_output'</span><span class="p">,</span> <span class="s1">'transformer.layers.5.mlp_output'</span><span class="p">,</span> <span class="s1">'transformer.layers.6.mlp_output'</span><span class="p">,</span> <span class="s1">'transformer.layers.7.mlp_output'</span><span class="p">,</span> <span class="s1">'transformer.layers.8.mlp_output'</span><span class="p">,</span> <span class="s1">'transformer.layers.9.mlp_output'</span><span class="p">,</span> <span class="s1">'transformer.layers.10.mlp_output'</span><span class="p">,</span> <span class="s1">'transformer.layers.11.mlp_output'</span><span class="p">,</span> <span class="s1">'transformer.layers.12.mlp_output'</span><span class="p">,</span> <span class="s1">'transformer.layers.13.mlp_output'</span><span class="p">,</span> <span class="s1">'transformer.layers.14.mlp_output'</span><span class="p">,</span> <span class="s1">'transformer.layers.15.mlp_output'</span><span class="p">,</span> <span class="s1">'transformer.layers.16.mlp_output'</span><span class="p">,</span> <span class="s1">'transformer.layers.17.mlp_output'</span><span class="p">,</span> <span class="s1">'transformer.layers.18.mlp_output'</span><span class="p">,</span> <span class="s1">'transformer.layers.19.mlp_output'</span><span class="p">,</span> <span class="s1">'transformer.layers.20.mlp_output'</span><span class="p">,</span> <span class="s1">'transformer.layers.21.mlp_output'</span><span class="p">,</span> <span class="s1">'transformer.layers.22.mlp_output'</span><span class="p">,</span> <span class="s1">'transformer.layers.23.mlp_output'</span><span class="p">])</span>
|
||
<span class="n">Step</span><span class="p">:</span> <span class="mi">0</span>
|
||
<span class="n">tensor</span><span class="p">([[</span> <span class="mf">0.0294</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.0260</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.0776</span><span class="p">,</span> <span class="o">...</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.0560</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.0235</span><span class="p">,</span> <span class="mf">0.0273</span><span class="p">],</span>
|
||
<span class="p">[</span><span class="o">-</span><span class="mf">0.0071</span><span class="p">,</span> <span class="mf">0.5879</span><span class="p">,</span> <span class="mf">0.1993</span><span class="p">,</span> <span class="o">...</span><span class="p">,</span> <span class="o">-</span><span class="mf">1.0449</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.6299</span><span class="p">,</span> <span class="mf">0.5957</span><span class="p">],</span>
|
||
<span class="p">[</span><span class="o">-</span><span class="mf">0.8779</span><span class="p">,</span> <span class="mf">0.1050</span><span class="p">,</span> <span class="mf">0.7090</span><span class="p">,</span> <span class="o">...</span><span class="p">,</span> <span class="mf">0.0910</span><span class="p">,</span> <span class="mf">1.0713</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.2939</span><span class="p">],</span>
|
||
<span class="o">...</span><span class="p">,</span>
|
||
<span class="p">[</span> <span class="mf">0.1212</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.0903</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.5918</span><span class="p">,</span> <span class="o">...</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.1045</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.3445</span><span class="p">,</span> <span class="mf">0.1082</span><span class="p">],</span>
|
||
<span class="p">[</span><span class="o">-</span><span class="mf">1.0723</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.0732</span><span class="p">,</span> <span class="mf">0.6157</span><span class="p">,</span> <span class="o">...</span><span class="p">,</span> <span class="mf">0.3452</span><span class="p">,</span> <span class="mf">0.2998</span><span class="p">,</span> <span class="mf">0.2649</span><span class="p">],</span>
|
||
<span class="p">[</span><span class="o">-</span><span class="mf">0.7134</span><span class="p">,</span> <span class="mf">0.9692</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.1141</span><span class="p">,</span> <span class="o">...</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.0096</span><span class="p">,</span> <span class="mf">0.9521</span><span class="p">,</span> <span class="mf">0.1437</span><span class="p">]],</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="s1">'cuda:0'</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
|
||
<span class="n">Step</span><span class="p">:</span> <span class="mi">1</span>
|
||
<span class="n">tensor</span><span class="p">([[</span><span class="o">-</span><span class="mf">0.2107</span><span class="p">,</span> <span class="mf">0.5874</span><span class="p">,</span> <span class="mf">0.8179</span><span class="p">,</span> <span class="o">...</span><span class="p">,</span> <span class="mf">0.7900</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.6890</span><span class="p">,</span> <span class="mf">0.6064</span><span class="p">]],</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="s1">'cuda:0'</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
|
||
<span class="n">Step</span><span class="p">:</span> <span class="mi">2</span>
|
||
<span class="n">tensor</span><span class="p">([[</span> <span class="mf">0.4192</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.0047</span><span class="p">,</span> <span class="mf">1.3887</span><span class="p">,</span> <span class="o">...</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.9028</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.0682</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.2820</span><span class="p">]],</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="s1">'cuda:0'</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
|
||
<span class="n">Step</span><span class="p">:</span> <span class="mi">3</span>
|
||
<span class="n">tensor</span><span class="p">([[</span><span class="o">-</span><span class="mf">0.7949</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.5073</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.1721</span><span class="p">,</span> <span class="o">...</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.5830</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.1378</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.0070</span><span class="p">]],</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="s1">'cuda:0'</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
|
||
<span class="n">Step</span><span class="p">:</span> <span class="mi">4</span>
|
||
<span class="n">tensor</span><span class="p">([[</span><span class="o">-</span><span class="mf">0.0804</span><span class="p">,</span> <span class="mf">0.1272</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.6255</span><span class="p">,</span> <span class="o">...</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.1072</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.0523</span><span class="p">,</span> <span class="mf">0.7144</span><span class="p">]],</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="s1">'cuda:0'</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
|
||
<span class="n">Step</span><span class="p">:</span> <span class="mi">5</span>
|
||
<span class="n">tensor</span><span class="p">([[</span><span class="o">-</span><span class="mf">0.3328</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.8828</span><span class="p">,</span> <span class="mf">0.3442</span><span class="p">,</span> <span class="o">...</span><span class="p">,</span> <span class="mf">0.8149</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.0630</span><span class="p">,</span> <span class="mf">1.2305</span><span class="p">]],</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="s1">'cuda:0'</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
|
||
<span class="n">Step</span><span class="p">:</span> <span class="mi">6</span>
|
||
<span class="n">tensor</span><span class="p">([[</span><span class="o">-</span><span class="mf">0.2225</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.2079</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.1459</span><span class="p">,</span> <span class="o">...</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.3555</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.1672</span><span class="p">,</span> <span class="mf">0.1135</span><span class="p">]],</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="s1">'cuda:0'</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
|
||
<span class="n">Step</span><span class="p">:</span> <span class="mi">7</span>
|
||
<span class="n">tensor</span><span class="p">([[</span> <span class="mf">0.1290</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.1556</span><span class="p">,</span> <span class="mf">0.3977</span><span class="p">,</span> <span class="o">...</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.8218</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.3291</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.8672</span><span class="p">]],</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="s1">'cuda:0'</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
|
||
<span class="n">Input</span> <span class="p">[</span><span class="n">Text</span> <span class="mi">0</span><span class="p">]:</span> <span class="s2">"Born in north-east France, Soyer trained as a"</span>
|
||
<span class="n">Output</span> <span class="p">[</span><span class="n">Text</span> <span class="mi">0</span> <span class="n">Beam</span> <span class="mi">0</span><span class="p">]:</span> <span class="s2">" chef before moving to London in the early"</span>
|
||
</pre></div>
|
||
</div>
|
||
</section>
|
||
<section id="debug-execution-errors">
|
||
<h2>Debug Execution Errors<a class="headerlink" href="#debug-execution-errors" title="Link to this heading"></a></h2>
|
||
<ul class="simple">
|
||
<li><p>If you use plugins, use can set the environment variable <code class="docutils literal notranslate"><span class="pre">CUDA_LAUNCH_BLOCKING=1</span></code> so that kernels are launch synchronously, with their return status checked immediately.</p></li>
|
||
<li><p>If you see memory errors, make sure that the engine inputs respect the build-time shapes and that they reside <strong>on the correct device</strong> (CPU/GPU).</p></li>
|
||
</ul>
|
||
</section>
|
||
<section id="installation-errors">
|
||
<h2>Installation Errors<a class="headerlink" href="#installation-errors" title="Link to this heading"></a></h2>
|
||
<p>Many build errors can be resolved by simply deleting the build tree. Try running the build script with <code class="docutils literal notranslate"><span class="pre">--clean</span></code> or running <code class="docutils literal notranslate"><span class="pre">rm</span> <span class="pre">-r</span> <span class="pre">cpp/build</span></code>.</p>
|
||
</section>
|
||
<section id="tips">
|
||
<h2>Tips<a class="headerlink" href="#tips" title="Link to this heading"></a></h2>
|
||
<ul class="simple">
|
||
<li><p>It’s recommended to add options <code class="docutils literal notranslate"><span class="pre">–shm-size=1g</span> <span class="pre">–ulimit</span> <span class="pre">memlock=-1</span></code> to the
|
||
docker or nvidia-docker run command. Otherwise you may see NCCL errors when
|
||
running multiple GPU inferences. See
|
||
https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/troubleshooting.html#errors
|
||
for details.</p></li>
|
||
<li><p>When building models, memory-related issues such as</p></li>
|
||
</ul>
|
||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="p">[</span><span class="mi">09</span><span class="o">/</span><span class="mi">23</span><span class="o">/</span><span class="mi">2023</span><span class="o">-</span><span class="mi">03</span><span class="p">:</span><span class="mi">13</span><span class="p">:</span><span class="mi">00</span><span class="p">]</span> <span class="p">[</span><span class="n">TRT</span><span class="p">]</span> <span class="p">[</span><span class="n">E</span><span class="p">]</span> <span class="mi">9</span><span class="p">:</span> <span class="n">GPTLMHeadModel</span><span class="o">/</span><span class="n">layers</span><span class="o">/</span><span class="mi">0</span><span class="o">/</span><span class="n">attention</span><span class="o">/</span><span class="n">qkv</span><span class="o">/</span><span class="n">PLUGIN_V2_Gemm_0</span><span class="p">:</span> <span class="n">could</span> <span class="ow">not</span> <span class="n">find</span> <span class="nb">any</span> <span class="n">supported</span> <span class="n">formats</span> <span class="n">consistent</span> <span class="k">with</span> <span class="nb">input</span><span class="o">/</span><span class="n">output</span> <span class="n">data</span> <span class="n">types</span>
|
||
<span class="p">[</span><span class="mi">09</span><span class="o">/</span><span class="mi">23</span><span class="o">/</span><span class="mi">2023</span><span class="o">-</span><span class="mi">03</span><span class="p">:</span><span class="mi">13</span><span class="p">:</span><span class="mi">00</span><span class="p">]</span> <span class="p">[</span><span class="n">TRT</span><span class="p">]</span> <span class="p">[</span><span class="n">E</span><span class="p">]</span> <span class="mi">9</span><span class="p">:</span> <span class="p">[</span><span class="n">pluginV2Builder</span><span class="o">.</span><span class="n">cpp</span><span class="p">::</span><span class="n">reportPluginError</span><span class="p">::</span><span class="mi">24</span><span class="p">]</span> <span class="n">Error</span> <span class="n">Code</span> <span class="mi">9</span><span class="p">:</span> <span class="n">Internal</span> <span class="n">Error</span> <span class="p">(</span><span class="n">GPTLMHeadModel</span><span class="o">/</span><span class="n">layers</span><span class="o">/</span><span class="mi">0</span><span class="o">/</span><span class="n">attention</span><span class="o">/</span><span class="n">qkv</span><span class="o">/</span><span class="n">PLUGIN_V2_Gemm_0</span><span class="p">:</span> <span class="n">could</span> <span class="ow">not</span> <span class="n">find</span> <span class="nb">any</span> <span class="n">supported</span> <span class="n">formats</span> <span class="n">consistent</span> <span class="k">with</span> <span class="nb">input</span><span class="o">/</span><span class="n">output</span> <span class="n">data</span> <span class="n">types</span><span class="p">)</span>
|
||
</pre></div>
|
||
</div>
|
||
<p>may happen. One possible solution is to reduce the amount of memory needed by
|
||
reducing the maximum batch size, input and output lengths. Another option is to
|
||
enable plugins, for example: <code class="docutils literal notranslate"><span class="pre">--gpt_attention_plugin</span></code>.</p>
|
||
<ul class="simple">
|
||
<li><p>MPI + Slurm</p></li>
|
||
</ul>
|
||
<p>TensorRT-LLM is a
|
||
<a class="reference external" href="https://en.wikipedia.org/wiki/Message_Passing_Interface">MPI</a>-aware package
|
||
that uses <a class="reference external" href="https://mpi4py.readthedocs.io/en/stable/"><code class="docutils literal notranslate"><span class="pre">mpi4py</span></code></a>. If you are
|
||
running scripts in a <a class="reference external" href="https://slurm.schedmd.com/">Slurm</a> environment, you might
|
||
encounter interferences:</p>
|
||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="o">--------------------------------------------------------------------------</span>
|
||
<span class="n">PMI2_Init</span> <span class="n">failed</span> <span class="n">to</span> <span class="n">initialize</span><span class="o">.</span> <span class="n">Return</span> <span class="n">code</span><span class="p">:</span> <span class="mi">14</span>
|
||
<span class="o">--------------------------------------------------------------------------</span>
|
||
<span class="o">--------------------------------------------------------------------------</span>
|
||
<span class="n">The</span> <span class="n">application</span> <span class="n">appears</span> <span class="n">to</span> <span class="n">have</span> <span class="n">been</span> <span class="n">direct</span> <span class="n">launched</span> <span class="n">using</span> <span class="s2">"srun"</span><span class="p">,</span>
|
||
<span class="n">but</span> <span class="n">OMPI</span> <span class="n">was</span> <span class="ow">not</span> <span class="n">built</span> <span class="k">with</span> <span class="n">SLURM</span><span class="s1">'s PMI support and therefore cannot</span>
|
||
<span class="n">execute</span><span class="o">.</span> <span class="n">There</span> <span class="n">are</span> <span class="n">several</span> <span class="n">options</span> <span class="k">for</span> <span class="n">building</span> <span class="n">PMI</span> <span class="n">support</span> <span class="n">under</span>
|
||
<span class="n">SLURM</span><span class="p">,</span> <span class="n">depending</span> <span class="n">upon</span> <span class="n">the</span> <span class="n">SLURM</span> <span class="n">version</span> <span class="n">you</span> <span class="n">are</span> <span class="n">using</span><span class="p">:</span>
|
||
|
||
<span class="n">version</span> <span class="mf">16.05</span> <span class="ow">or</span> <span class="n">later</span><span class="p">:</span> <span class="n">you</span> <span class="n">can</span> <span class="n">use</span> <span class="n">SLURM</span><span class="s1">'s PMIx support. This</span>
|
||
<span class="n">requires</span> <span class="n">that</span> <span class="n">you</span> <span class="n">configure</span> <span class="ow">and</span> <span class="n">build</span> <span class="n">SLURM</span> <span class="o">--</span><span class="k">with</span><span class="o">-</span><span class="n">pmix</span><span class="o">.</span>
|
||
|
||
<span class="n">Versions</span> <span class="n">earlier</span> <span class="n">than</span> <span class="mf">16.05</span><span class="p">:</span> <span class="n">you</span> <span class="n">must</span> <span class="n">use</span> <span class="n">either</span> <span class="n">SLURM</span><span class="s1">'s PMI-1 or</span>
|
||
<span class="n">PMI</span><span class="o">-</span><span class="mi">2</span> <span class="n">support</span><span class="o">.</span> <span class="n">SLURM</span> <span class="n">builds</span> <span class="n">PMI</span><span class="o">-</span><span class="mi">1</span> <span class="n">by</span> <span class="n">default</span><span class="p">,</span> <span class="ow">or</span> <span class="n">you</span> <span class="n">can</span> <span class="n">manually</span>
|
||
<span class="n">install</span> <span class="n">PMI</span><span class="o">-</span><span class="mf">2.</span> <span class="n">You</span> <span class="n">must</span> <span class="n">then</span> <span class="n">build</span> <span class="n">Open</span> <span class="n">MPI</span> <span class="n">using</span> <span class="o">--</span><span class="k">with</span><span class="o">-</span><span class="n">pmi</span> <span class="n">pointing</span>
|
||
<span class="n">to</span> <span class="n">the</span> <span class="n">SLURM</span> <span class="n">PMI</span> <span class="n">library</span> <span class="n">location</span><span class="o">.</span>
|
||
|
||
<span class="n">Please</span> <span class="n">configure</span> <span class="k">as</span> <span class="n">appropriate</span> <span class="ow">and</span> <span class="k">try</span> <span class="n">again</span><span class="o">.</span>
|
||
<span class="o">--------------------------------------------------------------------------</span>
|
||
</pre></div>
|
||
</div>
|
||
<p>As a rule of thumb, if you are running TensorRT-LLM interactively on a Slurm
|
||
node, prefix your commands with <code class="docutils literal notranslate"><span class="pre">mpirun</span> <span class="pre">-n</span> <span class="pre">1</span></code> to run TensorRT-LLM in a
|
||
dedicated MPI environment, not the one provided by your Slurm allocation.</p>
|
||
<p>For example: <code class="docutils literal notranslate"><span class="pre">mpirun</span> <span class="pre">-n</span> <span class="pre">1</span> <span class="pre">python3</span> <span class="pre">examples/gpt/build.py</span> <span class="pre">...</span></code></p>
|
||
</section>
|
||
</section>
|
||
|
||
|
||
</div>
|
||
</div>
|
||
<footer><div class="rst-footer-buttons" role="navigation" aria-label="Footer">
|
||
<a href="../performance/perf-analysis.html" class="btn btn-neutral float-left" title="Performance Analysis" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
|
||
<a href="support-matrix.html" class="btn btn-neutral float-right" title="Support Matrix" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
|
||
</div>
|
||
|
||
<hr/>
|
||
|
||
<div role="contentinfo">
|
||
<jinja2.runtime.BlockReference object at 0x7f0d22326a40>
|
||
|
||
<div class="footer">
|
||
<p>
|
||
Copyright © 2024 NVIDIA Corporation
|
||
</p>
|
||
<p>
|
||
<a class="Link" href="https://www.nvidia.com/en-us/about-nvidia/privacy-policy/" target="_blank" rel="noopener"
|
||
data-cms-ai="0">Privacy Policy</a> |
|
||
<a class="Link" href="https://www.nvidia.com/en-us/about-nvidia/privacy-center/" target="_blank" rel="noopener"
|
||
data-cms-ai="0">Manage My Privacy</a> |
|
||
<a class="Link" href="https://www.nvidia.com/en-us/preferences/start/" target="_blank" rel="noopener"
|
||
data-cms-ai="0">Do Not Sell or Share My Data</a> |
|
||
<a class="Link" href="https://www.nvidia.com/en-us/about-nvidia/terms-of-service/" target="_blank"
|
||
rel="noopener" data-cms-ai="0">Terms of Service</a> |
|
||
<a class="Link" href="https://www.nvidia.com/en-us/about-nvidia/accessibility/" target="_blank" rel="noopener"
|
||
data-cms-ai="0">Accessibility</a> |
|
||
<a class="Link" href="https://www.nvidia.com/en-us/about-nvidia/company-policies/" target="_blank"
|
||
rel="noopener" data-cms-ai="0">Corporate Policies</a> |
|
||
<a class="Link" href="https://www.nvidia.com/en-us/product-security/" target="_blank" rel="noopener"
|
||
data-cms-ai="0">Product Security</a> |
|
||
<a class="Link" href="https://www.nvidia.com/en-us/contact/" target="_blank" rel="noopener"
|
||
data-cms-ai="0">Contact</a>
|
||
</p>
|
||
</div>
|
||
|
||
|
||
</div>
|
||
|
||
|
||
|
||
</footer>
|
||
</div>
|
||
</div>
|
||
</section>
|
||
</div>
|
||
<script>
|
||
jQuery(function () {
|
||
SphinxRtdTheme.Navigation.enable(true);
|
||
});
|
||
</script>
|
||
|
||
</body>
|
||
</html> |