mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
576 lines
54 KiB
HTML
576 lines
54 KiB
HTML
<!DOCTYPE html>
|
||
<html class="writer-html5" lang="en" data-content_root="./">
|
||
<head>
|
||
<meta charset="utf-8" /><meta name="viewport" content="width=device-width, initial-scale=1" />
|
||
|
||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||
<title>C++ GPT Runtime — tensorrt_llm documentation</title>
|
||
<link rel="stylesheet" type="text/css" href="_static/pygments.css?v=80d5e7a1" />
|
||
<link rel="stylesheet" type="text/css" href="_static/css/theme.css?v=19f00094" />
|
||
|
||
|
||
<!--[if lt IE 9]>
|
||
<script src="_static/js/html5shiv.min.js"></script>
|
||
<![endif]-->
|
||
|
||
<script src="_static/jquery.js?v=5d32c60e"></script>
|
||
<script src="_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||
<script src="_static/documentation_options.js?v=5929fcd5"></script>
|
||
<script src="_static/doctools.js?v=888ff710"></script>
|
||
<script src="_static/sphinx_highlight.js?v=dc90522c"></script>
|
||
<script src="_static/js/theme.js"></script>
|
||
<link rel="index" title="Index" href="genindex.html" />
|
||
<link rel="search" title="Search" href="search.html" />
|
||
<link rel="next" title="The Batch Manager in TensorRT-LLM" href="batch_manager.html" />
|
||
<link rel="prev" title="TensorRT-LLM Architecture" href="architecture.html" />
|
||
</head>
|
||
|
||
<body class="wy-body-for-nav">
|
||
<div class="wy-grid-for-nav">
|
||
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
|
||
<div class="wy-side-scroll">
|
||
<div class="wy-side-nav-search" >
|
||
|
||
|
||
|
||
<a href="index.html" class="icon icon-home">
|
||
tensorrt_llm
|
||
</a>
|
||
<div role="search">
|
||
<form id="rtd-search-form" class="wy-form" action="search.html" method="get">
|
||
<input type="text" name="q" placeholder="Search docs" aria-label="Search docs" />
|
||
<input type="hidden" name="check_keywords" value="yes" />
|
||
<input type="hidden" name="area" value="default" />
|
||
</form>
|
||
</div>
|
||
</div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
|
||
<p class="caption" role="heading"><span class="caption-text">Contents:</span></p>
|
||
<ul class="current">
|
||
<li class="toctree-l1"><a class="reference internal" href="architecture.html">TensorRT-LLM Architecture</a></li>
|
||
<li class="toctree-l1 current"><a class="current reference internal" href="#">C++ GPT Runtime</a><ul>
|
||
<li class="toctree-l2"><a class="reference internal" href="#the-session">The Session</a><ul>
|
||
<li class="toctree-l3"><a class="reference internal" href="#creation">Creation</a><ul>
|
||
<li class="toctree-l4"><a class="reference internal" href="#session-configuration">Session Configuration</a></li>
|
||
<li class="toctree-l4"><a class="reference internal" href="#model-configuration">Model Configuration</a></li>
|
||
<li class="toctree-l4"><a class="reference internal" href="#world-configuration">World Configuration</a></li>
|
||
</ul>
|
||
</li>
|
||
<li class="toctree-l3"><a class="reference internal" href="#generation">Generation</a><ul>
|
||
<li class="toctree-l4"><a class="reference internal" href="#inputs-and-outputs">Inputs and Outputs</a></li>
|
||
<li class="toctree-l4"><a class="reference internal" href="#sampling-parameters">Sampling Parameters</a></li>
|
||
</ul>
|
||
</li>
|
||
</ul>
|
||
</li>
|
||
<li class="toctree-l2"><a class="reference internal" href="#internal-components">Internal Components</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="#in-flight-batching-support">In-flight Batching Support</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="#know-issues-and-future-changes">Know Issues and Future Changes</a></li>
|
||
</ul>
|
||
</li>
|
||
<li class="toctree-l1"><a class="reference internal" href="batch_manager.html">The Batch Manager in TensorRT-LLM</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="inference_request.html">Inference Request</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="gpt_attention.html">Multi-head, Multi-query and Group-query Attention</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="precision.html">Numerical Precision</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="build_from_source.html">Build from Source</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="performance.html">Performance of TensorRT-LLM</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="2023-05-19-how-to-debug.html">How to debug</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="2023-05-17-how-to-add-a-new-model.html">How to add a new model</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="graph-rewriting.html">Graph Rewriting Module</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="memory.html">Memory Usage of TensorRT-LLM</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="new_workflow.html">New Workflow</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="lora.html">Run gpt-2b + LoRA using GptManager / cpp runtime</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="perf_best_practices.html">Best Practices for Tuning the Performance of TensorRT-LLM</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="performance_analysis.html">Performance Analysis of TensorRT-LLM</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="python-api/tensorrt_llm.layers.html">Layers</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="python-api/tensorrt_llm.functional.html">Functionals</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="python-api/tensorrt_llm.models.html">Models</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="python-api/tensorrt_llm.plugin.html">Plugin</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="python-api/tensorrt_llm.quantization.html">Quantization</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="python-api/tensorrt_llm.runtime.html">Runtime</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">C++ API</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="_cpp_gen/runtime.html">Runtime</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">Blogs</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="blogs/H100vsA100.html">H100 has 4.6x A100 Performance in TensorRT-LLM, achieving 10,000 tok/s at 100ms to first token</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="blogs/H200launch.html">H200 achieves nearly 12,000 tokens/sec on Llama2-13B with TensorRT-LLM</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="blogs/Falcon180B-H200.html">Falcon-180B on a single H200 GPU with INT4 AWQ, and 6.7x faster Llama-70B over A100</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="blogs/quantization-in-TRT-LLM.html">Speed up inference with SOTA quantization techniques in TRT-LLM</a></li>
|
||
</ul>
|
||
|
||
</div>
|
||
</div>
|
||
</nav>
|
||
|
||
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
|
||
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
|
||
<a href="index.html">tensorrt_llm</a>
|
||
</nav>
|
||
|
||
<div class="wy-nav-content">
|
||
<div class="rst-content">
|
||
<div role="navigation" aria-label="Page navigation">
|
||
<ul class="wy-breadcrumbs">
|
||
<li><a href="index.html" class="icon icon-home" aria-label="Home"></a></li>
|
||
<li class="breadcrumb-item active">C++ GPT Runtime</li>
|
||
<li class="wy-breadcrumbs-aside">
|
||
<a href="_sources/gpt_runtime.md.txt" rel="nofollow"> View page source</a>
|
||
</li>
|
||
</ul>
|
||
<hr/>
|
||
</div>
|
||
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
|
||
<div itemprop="articleBody">
|
||
|
||
<section id="c-gpt-runtime">
|
||
<h1>C++ GPT Runtime<a class="headerlink" href="#c-gpt-runtime" title="Link to this heading"></a></h1>
|
||
<p>TensorRT-LLM includes a C++ component to execute TensorRT engines built with
|
||
the Python API as described in the <a class="reference internal" href="architecture.html"><span class="std std-doc">Architecture</span></a> document.
|
||
That component is called the C++ runtime.</p>
|
||
<p>The API of the C++ runtime is composed of the classes declared in
|
||
<a class="reference external" href="https://github.com/NVIDIA/TensorRT-LLM/tree/rel/cpp/include/tensorrt_llm/runtime"><code class="docutils literal notranslate"><span class="pre">cpp/include/tensorrt_llm/runtime</span></code></a> and
|
||
implemented in
|
||
<a class="reference external" href="https://github.com/NVIDIA/TensorRT-LLM/tree/rel/cpp/tensorrt_llm/runtime"><code class="docutils literal notranslate"><span class="pre">cpp/tensorrt_llm/runtime</span></code></a>. An example of
|
||
how to use the C++ runtime for a GPT-like auto-regressive model can be found in
|
||
<a class="reference external" href="https://github.com/NVIDIA/TensorRT-LLM/tree/rel/cpp/tests/runtime/gptSessionTest.cpp"><code class="docutils literal notranslate"><span class="pre">cpp/tests/runtime/gptSessionTest.cpp</span></code></a>.</p>
|
||
<p>Even if the different components described in that document mention GPT in
|
||
their name, they are not restricted to this specific model. Those classes can
|
||
be used to implement auto-regressive models like BLOOM, GPT-J, GPT-NeoX or
|
||
LLaMA, for example.</p>
|
||
<p>Complete support of encoder-decoder models, like T5, will be added to
|
||
TensorRT-LLM in a future release. An experimental version, only in Python for
|
||
now, can be found in the <a class="reference external" href="https://github.com/NVIDIA/TensorRT-LLM/tree/rel/examples/enc_dec"><code class="docutils literal notranslate"><span class="pre">examples/enc_dec</span></code></a> folder.</p>
|
||
<section id="the-session">
|
||
<h2>The Session<a class="headerlink" href="#the-session" title="Link to this heading"></a></h2>
|
||
<p>The main component of the C++ runtime is the session. For GPT-like
|
||
auto-regressive models, it is the
|
||
<a class="reference external" href="https://github.com/NVIDIA/TensorRT-LLM/tree/rel/cpp/include/tensorrt_llm/runtime/gptSession.h"><code class="docutils literal notranslate"><span class="pre">GptSession</span></code></a> class.</p>
|
||
<section id="creation">
|
||
<h3>Creation<a class="headerlink" href="#creation" title="Link to this heading"></a></h3>
|
||
<p>The constructors of that class allow users to specify the model and the
|
||
environment to execute it. The model is described by an instance of the
|
||
<a class="reference external" href="https://github.com/NVIDIA/TensorRT-LLM/tree/rel/cpp/include/tensorrt_llm/runtime/gptModelConfig.h"><code class="docutils literal notranslate"><span class="pre">GptModelConfig</span></code></a>
|
||
class and a pointer to the TensorRT engine that must be
|
||
executed to perform the inference. The environment is configured through the
|
||
<a class="reference external" href="https://github.com/NVIDIA/TensorRT-LLM/tree/rel/cpp/include/tensorrt_llm/runtime/worldConfig.h"><code class="docutils literal notranslate"><span class="pre">WorldConfig</span></code></a>
|
||
(that name comes from
|
||
<a class="reference external" href="https://en.wikipedia.org/wiki/Message_Passing_Interface">MPI</a> and its “famous”
|
||
<code class="docutils literal notranslate"><span class="pre">MPI_COMM_WORLD</span></code> default communicator). The constructor also accepts an
|
||
optional object to log information, warnings and errors:</p>
|
||
<div class="highlight-cpp notranslate"><div class="highlight"><pre><span></span><span class="cp">#include</span><span class="w"> </span><span class="cpf"><tensorrt_llm/runtime/gptSession.h></span>
|
||
|
||
<span class="k">using</span><span class="w"> </span><span class="k">namespace</span><span class="w"> </span><span class="nn">tensorrt_llm</span><span class="o">::</span><span class="nn">runtime</span><span class="p">;</span>
|
||
|
||
<span class="n">GptSession</span><span class="w"> </span><span class="n">session</span><span class="p">(</span><span class="n">sessionConfig</span><span class="p">,</span><span class="w"> </span><span class="c1">// Configuration of the session,</span>
|
||
<span class="w"> </span><span class="n">modelConfig</span><span class="p">,</span><span class="w"> </span><span class="c1">// Description of the model,</span>
|
||
<span class="w"> </span><span class="n">worldConfig</span><span class="p">,</span><span class="w"> </span><span class="c1">// Description of the environment,</span>
|
||
<span class="w"> </span><span class="n">engineBuffer</span><span class="p">,</span><span class="w"> </span><span class="c1">// The compiled TensorRT engine (const void*),</span>
|
||
<span class="w"> </span><span class="n">engineSize</span><span class="p">,</span><span class="w"> </span><span class="c1">// The size in bytes of the TensorRT engine (size_t),</span>
|
||
<span class="w"> </span><span class="n">logger</span><span class="p">);</span><span class="w"> </span><span class="c1">// The optional logger.</span>
|
||
</pre></div>
|
||
</div>
|
||
<p>The above constructor accepts a <code class="docutils literal notranslate"><span class="pre">const</span> <span class="pre">void*</span></code> pointer to the engine and the
|
||
associated size (in bytes) of that buffer. There exist other overloaded
|
||
versions that take <code class="docutils literal notranslate"><span class="pre">std::vector<uint8_t></span></code> or <code class="docutils literal notranslate"><span class="pre">std::string</span></code> arguments to
|
||
encapsulate the engine.</p>
|
||
<section id="session-configuration">
|
||
<h4>Session Configuration<a class="headerlink" href="#session-configuration" title="Link to this heading"></a></h4>
|
||
<p>The session configuration is an instance of the
|
||
<a class="reference external" href="https://github.com/NVIDIA/TensorRT-LLM/tree/rel/cpp/include/tensorrt_llm/runtime/gptSession.h"><code class="docutils literal notranslate"><span class="pre">GptSession::Config</span></code></a> class.
|
||
The constructor of this class requires three arguments:</p>
|
||
<ul class="simple">
|
||
<li><p><code class="docutils literal notranslate"><span class="pre">maxBatchSize</span></code>, the maximum number of sequences in a batch,</p></li>
|
||
<li><p><code class="docutils literal notranslate"><span class="pre">maxBeamWidth</span></code>, the maximum width of the beams in beam-search,</p></li>
|
||
<li><p><code class="docutils literal notranslate"><span class="pre">maxSequenceLength</span></code>, the length of the longest input sequence,</p></li>
|
||
</ul>
|
||
<p>Additionally, the class encapsulates the following optional parameters
|
||
(they are declared as public member variables and can be accessed directly):</p>
|
||
<ul class="simple">
|
||
<li><p><code class="docutils literal notranslate"><span class="pre">decoderPerRequest</span></code>, whether the session will use a different decoder per
|
||
request. It must be set to <code class="docutils literal notranslate"><span class="pre">true</span></code> when running in-flight batching,</p></li>
|
||
<li><p><code class="docutils literal notranslate"><span class="pre">cudaGraphMode</span></code>, whether the session will use CUDA graphs for the engine
|
||
execution in generation phase,</p></li>
|
||
<li><p><code class="docutils literal notranslate"><span class="pre">kvCacheConfig</span></code> encapsulates parameters to configure paged KV cache, when the paged KV cache is enabled in the engine:</p>
|
||
<ul>
|
||
<li><p><code class="docutils literal notranslate"><span class="pre">maxTokens</span></code>, the maximum number of tokens that will have to be
|
||
stored in the paged KV cache,</p></li>
|
||
<li><p><code class="docutils literal notranslate"><span class="pre">freeGpuMemoryFraction</span></code>, the fraction of free GPU memory that will be
|
||
reserved for paged KV cache,</p></li>
|
||
</ul>
|
||
</li>
|
||
<li><p><code class="docutils literal notranslate"><span class="pre">ctxMicroBatchSize</span></code>, the micro batch size to be used in context phase.
|
||
Batches entered in <code class="docutils literal notranslate"><span class="pre">GptSession::generation</span></code> will be split into smaller
|
||
micro batches of this size,</p></li>
|
||
<li><p><code class="docutils literal notranslate"><span class="pre">genMicroBatchSize</span></code>, the micro batch size to be used in generation phase,
|
||
Batches entered in <code class="docutils literal notranslate"><span class="pre">GptSession::generation</span></code> will be split into smaller
|
||
micro batches of this size.</p></li>
|
||
</ul>
|
||
</section>
|
||
<section id="model-configuration">
|
||
<h4>Model Configuration<a class="headerlink" href="#model-configuration" title="Link to this heading"></a></h4>
|
||
<p>The model configuration is an instance of the
|
||
<a class="reference external" href="https://github.com/NVIDIA/TensorRT-LLM/tree/rel/cpp/include/tensorrt_llm/runtime/gptModelConfig.h"><code class="docutils literal notranslate"><span class="pre">GptModelConfig</span></code></a> class.
|
||
That class encapsulates the following parameters (they are declared as private
|
||
member variables and exposed through getters and setters):</p>
|
||
<ul class="simple">
|
||
<li><p><code class="docutils literal notranslate"><span class="pre">vocabSize</span></code>, the size of the vocabulary,</p></li>
|
||
<li><p><code class="docutils literal notranslate"><span class="pre">numLayers</span></code>, the number of layers in the model,</p></li>
|
||
<li><p><code class="docutils literal notranslate"><span class="pre">numHeads</span></code>, the number of heads in the attention block,</p></li>
|
||
<li><p><code class="docutils literal notranslate"><span class="pre">numKvHeads</span></code>, is the number of heads for K and V in the attention component.
|
||
When the number of K/V heads is the same as the number of (Q) heads, the
|
||
model uses Multi-head Attention. When the number of K/V heads is 1, it uses
|
||
Multi-query Attention. Otherwise, it uses Group-query Attention. See <a class="reference internal" href="gpt_attention.html"><span class="std std-doc">GPT
|
||
Attention</span></a>,</p></li>
|
||
<li><p><code class="docutils literal notranslate"><span class="pre">hiddenSize</span></code>, the size of the hidden dimension,</p></li>
|
||
<li><p><code class="docutils literal notranslate"><span class="pre">dataType</span></code>, the datatype that was used to build the TensorRT engine and that
|
||
must be used to run the model during inference,</p></li>
|
||
<li><p><code class="docutils literal notranslate"><span class="pre">useGptAttentionPlugin</span></code>, indicates if the <a class="reference internal" href="gpt_attention.html"><span class="std std-doc">GPT Attention</span></a>
|
||
operator was compiled using the
|
||
<a class="reference external" href="https://github.com/NVIDIA/TensorRT-LLM/tree/rel/cpp/tensorrt_llm/plugins/gptAttentionPlugin">GPT Attention plugin</a>,</p></li>
|
||
<li><p><code class="docutils literal notranslate"><span class="pre">inputPacked</span></code>, indicates that the input must be packed (or padded when set
|
||
to <code class="docutils literal notranslate"><span class="pre">false</span></code>). For performance reasons, it is recommended to always use packed,
|
||
even if its default is set to <code class="docutils literal notranslate"><span class="pre">false</span></code> (will be changed in a future release).
|
||
See <a class="reference internal" href="gpt_attention.html"><span class="std std-doc">GPT Attention</span></a>,</p></li>
|
||
<li><p><code class="docutils literal notranslate"><span class="pre">pagedKvCache</span></code>, indicates if the K/V cache uses paging.
|
||
See <a class="reference internal" href="gpt_attention.html"><span class="std std-doc">GPT Attention</span></a>,</p></li>
|
||
<li><p><code class="docutils literal notranslate"><span class="pre">tokensPerBlock</span></code>, is the number of tokens in each block of the K/V cache.
|
||
It’s relevant when the paged K/V cache is enabled. By default, the value is
|
||
64. See <a class="reference internal" href="gpt_attention.html"><span class="std std-doc">GPT Attention</span></a>,</p></li>
|
||
<li><p><code class="docutils literal notranslate"><span class="pre">quantMode</span></code>, controls the quantization method. See
|
||
<a class="reference internal" href="precision.html"><span class="std std-doc">Numerical Precision</span></a>.</p></li>
|
||
<li><p><code class="docutils literal notranslate"><span class="pre">maxBatchSize</span></code>, indicates the maximum batch size that the TensorRT engine
|
||
was built for,</p></li>
|
||
<li><p><code class="docutils literal notranslate"><span class="pre">maxInputLen</span></code>, is the maximum size of the input sequences,</p></li>
|
||
<li><p><code class="docutils literal notranslate"><span class="pre">maxSequenceLen</span></code>, is the maximum total size (input+output) of the sequences.</p></li>
|
||
</ul>
|
||
</section>
|
||
<section id="world-configuration">
|
||
<h4>World Configuration<a class="headerlink" href="#world-configuration" title="Link to this heading"></a></h4>
|
||
<p>Familiarity with
|
||
<a class="reference external" href="https://en.wikipedia.org/wiki/Message_Passing_Interface">MPI</a>, is not required
|
||
to utilize the TensorRT-LMM C++ runtime. There are two main things
|
||
you need to know: (1) The C++ Runtime in TensorRT-LLM uses
|
||
<a class="reference external" href="https://en.wikipedia.org/wiki/Process_(computing)">processes</a> to execute
|
||
TensorRT engines on the different GPUs. Those GPUs can be located on a single
|
||
node as well as on different nodes in a cluster. Each process is called a
|
||
<em>rank</em> in MPI. (2) The ranks are grouped in communication groups. The
|
||
TensorRT-LLM C++ Runtime calls that group the <em>world</em>.</p>
|
||
<p>The world configuration is an instance of the
|
||
<a class="reference external" href="https://github.com/NVIDIA/TensorRT-LLM/tree/rel/cpp/include/tensorrt_llm/runtime/worldConfig.h"><code class="docutils literal notranslate"><span class="pre">WorldConfig</span></code></a>
|
||
class. In this release, that class encapsulates the following parameters:</p>
|
||
<ul class="simple">
|
||
<li><p><code class="docutils literal notranslate"><span class="pre">tensorParallelism</span></code>, is the number of ranks that collaborate together to
|
||
implement Tensor Parallelism (TP). With TP each GPU performs computations for
|
||
all the layers of the model. Some of those computations are distributed
|
||
across the GPU. TP is more balanced than PP (see below), in most cases, but
|
||
requires higher bandwidth between the GPUs. It is the recommended setting in
|
||
the presence of NVLINK between GPUs,</p></li>
|
||
<li><p><code class="docutils literal notranslate"><span class="pre">pipelineParallelism</span></code>, is the number of ranks that collaborate together to
|
||
implement Pipeline Parallelism (PP). With PP, each GPU works on a subset of
|
||
consecutive layers and communications between the GPUs happen only at the
|
||
boundaries of the subsets of layers. It is harder to guarantee the full
|
||
utilization of the GPUs with PP but it requires less memory bandwidth. It
|
||
is recommended in the absence of NVLINK between GPUs,</p></li>
|
||
<li><p><code class="docutils literal notranslate"><span class="pre">rank</span></code>, is the unique identifier of the rank (see below),</p></li>
|
||
<li><p><code class="docutils literal notranslate"><span class="pre">gpusPerNode</span></code>, indicates the number of GPUs on each node. Having that
|
||
information allows the C++ runtime to optimize communications between GPUs in
|
||
a node (like taking advantage of the
|
||
<a class="reference external" href="https://www.nvidia.com/en-us/data-center/nvlink/">NVLINK</a>
|
||
interconnect between GPUs of an A100
|
||
<a class="reference external" href="https://www.nvidia.com/en-us/data-center/dgx-platform/">DGX</a>
|
||
node).</p></li>
|
||
</ul>
|
||
<p>For a multi-GPU configuration (single or multi-node), each rank must create its
|
||
own instance of <code class="docutils literal notranslate"><span class="pre">GptSession</span></code> using its own <code class="docutils literal notranslate"><span class="pre">WorldConfig</span></code>. A typical example
|
||
is:</p>
|
||
<div class="highlight-cpp notranslate"><div class="highlight"><pre><span></span><span class="cp">#include</span><span class="w"> </span><span class="cpf">"tensorrt_llm/common/mpiUtils.h"</span>
|
||
|
||
<span class="c1">// Get the unique identifier for each rank.</span>
|
||
<span class="k">auto</span><span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">rank</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">COMM_SESSION</span><span class="p">.</span><span class="n">getRank</span><span class="p">();</span>
|
||
|
||
<span class="c1">// Create the TensorRT-LLM Runtime WorldConfig.</span>
|
||
<span class="n">tensorrt_llm</span><span class="o">::</span><span class="n">runtime</span><span class="o">::</span><span class="n">WorldConfig</span><span class="w"> </span><span class="nf">worldConfig</span><span class="p">(</span><span class="n">tensorParallelism</span><span class="p">,</span><span class="w"> </span><span class="n">pipelineParallelism</span><span class="p">,</span><span class="w"> </span><span class="n">rank</span><span class="p">);</span>
|
||
|
||
<span class="c1">// Create the GPT session (as shown above).</span>
|
||
<span class="n">tensorrt_llm</span><span class="o">::</span><span class="n">runtime</span><span class="o">::</span><span class="n">GptSession</span><span class="w"> </span><span class="nf">session</span><span class="p">(</span><span class="n">sessionConfig</span><span class="p">,</span><span class="w"> </span><span class="n">modelConfig</span><span class="p">,</span><span class="w"> </span><span class="n">worldConfig</span><span class="p">,</span><span class="w"> </span><span class="p">...);</span>
|
||
</pre></div>
|
||
</div>
|
||
<p>For simplicity, TensorRT-LLM provides users with the following simplified API:</p>
|
||
<div class="highlight-cpp notranslate"><div class="highlight"><pre><span></span><span class="k">auto</span><span class="w"> </span><span class="n">worldConfig</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">tensorrt_llm</span><span class="o">::</span><span class="n">runtime</span><span class="o">::</span><span class="n">WorldConfig</span><span class="o">::</span><span class="n">mpi</span><span class="p">();</span>
|
||
</pre></div>
|
||
</div>
|
||
<p>Once compiled, that C++ code must be executed using the <code class="docutils literal notranslate"><span class="pre">mpirun</span></code> command
|
||
installed on the system (talk to your system administrator if needed):</p>
|
||
<div class="highlight-bash notranslate"><div class="highlight"><pre><span></span><span class="c1"># Launch the program using two processes (worldSize == 2 and ranks == {0, 1}).</span>
|
||
mpirun<span class="w"> </span>-n<span class="w"> </span><span class="m">2</span><span class="w"> </span>...
|
||
</pre></div>
|
||
</div>
|
||
</section>
|
||
</section>
|
||
<section id="generation">
|
||
<h3>Generation<a class="headerlink" href="#generation" title="Link to this heading"></a></h3>
|
||
<p>The <code class="docutils literal notranslate"><span class="pre">GptSession::generate</span></code> member function performs the generation loop. Given
|
||
input tensors to read from, output tensors to populate, that member function
|
||
will run the generation loop until it reaches the maximum number of tokens that
|
||
can be produced or each sequence has reached completion (due to the production
|
||
of “end-of-sequence” or a word in the list of “stop words”). The pseudo-code of
|
||
that function looks like (member function names were changed to keep the
|
||
presentation simple):</p>
|
||
<div class="highlight-cpp notranslate"><div class="highlight"><pre><span></span><span class="c1">// Have all the sequences in the batch reached completion?</span>
|
||
<span class="kt">bool</span><span class="w"> </span><span class="n">allFinished</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="nb">false</span><span class="p">;</span>
|
||
|
||
<span class="c1">// Until all sequences are finished or the number of steps reaches the limit...</span>
|
||
<span class="k">for</span><span class="w"> </span><span class="p">(</span><span class="kt">int</span><span class="w"> </span><span class="n">step</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="mi">0</span><span class="p">;</span><span class="w"> </span><span class="o">!</span><span class="n">allFinished</span><span class="w"> </span><span class="o">&&</span><span class="w"> </span><span class="n">step</span><span class="w"> </span><span class="o"><</span><span class="w"> </span><span class="n">maxNewTokens</span><span class="p">;</span><span class="w"> </span><span class="o">++</span><span class="n">step</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||
|
||
<span class="w"> </span><span class="c1">// Trigger the computation of the logits...</span>
|
||
<span class="w"> </span><span class="n">computeLogits</span><span class="p">(...);</span>
|
||
|
||
<span class="w"> </span><span class="c1">// Run the sampling to produce a token (for each active sequence) from the logits.</span>
|
||
<span class="w"> </span><span class="n">allFinished</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">generateTokensFromLogits</span><span class="p">(...);</span>
|
||
|
||
<span class="w"> </span><span class="c1">// Callback to stream the output tokens while the generation loop continues.</span>
|
||
<span class="w"> </span><span class="n">onTokenGenerated</span><span class="p">(...);</span>
|
||
<span class="p">}</span>
|
||
</pre></div>
|
||
</div>
|
||
<section id="inputs-and-outputs">
|
||
<h4>Inputs and Outputs<a class="headerlink" href="#inputs-and-outputs" title="Link to this heading"></a></h4>
|
||
<p>The <code class="docutils literal notranslate"><span class="pre">generate</span></code> member function takes an instance of the
|
||
<a class="reference external" href="https://github.com/NVIDIA/TensorRT-LLM/tree/rel/cpp/include/tensorrt_llm/runtime/generationInput.h"><code class="docutils literal notranslate"><span class="pre">GenerationInput</span></code></a> class and
|
||
populates an instance of the
|
||
<a class="reference external" href="https://github.com/NVIDIA/TensorRT-LLM/tree/rel/cpp/include/tensorrt_llm/runtime/generationOutput.h"><code class="docutils literal notranslate"><span class="pre">GenerationOutput</span></code></a> class.</p>
|
||
<p><em><strong>Mandatory inputs</strong></em></p>
|
||
<ul class="simple">
|
||
<li><p><code class="docutils literal notranslate"><span class="pre">endId</span></code>, is the token ID that marks the end of the input sequence (aka <code class="docutils literal notranslate"><span class="pre">EOS</span></code>
|
||
or end-of-sequence). It’s <code class="docutils literal notranslate"><span class="pre">50,256</span></code> for the GPT2 model which has a vocabulary
|
||
of <code class="docutils literal notranslate"><span class="pre">50,257</span></code> tokens, for example,</p></li>
|
||
<li><p><code class="docutils literal notranslate"><span class="pre">padId</span></code>, is the token ID that is used for padding (i.e. fills in the slots
|
||
that are at an index greater-or-equal to the input length for padded
|
||
sequences). It can be set to the same value as <code class="docutils literal notranslate"><span class="pre">endId</span></code>,</p></li>
|
||
<li><p><code class="docutils literal notranslate"><span class="pre">ids</span></code>, is the tensor of input IDs. That tensor must be allocated on the GPU.
|
||
When the input tensor is padded, the shape of <code class="docutils literal notranslate"><span class="pre">ids</span></code> is <code class="docutils literal notranslate"><span class="pre">[batchSize,</span> <span class="pre">maxInputLength]</span></code>, where <code class="docutils literal notranslate"><span class="pre">batchSize</span></code> and <code class="docutils literal notranslate"><span class="pre">maxInputLength</span></code> must respect the
|
||
maximum sizes in <code class="docutils literal notranslate"><span class="pre">sessionConfig</span></code> passed to the <code class="docutils literal notranslate"><span class="pre">GptSession</span></code> constructor.
|
||
When the input is packed, the shape of <code class="docutils literal notranslate"><span class="pre">ids</span></code> is <code class="docutils literal notranslate"><span class="pre">[numTokens]</span></code>, where
|
||
<code class="docutils literal notranslate"><span class="pre">numTokens</span></code> is the sum of the lengths of the different sequences in the batch,</p></li>
|
||
<li><p><code class="docutils literal notranslate"><span class="pre">lengths</span></code>, is the tensor of input sequence lengths. That tensor must be
|
||
allocated on the GPU and contain <code class="docutils literal notranslate"><span class="pre">batchSize</span></code> values,</p></li>
|
||
<li><p><code class="docutils literal notranslate"><span class="pre">packed</span></code>, indicates if the <code class="docutils literal notranslate"><span class="pre">ids</span></code> tensor is packed or padded. In this
|
||
release, that flag must match the value passed to the constructor through
|
||
the instance of the <code class="docutils literal notranslate"><span class="pre">ModelConfig</span></code> class. In a future release, the session
|
||
may be made more flexible and automatically pad or pack the input,</p></li>
|
||
</ul>
|
||
<p><em><strong>Optional inputs</strong></em></p>
|
||
<ul class="simple">
|
||
<li><p><code class="docutils literal notranslate"><span class="pre">embeddingBiasOpt</span></code>, is a tensor of floating-point values on the GPU that
|
||
contains the bias to add to the logits during sampling (after the projection
|
||
from hidden states to logits as the last step of the model). This tensor
|
||
must have <code class="docutils literal notranslate"><span class="pre">vocabSize</span></code> elements (as defined in the <code class="docutils literal notranslate"><span class="pre">ModelConfig</span></code> argument
|
||
passed to the constructor),</p></li>
|
||
<li><p><code class="docutils literal notranslate"><span class="pre">badWordsList</span></code>, is a tensor of integers on the GPU that encodes the list of
|
||
words that have to be banned from generated sequences. Its shape is <code class="docutils literal notranslate"><span class="pre">[2,</span> <span class="pre">badWordsLength]</span></code>, as explained below, or <code class="docutils literal notranslate"><span class="pre">[batchSize,</span> <span class="pre">2,</span> <span class="pre">badWordsLength]</span></code>
|
||
when there is a different list for each sequence in the batch,</p></li>
|
||
<li><p><code class="docutils literal notranslate"><span class="pre">stopWordsList</span></code>, is a tensor of integers on the GPU that encodes the list of
|
||
words that trigger the end of the generation for a sequence. Its shape is
|
||
<code class="docutils literal notranslate"><span class="pre">[2,</span> <span class="pre">stopWordsLength]</span></code>, as explained below, or <code class="docutils literal notranslate"><span class="pre">[batchSize,</span> <span class="pre">2,</span> <span class="pre">stopWordsLength]</span></code> when there is a different list for each sequence in the
|
||
batch,</p></li>
|
||
<li><p><code class="docutils literal notranslate"><span class="pre">maxNewTokens</span></code>, is the maximum number of tokens to generate.</p></li>
|
||
</ul>
|
||
<p>The <code class="docutils literal notranslate"><span class="pre">badWordsList</span></code> and <code class="docutils literal notranslate"><span class="pre">stopWordsList</span></code> tensors have the same shape <code class="docutils literal notranslate"><span class="pre">[2,</span> <span class="pre">length]</span></code>. Let’s consider an example with three words to describe the
|
||
representation of those lists. The first word contains tokens <code class="docutils literal notranslate"><span class="pre">[5,</span> <span class="pre">7,</span> <span class="pre">3]</span></code>, the
|
||
second one contains <code class="docutils literal notranslate"><span class="pre">[9,</span> <span class="pre">2]</span></code> and the third one is composed of tokens <code class="docutils literal notranslate"><span class="pre">[6,</span> <span class="pre">2,</span> <span class="pre">4,</span> <span class="pre">1]</span></code>. In total, there are 9 tokens. That’s the length. The shape of the tensor
|
||
is <code class="docutils literal notranslate"><span class="pre">[2,</span> <span class="pre">9]</span></code>. The first row of the tensor must contain the 9 token IDs and the
|
||
second row must store the
|
||
<a class="reference external" href="https://en.wikipedia.org/wiki/Prefix_sum">inclusive prefix-sum</a>
|
||
of the word lengths as shown on the following diagram:</p>
|
||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span> <span class="mi">0</span> <span class="mi">3</span> <span class="mi">5</span> <span class="mi">9</span>
|
||
<span class="o">|</span> <span class="o">|</span> <span class="o">|</span> <span class="o">|</span>
|
||
<span class="n">V</span> <span class="n">V</span> <span class="n">V</span> <span class="n">V</span>
|
||
<span class="p">[</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">7</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">9</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">1</span><span class="p">]</span>
|
||
<span class="p">[</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">9</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">]</span>
|
||
</pre></div>
|
||
</div>
|
||
<p>In case all the words are made of a single token, the inner-most dimension of
|
||
the tensor must be increased by 1 (i.e. the length for 4 words, each made of a
|
||
single token, must be 5 instead of 4 – the shape is <code class="docutils literal notranslate"><span class="pre">[2,</span> <span class="pre">5]</span></code>).</p>
|
||
<p><em><strong>Mandatory outputs</strong></em></p>
|
||
<ul class="simple">
|
||
<li><p><code class="docutils literal notranslate"><span class="pre">ids</span></code>, is a tensor that contains the output token IDs. Its shape is
|
||
<code class="docutils literal notranslate"><span class="pre">[batchSize,</span> <span class="pre">beamWidth,</span> <span class="pre">maxSeqLength]</span></code> where <code class="docutils literal notranslate"><span class="pre">maxSeqLength</span></code> is the sum of
|
||
<code class="docutils literal notranslate"><span class="pre">maxInputLength</span></code> and <code class="docutils literal notranslate"><span class="pre">maxNewTokens</span></code>. After generation, it contains, for each
|
||
sequence, a copy of the input tokens followed by the output tokens. When a
|
||
sequence is shorter than <code class="docutils literal notranslate"><span class="pre">maxSeqLength</span></code>, padding tokens are added at the end
|
||
of the sequence.</p></li>
|
||
</ul>
|
||
<p><em>Note that the shape of that tensor is different in this version of
|
||
TensorRT-LLM from its shape in previous versions where it was <code class="docutils literal notranslate"><span class="pre">[maxSeqLength,</span> <span class="pre">batchSize,</span> <span class="pre">beamWidth]</span></code></em>.</p>
|
||
<p><em><strong>Optional outputs</strong></em></p>
|
||
<ul>
|
||
<li><p><code class="docutils literal notranslate"><span class="pre">logProbs</span></code>, is a tensor of floating-point values on the GPU to store the
|
||
log-prob of the generated tokens. Its shape is <code class="docutils literal notranslate"><span class="pre">[maxNewTokens,</span> <span class="pre">batchSize,</span> <span class="pre">beamWidth]</span></code>. Its shape will likely change in a future release to match the
|
||
shape of the output <code class="docutils literal notranslate"><span class="pre">ids</span></code> tensor.</p></li>
|
||
<li><p><code class="docutils literal notranslate"><span class="pre">contextLogits</span></code>, is a tensor of values on the GPU (same datatype as the
|
||
computation type) to store the logits for the context. Its shape is
|
||
<code class="docutils literal notranslate"><span class="pre">[batchSize,</span> <span class="pre">maxSequenceLength,</span> <span class="pre">vocabSizePadded]</span></code>. If use <code class="docutils literal notranslate"><span class="pre">remove_input_padding</span></code>, its shape is <code class="docutils literal notranslate"><span class="pre">[packedSize,</span> <span class="pre">vocabSizePadded]</span></code>. This buffer will only be
|
||
filled in if the TensorRT engine was built with the <code class="docutils literal notranslate"><span class="pre">gather_context_logits</span></code> or
|
||
<code class="docutils literal notranslate"><span class="pre">gather_all_token_logits</span></code> parameter enabled.</p>
|
||
<p>After inference is complete, you can get the context logits in <code class="docutils literal notranslate"><span class="pre">GenerationOutput.contextLogits</span></code>, these are variables on the GPU. For specific acquisition methods, please refer to the example of <a class="reference external" href="https://github.com/NVIDIA/TensorRT-LLM/blob/main/benchmarks/cpp/gptSessionBenchmark.cpp">gptSessionBenchmark.cpp</a>.</p>
|
||
<p>It is important to point out
|
||
that enabling that computation may have an impact on performance (the final
|
||
LM head has to perform a matrix multiplication on all the context tokens
|
||
instead of a just the last one).</p>
|
||
</li>
|
||
<li><p><code class="docutils literal notranslate"><span class="pre">generationLogits</span></code>, is a tensor of values on the GPU (same datatype as the
|
||
computation type) to store the logits for the generation. Its shape is
|
||
<code class="docutils literal notranslate"><span class="pre">[batchSize,</span> <span class="pre">beamWidth,</span> <span class="pre">maxOutputLen,</span> <span class="pre">vocabSizePadded]</span></code>. This buffer will only be
|
||
filled in if the TensorRT engine was built with the <code class="docutils literal notranslate"><span class="pre">gather_generation_logits</span></code> or
|
||
<code class="docutils literal notranslate"><span class="pre">gather_all_token_logits</span></code> parameter enabled.</p>
|
||
<p>Generation logits can also be obtained through <code class="docutils literal notranslate"><span class="pre">GenerationOutput.generationLogits</span></code> after inference is completed.</p>
|
||
</li>
|
||
<li><p><code class="docutils literal notranslate"><span class="pre">onTokenGenerated</span></code>, is a callback function invoked in the generation loop to
|
||
pass newly generated tokens to the caller while the loop continues to
|
||
execute. An implementation of that callback must accept the output <code class="docutils literal notranslate"><span class="pre">ids</span></code>
|
||
tensor, the generation <code class="docutils literal notranslate"><span class="pre">step</span></code> and a boolean flag that indicates if the
|
||
generation is complete.</p></li>
|
||
</ul>
|
||
</section>
|
||
<section id="sampling-parameters">
|
||
<h4>Sampling Parameters<a class="headerlink" href="#sampling-parameters" title="Link to this heading"></a></h4>
|
||
<p>The <a class="reference external" href="https://github.com/NVIDIA/TensorRT-LLM/tree/rel/cpp/include/tensorrt_llm/runtime/samplingConfig.h"><code class="docutils literal notranslate"><span class="pre">SamplingConfig</span></code></a>
|
||
class encapsulates parameters that control the
|
||
<a class="reference external" href="https://huggingface.co/blog/how-to-generate">generation</a> of new tokens.
|
||
Except for the <code class="docutils literal notranslate"><span class="pre">beamWidth</span></code> parameter, all the fields are optional and the
|
||
runtime will use a default value if no values are provided by the user. For
|
||
vector fields, the TensorRT-LLM runtime supports one value per sequence (i.e.
|
||
the vector contains <code class="docutils literal notranslate"><span class="pre">batchSize</span></code> values). If all the sequences use the same
|
||
value for a given parameter, the vector can be limited to a single element
|
||
(i.e. <code class="docutils literal notranslate"><span class="pre">size()</span> <span class="pre">==</span> <span class="pre">1</span></code>).</p>
|
||
<p><em><strong>General</strong></em></p>
|
||
<ul class="simple">
|
||
<li><p><code class="docutils literal notranslate"><span class="pre">temperature</span></code>, a vector of floating-point numbers to control the
|
||
modulation of logits when sampling new tokens. The default value is <code class="docutils literal notranslate"><span class="pre">1.0f</span></code>,</p></li>
|
||
<li><p><code class="docutils literal notranslate"><span class="pre">minLength</span></code>, a vector of integers to set a lower-bound on the number of tokens
|
||
generated. The default value is 0,</p></li>
|
||
<li><p><code class="docutils literal notranslate"><span class="pre">repetitionPenalty</span></code>, a vector of float-point numbers to penalize tokens
|
||
based on how often they appear in the sequence. The default value is <code class="docutils literal notranslate"><span class="pre">0.f</span></code>,</p></li>
|
||
<li><p><code class="docutils literal notranslate"><span class="pre">presencePenalty</span></code>, a vector of float-point numbers to penalize tokens
|
||
already present in the sequence (irrespective of the number of appearances).
|
||
The default value is <code class="docutils literal notranslate"><span class="pre">0.f</span></code>,</p></li>
|
||
<li><p><code class="docutils literal notranslate"><span class="pre">frequencyPenalty</span></code>, a vector of float-point numbers to penalize tokens
|
||
already present in the sequence (dependent on the number of appearances).
|
||
The default value is <code class="docutils literal notranslate"><span class="pre">0.f</span></code>,</p></li>
|
||
</ul>
|
||
<p>The parameters <code class="docutils literal notranslate"><span class="pre">repetitionPenalty</span></code>, <code class="docutils literal notranslate"><span class="pre">presencePenalty</span></code>, and <code class="docutils literal notranslate"><span class="pre">frequencyPenalty</span></code> are not mutually
|
||
exclusive.</p>
|
||
<p><em><strong>Sampling</strong></em></p>
|
||
<ul class="simple">
|
||
<li><p><code class="docutils literal notranslate"><span class="pre">randomSeed</span></code>, a vector of 64-bit integers to control the random seed used by
|
||
the random number generator in sampling. Its default value is 0,</p></li>
|
||
<li><p><code class="docutils literal notranslate"><span class="pre">topK</span></code>, a vector of integers to control the number of logits to sample from.
|
||
Its default value is 0. Note that if different values are provided for the
|
||
different sequences in the batch, the performance of the implementation will
|
||
depend on the largest value. For efficiency reasons, we recommend to batch
|
||
requests with similar <code class="docutils literal notranslate"><span class="pre">topK</span></code> values together,</p></li>
|
||
<li><p><code class="docutils literal notranslate"><span class="pre">topP</span></code>, a vector of floating-point values to control the top-P probability
|
||
to sample from. Its default value is <code class="docutils literal notranslate"><span class="pre">0.f</span></code>,</p></li>
|
||
<li><p><code class="docutils literal notranslate"><span class="pre">topPDecay</span></code>, <code class="docutils literal notranslate"><span class="pre">topPMin</span></code> and <code class="docutils literal notranslate"><span class="pre">topPResetIds</span></code>, vectors to control the decay in
|
||
the top-P algorithm. The top-P values are modulated by
|
||
a decay that exponentially depends on the length of the sequence as explained in
|
||
<a class="reference external" href="https://arxiv.org/abs/2206.04624"><em>Factuality Enhanced Language Models for Open-Ended Text Generation</em></a>.
|
||
<code class="docutils literal notranslate"><span class="pre">topPDecay</span></code> is the decay, <code class="docutils literal notranslate"><span class="pre">topPMin</span></code> is the lower-bound and <code class="docutils literal notranslate"><span class="pre">topPResetIds</span></code>
|
||
indicates where to reset the decay. Defaults are <code class="docutils literal notranslate"><span class="pre">1.f</span></code>, <code class="docutils literal notranslate"><span class="pre">1.0e-6,f</span></code> and <code class="docutils literal notranslate"><span class="pre">-1</span></code>,</p></li>
|
||
</ul>
|
||
<p>If both <code class="docutils literal notranslate"><span class="pre">topK</span></code> and <code class="docutils literal notranslate"><span class="pre">topP</span></code> fields are set, the top-K method will be run for
|
||
sequences with a <code class="docutils literal notranslate"><span class="pre">topK</span></code> value greater than <code class="docutils literal notranslate"><span class="pre">0.f</span></code>. In that case, the <code class="docutils literal notranslate"><span class="pre">topP</span></code>
|
||
value for that sequence also influences the result. If the <code class="docutils literal notranslate"><span class="pre">topK</span></code> values for
|
||
some sequences are <code class="docutils literal notranslate"><span class="pre">0.f</span></code>, the top-P method will be used for those remaining
|
||
sequences. If both <code class="docutils literal notranslate"><span class="pre">topK</span></code> and <code class="docutils literal notranslate"><span class="pre">topP</span></code> are zero, greedy search is performed.</p>
|
||
<p><em><strong>Beam-search</strong></em></p>
|
||
<ul class="simple">
|
||
<li><p><code class="docutils literal notranslate"><span class="pre">beamWidth</span></code>, is the width used for the <a class="reference external" href="https://en.wikipedia.org/wiki/Beam_search">beam
|
||
search</a> sampling algorithm. There
|
||
is no explicit upper-bound on the beam width but increasing the beam width
|
||
will likely increase the latency. Use 1 to disable beam-search,</p></li>
|
||
<li><p><code class="docutils literal notranslate"><span class="pre">beamSearchDiversityRate</span></code>, a floating-point value that controls the
|
||
diversity in beam-search. Its default value is <code class="docutils literal notranslate"><span class="pre">0.f</span></code>,</p></li>
|
||
<li><p><code class="docutils literal notranslate"><span class="pre">lengthPenalty</span></code>, a floating-point value that controls how to penalize the
|
||
longer sequences in beam-search (the log-probability of a sequence will be
|
||
penalized by a factor that depends on <code class="docutils literal notranslate"><span class="pre">1.f</span> <span class="pre">/</span> <span class="pre">(length</span> <span class="pre">^</span> <span class="pre">lengthPenalty)</span></code>). The
|
||
default is value <code class="docutils literal notranslate"><span class="pre">0.f</span></code>. The parameter <code class="docutils literal notranslate"><span class="pre">lengthPenalty</span></code> may be renamed to
|
||
<code class="docutils literal notranslate"><span class="pre">beamSearchLengthPenalty</span></code> in a future release,</p></li>
|
||
</ul>
|
||
<p>The <code class="docutils literal notranslate"><span class="pre">beamWidth</span></code> parameter is a scalar value. It means that in this release of
|
||
TensorRT-LLM, it is not possible to specify a different width for each input
|
||
sequence. This limitation is likely to be removed in a future release.</p>
|
||
</section>
|
||
</section>
|
||
</section>
|
||
<section id="internal-components">
|
||
<h2>Internal Components<a class="headerlink" href="#internal-components" title="Link to this heading"></a></h2>
|
||
<p>The <code class="docutils literal notranslate"><span class="pre">GptSession</span></code> class encapsulates two main components. The
|
||
<a class="reference external" href="https://github.com/NVIDIA/TensorRT-LLM/tree/rel/cpp/tensorrt_llm/runtime/tllmRuntime.h"><code class="docutils literal notranslate"><span class="pre">TllmRuntime</span></code></a> is in charge of the
|
||
execution of the TensorRT engine. The
|
||
<a class="reference external" href="https://github.com/NVIDIA/TensorRT-LLM/tree/rel/cpp/include/tensorrt_llm/runtime/gptDecoder.h"><code class="docutils literal notranslate"><span class="pre">GptDecoder</span></code></a>
|
||
does the generation of the tokens from the logits. The <code class="docutils literal notranslate"><span class="pre">TllmRuntime</span></code> class is
|
||
an internal component and users are not expected to use that class directly.
|
||
The <code class="docutils literal notranslate"><span class="pre">GptDecoder</span></code> can be used directly to implement very custom generation loop
|
||
and for use cases that cannot be satisfied by the implementation in
|
||
<code class="docutils literal notranslate"><span class="pre">GptSession</span></code>.</p>
|
||
</section>
|
||
<section id="in-flight-batching-support">
|
||
<h2>In-flight Batching Support<a class="headerlink" href="#in-flight-batching-support" title="Link to this heading"></a></h2>
|
||
<p>In this release, in-flight batching is supported using separate decoders per
|
||
request. The biggest difference compared to using a single decoder is in how
|
||
the token generation from logits is managed. A batch is split into <code class="docutils literal notranslate"><span class="pre">batchSize</span></code>
|
||
individual requests and kernels are issued using separated CUDA streams.
|
||
This behavior may be revisited in a future release to maintain the structure
|
||
of the batch and improve efficiency.</p>
|
||
</section>
|
||
<section id="know-issues-and-future-changes">
|
||
<h2>Know Issues and Future Changes<a class="headerlink" href="#know-issues-and-future-changes" title="Link to this heading"></a></h2>
|
||
<ul class="simple">
|
||
<li><p>In the current release of TensorRT-LLM, the C++ and Python runtimes are two
|
||
separate software components and the C++ runtime is being more actively
|
||
developed (with features like in-flight batching). An objective, for a
|
||
future release, could be to rebuild the Python runtime on top of the C++
|
||
one.</p></li>
|
||
</ul>
|
||
</section>
|
||
</section>
|
||
|
||
|
||
</div>
|
||
</div>
|
||
<footer><div class="rst-footer-buttons" role="navigation" aria-label="Footer">
|
||
<a href="architecture.html" class="btn btn-neutral float-left" title="TensorRT-LLM Architecture" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
|
||
<a href="batch_manager.html" class="btn btn-neutral float-right" title="The Batch Manager in TensorRT-LLM" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
|
||
</div>
|
||
|
||
<hr/>
|
||
|
||
<div role="contentinfo">
|
||
<p>© Copyright 2023, NVidia.</p>
|
||
</div>
|
||
|
||
Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
|
||
<a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
|
||
provided by <a href="https://readthedocs.org">Read the Docs</a>.
|
||
|
||
|
||
</footer>
|
||
</div>
|
||
</div>
|
||
</section>
|
||
</div>
|
||
<script>
|
||
jQuery(function () {
|
||
SphinxRtdTheme.Navigation.enable(true);
|
||
});
|
||
</script>
|
||
|
||
</body>
|
||
</html> |