TensorRT-LLMs/_modules/tensorrt_llm/functional.html
2023-10-19 12:25:48 +00:00

3939 lines
463 KiB
HTML
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<!DOCTYPE html>
<html class="writer-html5" lang="en" >
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>tensorrt_llm.functional &mdash; tensorrt_llm documentation</title>
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
<!--[if lt IE 9]>
<script src="../../_static/js/html5shiv.min.js"></script>
<![endif]-->
<script src="../../_static/jquery.js?v=5d32c60e"></script>
<script src="../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
<script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js?v=b3ba4146"></script>
<script src="../../_static/doctools.js?v=888ff710"></script>
<script src="../../_static/sphinx_highlight.js?v=4825356b"></script>
<script src="../../_static/js/theme.js"></script>
<link rel="index" title="Index" href="../../genindex.html" />
<link rel="search" title="Search" href="../../search.html" />
</head>
<body class="wy-body-for-nav">
<div class="wy-grid-for-nav">
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
<div class="wy-side-scroll">
<div class="wy-side-nav-search" >
<a href="../../index.html" class="icon icon-home">
tensorrt_llm
</a>
<div role="search">
<form id="rtd-search-form" class="wy-form" action="../../search.html" method="get">
<input type="text" name="q" placeholder="Search docs" aria-label="Search docs" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
</div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
<p class="caption" role="heading"><span class="caption-text">Contents:</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../architecture.html">TensorRT-LLM Architecture</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../gpt_runtime.html">C++ GPT Runtime</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../batch_manager.html">The Batch Manager in TensorRT-LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../gpt_attention.html">Multi-head, Multi-query and Group-query Attention</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../precision.html">Numerical Precision</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../performance.html">Performance of TensorRT-LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../installation.html">Build From Sources</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../2023-05-19-how-to-debug.html">How to debug</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../2023-05-17-how-to-add-a-new-model.html">How to add a new model</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../graph-rewriting.html">Graph Rewriting Module</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../python-api/tensorrt_llm.layers.html">Layers</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../python-api/tensorrt_llm.functional.html">Functionals</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../python-api/tensorrt_llm.models.html">Models</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../python-api/tensorrt_llm.plugin.html">Plugin</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../python-api/tensorrt_llm.quantization.html">Qunatization</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../python-api/tensorrt_llm.runtime.html">Runtime</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">C++ API</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../_cpp_gen/runtime.html">Runtime</a></li>
</ul>
</div>
</div>
</nav>
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
<a href="../../index.html">tensorrt_llm</a>
</nav>
<div class="wy-nav-content">
<div class="rst-content">
<div role="navigation" aria-label="Page navigation">
<ul class="wy-breadcrumbs">
<li><a href="../../index.html" class="icon icon-home" aria-label="Home"></a></li>
<li class="breadcrumb-item"><a href="../index.html">Module code</a></li>
<li class="breadcrumb-item active">tensorrt_llm.functional</li>
<li class="wy-breadcrumbs-aside">
</li>
</ul>
<hr/>
</div>
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
<div itemprop="articleBody">
<h1>Source code for tensorrt_llm.functional</h1><div class="highlight"><pre>
<span></span><span class="c1"># SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION &amp; AFFILIATES. All rights reserved.</span>
<span class="c1"># SPDX-License-Identifier: Apache-2.0</span>
<span class="c1">#</span>
<span class="c1"># Licensed under the Apache License, Version 2.0 (the &quot;License&quot;);</span>
<span class="c1"># you may not use this file except in compliance with the License.</span>
<span class="c1"># You may obtain a copy of the License at</span>
<span class="c1">#</span>
<span class="c1"># http://www.apache.org/licenses/LICENSE-2.0</span>
<span class="c1">#</span>
<span class="c1"># Unless required by applicable law or agreed to in writing, software</span>
<span class="c1"># distributed under the License is distributed on an &quot;AS IS&quot; BASIS,</span>
<span class="c1"># WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.</span>
<span class="c1"># See the License for the specific language governing permissions and</span>
<span class="c1"># limitations under the License.</span>
<span class="kn">import</span> <span class="nn">math</span>
<span class="kn">from</span> <span class="nn">collections</span> <span class="kn">import</span> <span class="n">OrderedDict</span>
<span class="kn">from</span> <span class="nn">enum</span> <span class="kn">import</span> <span class="n">IntEnum</span>
<span class="kn">from</span> <span class="nn">functools</span> <span class="kn">import</span> <span class="n">partial</span>
<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">List</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Sequence</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Union</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="kn">import</span> <span class="nn">tensorrt</span> <span class="k">as</span> <span class="nn">trt</span>
<span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">from</span> <span class="nn">.</span> <span class="kn">import</span> <span class="n">graph_rewriting</span> <span class="k">as</span> <span class="n">gw</span>
<span class="kn">from</span> <span class="nn">._common</span> <span class="kn">import</span> <span class="n">default_net</span><span class="p">,</span> <span class="n">default_trtnet</span><span class="p">,</span> <span class="n">precision</span>
<span class="kn">from</span> <span class="nn">._utils</span> <span class="kn">import</span> <span class="p">(</span><span class="n">dim_resolve_negative</span><span class="p">,</span> <span class="n">dim_to_trt_axes</span><span class="p">,</span> <span class="n">fp16_array</span><span class="p">,</span>
<span class="n">fp32_array</span><span class="p">,</span> <span class="n">int32_array</span><span class="p">,</span> <span class="n">np_dtype_to_trt</span><span class="p">,</span> <span class="n">str_dtype_to_np</span><span class="p">,</span>
<span class="n">str_dtype_to_trt</span><span class="p">,</span> <span class="n">torch_to_numpy</span><span class="p">,</span> <span class="n">trt_dtype_to_torch</span><span class="p">)</span>
<span class="kn">from</span> <span class="nn">.plugin</span> <span class="kn">import</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span>
<span class="kn">from</span> <span class="nn">.quantization</span> <span class="kn">import</span> <span class="n">QuantMode</span>
<div class="viewcode-block" id="DimRange"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.DimRange">[docs]</a><span class="k">class</span> <span class="nc">DimRange</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> One DimRange object stores the ranges of all the dimensions of one tensor in one optimization profile.</span>
<span class="sd"> For example, tensor has 2 dimensions. Then the data members are:</span>
<span class="sd"> self.min = [dim 0 min, dim 1 min]</span>
<span class="sd"> self.opt = [dim 0 opt, dim 1 opt]</span>
<span class="sd"> self.max = [dim 0 max, dim 1 max]</span>
<span class="sd"> For static dimension, it has min==opt==max, thus the \p shape param in the ctor can be an integer</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">shape</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]]],</span>
<span class="n">names</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Parameters:</span>
<span class="sd"> shape: a list with length N, each element is an integer or a 3-elements tuple/list of int,</span>
<span class="sd"> where N is the number of dimensions for a tensor.</span>
<span class="sd"> When one element is an integer, it means that dimension is static.</span>
<span class="sd"> Otherwise, when one element is a tuple/list, it means the dimension is dynamic.</span>
<span class="sd"> The 3 elements in one tuple/list is ordered by (min, opt, max), and this function asserts</span>
<span class="sd"> 0 &lt;= min &lt;= opt &lt;= max.</span>
<span class="sd"> Example, for a 3 rank tensor, with 1st dimension being static and has value 16, and second dimension being dynamic with</span>
<span class="sd"> min/opt/max values being 1/8/32, and 3rd dimension being static and has value 8.</span>
<span class="sd"> The shape parameter could be:</span>
<span class="sd"> [16, (1, 8, 32), 8]</span>
<span class="sd"> It has same semantics of</span>
<span class="sd"> [(16, 16, 16), (1, 8, 32), (8, 8, 8)]</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">min</span> <span class="o">=</span> <span class="p">[]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">opt</span> <span class="o">=</span> <span class="p">[]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max</span> <span class="o">=</span> <span class="p">[]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">dimension_names</span> <span class="o">=</span> <span class="n">names</span>
<span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">names</span><span class="p">)</span> <span class="o">==</span> <span class="nb">len</span><span class="p">(</span>
<span class="n">shape</span>
<span class="p">),</span> <span class="s2">&quot;Expecting shape list and name list must have same length, got {shape=}, {name=}&quot;</span>
<span class="k">for</span> <span class="n">dim</span> <span class="ow">in</span> <span class="n">shape</span><span class="p">:</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="p">(</span><span class="nb">list</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">)):</span>
<span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span> <span class="o">==</span> <span class="mi">3</span> <span class="ow">and</span> <span class="mi">0</span> <span class="o">&lt;=</span> <span class="n">dim</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">&lt;=</span> <span class="n">dim</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">&lt;=</span> <span class="n">dim</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span> \
<span class="s2">&quot;Each dimension must specify a 3-elements tuple or list in the oder of (min,opt,max), got {dim=}&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">min</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">dim</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
<span class="bp">self</span><span class="o">.</span><span class="n">opt</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">dim</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">dim</span><span class="p">[</span><span class="mi">2</span><span class="p">])</span>
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">min</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">opt</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">AttributeError</span><span class="p">(</span>
<span class="sa">f</span><span class="s1">&#39;Dimension should be [min, opt, max] (dynamic shape) or int (specific value). Got </span><span class="si">{</span><span class="nb">type</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span><span class="si">}</span><span class="s1">&#39;</span>
<span class="p">)</span>
<span class="k">def</span> <span class="fm">__eq__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">__value</span><span class="p">:</span> <span class="nb">object</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
<span class="k">return</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">__value</span><span class="p">,</span> <span class="n">DimRange</span><span class="p">)</span> <span class="ow">and</span> \
<span class="bp">self</span><span class="o">.</span><span class="n">dimension_names</span> <span class="o">==</span> <span class="n">__value</span><span class="o">.</span><span class="n">dimension_names</span> <span class="ow">and</span> \
<span class="bp">self</span><span class="o">.</span><span class="n">min</span> <span class="o">==</span> <span class="n">__value</span><span class="o">.</span><span class="n">min</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">opt</span> <span class="o">==</span> <span class="n">__value</span><span class="o">.</span><span class="n">opt</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">max</span> <span class="o">==</span> <span class="n">__value</span><span class="o">.</span><span class="n">max</span>
<span class="k">def</span> <span class="fm">__repr__</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">str</span><span class="p">:</span>
<span class="k">return</span> <span class="nb">str</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span>
<span class="k">def</span> <span class="fm">__str__</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">str</span><span class="p">:</span>
<span class="k">return</span> <span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">dimension_names</span><span class="si">=}</span><span class="s2"> </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">min</span><span class="si">=}</span><span class="s2">, </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">opt</span><span class="si">=}</span><span class="s2">, </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">max</span><span class="si">=}</span><span class="s2">)&quot;</span>
<span class="k">def</span> <span class="fm">__hash__</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">int</span><span class="p">:</span>
<span class="k">return</span> <span class="nb">hash</span><span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="bp">self</span><span class="p">))</span></div>
<div class="viewcode-block" id="Tensor"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor">[docs]</a><span class="k">class</span> <span class="nc">Tensor</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> The class to represent dense tensors.</span>
<span class="sd"> A dense tensor is named, has a shape and contains typed elements. Each</span>
<span class="sd"> dimension of a tensor can either be static or dynamic. Static dimensions</span>
<span class="sd"> are known at engine compilation by TensorRT. Dynamic dimensions can take</span>
<span class="sd"> values determined at runtime. The tensor can be located on the host (CPU)</span>
<span class="sd"> or the device (GPU).</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">name</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">shape</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">dim_range</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">is_network_input</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="n">location</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">TensorLocation</span><span class="o">.</span><span class="n">DEVICE</span><span class="p">,</span>
<span class="n">network</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">trt_tensor</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Parameters:</span>
<span class="sd"> name : str</span>
<span class="sd"> The name of the tensor.</span>
<span class="sd"> dtype : tensorrt.DataType</span>
<span class="sd"> The type of the elements of the tensor. See the TensorRT</span>
<span class="sd"> documentation for list of supported data types.</span>
<span class="sd"> shape : tensorrt.Dims</span>
<span class="sd"> The dimensions of the tensor. In TensorRT-LLM, tensors can have</span>
<span class="sd"> static or dynamic dimensions (it is possible to mix static and</span>
<span class="sd"> dynamic dimensions). A static dimension is known when the</span>
<span class="sd"> TensorRT engine is built. A dynamic dimension can be set when</span>
<span class="sd"> the engine is executed. Use -1 for dynamic dimensions.</span>
<span class="sd"> dim_range : OrderedDict</span>
<span class="sd"> An ordered dictionary (the positions of the elements matter)</span>
<span class="sd"> that associates a name and a range of values to the dimensions.</span>
<span class="sd"> For a static dimension, the range must be limited to a single</span>
<span class="sd"> value. For a dynamic dimension, the range is defined by three</span>
<span class="sd"> values [min, opt, max] where min and max are, respectively, the</span>
<span class="sd"> smallest and largest possible values of that dimension. The</span>
<span class="sd"> opt value is used by TensorRT to optimize the engine for the</span>
<span class="sd"> most common case.</span>
<span class="sd"> Assume there is N optimization profiles, each item dim_range dict is ordered by:</span>
<span class="sd"> (dynamic dimension name : [profile 0 (min, opt, max), profile 1 (min, opt, max), ... profile N(min, opt, max)])</span>
<span class="sd"> or it&#39;s following when the dimension is static (can think as min==opt==max):</span>
<span class="sd"> (static dimension name : [profile 0 value, profile 1 value, ... profile N value])</span>
<span class="sd"> For static dimension the profile 0-N value must be same, (TODO: can it be simplified to be only 1 value?)</span>
<span class="sd"> And number of keys is equal to number of optimization profiles.</span>
<span class="sd"> is_network_input : bool</span>
<span class="sd"> A boolean indicating if that tensor is an input of the network.</span>
<span class="sd"> Inputs must be provided by the user to run the engine.</span>
<span class="sd"> location : tensorrt.TensorLocation</span>
<span class="sd"> A flag to indicate where the tensor will be located. It can be</span>
<span class="sd"> on the host (CPU) or the device (GPU).</span>
<span class="sd"> network: Network</span>
<span class="sd"> A parent Network instance, that helps to fine the users of this tensor.</span>
<span class="sd"> trt_tensor: trt.ITensor</span>
<span class="sd"> Construct with the ITensor instance directly, and no shape profiles are required.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="c1"># Layout of self.profiles</span>
<span class="c1"># Opt profile 0: dim 0 (min, opt, max), dim 1 (min, opt, max) ... dim M</span>
<span class="c1"># Opt profile 1: dim 0 (min, opt, max), dim 1 (min, opt, max) ... dim M</span>
<span class="c1"># ...</span>
<span class="c1"># Opt profile N: dim 0 ... dim M</span>
<span class="c1"># So from the dim_range arg to self.profiles conversion, there is a layout transpose</span>
<span class="c1"># dim_range arg is: {M dimension x N profiles}, while self.profiles layout is {N profiles x M dimensions}</span>
<span class="bp">self</span><span class="o">.</span><span class="n">profiles</span> <span class="o">=</span> <span class="p">[]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">is_tensor_wrapper</span> <span class="o">=</span> <span class="kc">False</span> <span class="c1"># specially for the graph rewriter</span>
<span class="c1"># work as a wrapper for a trt.ITensor, this is used specially in the graph rewriter</span>
<span class="k">if</span> <span class="n">trt_tensor</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">is_tensor_wrapper</span> <span class="o">=</span> <span class="kc">True</span>
<span class="k">assert</span> <span class="n">network</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span> <span class="o">=</span> <span class="n">trt_tensor</span>
<span class="bp">self</span><span class="o">.</span><span class="n">network</span> <span class="o">=</span> <span class="n">network</span>
<span class="k">assert</span> <span class="ow">not</span> <span class="n">is_network_input</span><span class="p">,</span> <span class="s2">&quot;is_network_input should be False when trt_tensor is not None&quot;</span>
<span class="k">return</span>
<span class="c1"># defining an input placeholder for the network</span>
<span class="bp">self</span><span class="o">.</span><span class="n">network</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span>
<span class="k">if</span> <span class="n">is_network_input</span><span class="p">:</span>
<span class="k">if</span> <span class="n">dim_range</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dim_range</span><span class="p">,</span> <span class="n">OrderedDict</span><span class="p">)</span>
<span class="k">assert</span> <span class="nb">len</span><span class="p">(</span>
<span class="n">dim_range</span>
<span class="p">)</span> <span class="o">&gt;=</span> <span class="mi">1</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;Each input tensor shall have at least one dimension, tensor &#39;</span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s2">&#39; found </span><span class="si">{</span><span class="n">dim_range</span><span class="si">=}</span><span class="s2">&quot;</span>
<span class="n">found_profiles</span> <span class="o">=</span> <span class="p">[</span>
<span class="nb">len</span><span class="p">(</span><span class="n">ranges</span><span class="p">)</span> <span class="k">for</span> <span class="n">_</span><span class="p">,</span> <span class="n">ranges</span> <span class="ow">in</span> <span class="n">dim_range</span><span class="o">.</span><span class="n">items</span><span class="p">()</span>
<span class="p">]</span>
<span class="k">assert</span> <span class="nb">all</span><span class="p">(</span>
<span class="p">[</span><span class="n">x</span> <span class="o">==</span> <span class="n">found_profiles</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">found_profiles</span><span class="p">]</span>
<span class="p">),</span> <span class="sa">f</span><span class="s2">&quot;Expecting all the dimensions in the dim_range has same number of profiles, tensor &#39;</span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s2">&#39; got </span><span class="si">{</span><span class="n">dim_range</span><span class="si">=}</span><span class="s2">&quot;</span>
<span class="n">num_opt_profile</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="n">dim_range</span><span class="o">.</span><span class="n">items</span><span class="p">())[</span><span class="mi">0</span><span class="p">][</span><span class="mi">1</span><span class="p">])</span>
<span class="k">assert</span> <span class="n">num_opt_profile</span> <span class="o">&gt;=</span> <span class="mi">1</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_opt_profile</span><span class="p">):</span>
<span class="n">range_shape</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">dimension_names</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">dim</span><span class="p">,</span> <span class="n">ranges</span> <span class="ow">in</span> <span class="n">dim_range</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">ranges</span><span class="p">,</span> <span class="p">(</span><span class="nb">list</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">))</span>
<span class="n">range_shape</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">ranges</span><span class="p">[</span><span class="n">i</span><span class="p">])</span>
<span class="n">dimension_names</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">profiles</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">DimRange</span><span class="p">(</span><span class="n">range_shape</span><span class="p">,</span> <span class="n">dimension_names</span><span class="p">))</span>
<span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">_add_input</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">name</span><span class="p">,</span> <span class="n">dtype</span><span class="p">,</span> <span class="n">shape</span><span class="p">,</span> <span class="n">dim_range</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">name</span> <span class="o">=</span> <span class="n">name</span>
<span class="bp">self</span><span class="o">.</span><span class="n">dtype</span> <span class="o">=</span> <span class="n">dtype</span>
<span class="bp">self</span><span class="o">.</span><span class="n">shape</span> <span class="o">=</span> <span class="n">shape</span>
<span class="bp">self</span><span class="o">.</span><span class="n">location</span> <span class="o">=</span> <span class="n">location</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">name</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> The name of the tensor.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">name</span>
<span class="nd">@name</span><span class="o">.</span><span class="n">setter</span>
<span class="k">def</span> <span class="nf">name</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">name</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Set the name of the tensor.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">if</span> <span class="n">name</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">name</span> <span class="o">=</span> <span class="n">name</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">dtype</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> The type of the elements in the tensor.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">dtype</span>
<span class="nd">@dtype</span><span class="o">.</span><span class="n">setter</span>
<span class="k">def</span> <span class="nf">dtype</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dtype</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Set the type of the elements in the tensor.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">if</span> <span class="n">dtype</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">dtype</span> <span class="o">=</span> <span class="n">dtype</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">shape</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> The shape of the tensor.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">size</span><span class="p">()</span>
<span class="nd">@shape</span><span class="o">.</span><span class="n">setter</span>
<span class="k">def</span> <span class="nf">shape</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">shape</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Set the shape of the tensor. See __init__.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">if</span> <span class="n">shape</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">shape</span> <span class="o">=</span> <span class="n">shape</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">location</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> The physical location of the tensor (on the host or the device).</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">location</span>
<span class="nd">@location</span><span class="o">.</span><span class="n">setter</span>
<span class="k">def</span> <span class="nf">location</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">location</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Set the physical location of the tensor (on the host or the device). See __init__.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">if</span> <span class="n">location</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">location</span> <span class="o">=</span> <span class="n">location</span>
<div class="viewcode-block" id="Tensor.mark_output"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.mark_output">[docs]</a> <span class="k">def</span> <span class="nf">mark_output</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">name</span><span class="p">,</span> <span class="n">dtype</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Mark a tensor as a network output.</span>
<span class="sd"> When a tensor is marked as an output, its content can be obtained after</span>
<span class="sd"> the execution of the TensorRT engine. The user is responsible for</span>
<span class="sd"> allocating buffers to store the output tensors when preparing the</span>
<span class="sd"> execution of the TensorRT engine.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">_mark_output</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">name</span><span class="p">,</span> <span class="n">dtype</span><span class="p">)</span></div>
<span class="k">def</span> <span class="fm">__add__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> See functional.add.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="n">add</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
<span class="k">def</span> <span class="fm">__radd__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> See functional.add.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="n">add</span><span class="p">(</span><span class="n">b</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span>
<span class="k">def</span> <span class="fm">__sub__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> See functional.sub.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="n">sub</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
<span class="k">def</span> <span class="fm">__rsub__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> See functional.sub.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="n">sub</span><span class="p">(</span><span class="n">b</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span>
<span class="k">def</span> <span class="fm">__mul__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> See functional.mul.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="n">mul</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
<span class="k">def</span> <span class="fm">__rmul__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> See functional.mul.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="n">mul</span><span class="p">(</span><span class="n">b</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span>
<span class="k">def</span> <span class="fm">__truediv__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> See functional.div.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="n">div</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
<span class="k">def</span> <span class="fm">__lt__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> See functional.lt.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="n">lt</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
<span class="k">def</span> <span class="fm">__gt__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> See functional.gt.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="n">gt</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
<span class="k">def</span> <span class="fm">__eq__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> See functional.eq.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_tensor_wrapper</span><span class="p">:</span>
<span class="c1"># for graph rewriter</span>
<span class="k">return</span> <span class="nb">hash</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">==</span> <span class="nb">hash</span><span class="p">(</span><span class="n">b</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="c1"># for creating the network</span>
<span class="k">return</span> <span class="n">eq</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
<span class="k">def</span> <span class="fm">__ge__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Maps to functional.gt or functional.eq.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="n">op_or</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="fm">__gt__</span><span class="p">(</span><span class="n">b</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="fm">__eq__</span><span class="p">(</span><span class="n">b</span><span class="p">))</span>
<span class="k">def</span> <span class="fm">__le__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Maps to functional.lt or functional.eq.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="n">op_or</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="fm">__lt__</span><span class="p">(</span><span class="n">b</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="fm">__eq__</span><span class="p">(</span><span class="n">b</span><span class="p">))</span>
<div class="viewcode-block" id="Tensor.view"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.view">[docs]</a> <span class="k">def</span> <span class="nf">view</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">shape</span><span class="p">,</span> <span class="n">zero_is_placeholder</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> See functional.view.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="n">view</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">shape</span><span class="p">,</span> <span class="n">zero_is_placeholder</span><span class="p">)</span></div>
<div class="viewcode-block" id="Tensor.permute"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.permute">[docs]</a> <span class="k">def</span> <span class="nf">permute</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dims</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> See functional.permute.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="n">permute</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dims</span><span class="p">)</span></div>
<div class="viewcode-block" id="Tensor.transpose"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.transpose">[docs]</a> <span class="k">def</span> <span class="nf">transpose</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim0</span><span class="p">,</span> <span class="n">dim1</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> See functional.transpose.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="n">transpose</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim0</span><span class="p">,</span> <span class="n">dim1</span><span class="p">)</span></div>
<div class="viewcode-block" id="Tensor.mean"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.mean">[docs]</a> <span class="k">def</span> <span class="nf">mean</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> See functional.mean.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="n">mean</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">keepdim</span><span class="p">)</span></div>
<div class="viewcode-block" id="Tensor.max"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.max">[docs]</a> <span class="k">def</span> <span class="nf">max</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> See functional.max.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="nb">max</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">keepdim</span><span class="p">)</span></div>
<div class="viewcode-block" id="Tensor.abs"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.abs">[docs]</a> <span class="k">def</span> <span class="nf">abs</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> See functional.abs.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="nb">abs</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span></div>
<div class="viewcode-block" id="Tensor.sqrt"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.sqrt">[docs]</a> <span class="k">def</span> <span class="nf">sqrt</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> See functional.sqrt.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="n">sqrt</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span></div>
<div class="viewcode-block" id="Tensor.cast"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.cast">[docs]</a> <span class="k">def</span> <span class="nf">cast</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dtype</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> See functional.cast.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="n">cast</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dtype</span><span class="p">)</span></div>
<div class="viewcode-block" id="Tensor.size"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.size">[docs]</a> <span class="k">def</span> <span class="nf">size</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Returns the shape of the tensor if the dim parameter is None.</span>
<span class="sd"> Otherwise, returns a size of the dimension indicated by dim. The</span>
<span class="sd"> behavior is undefined if dim is negative or exceeds the rank of the</span>
<span class="sd"> tensor.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">if</span> <span class="n">dim</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">shape</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="n">dim</span><span class="p">]</span></div>
<div class="viewcode-block" id="Tensor.rank"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.rank">[docs]</a> <span class="k">def</span> <span class="nf">rank</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Returns the rank (i.e. the number of dimensions) of the tensor.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span></div>
<div class="viewcode-block" id="Tensor.ndim"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.ndim">[docs]</a> <span class="k">def</span> <span class="nf">ndim</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Returns the rank (i.e. the number of dimensions) of the tensor.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span></div>
<div class="viewcode-block" id="Tensor.split"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.split">[docs]</a> <span class="k">def</span> <span class="nf">split</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">split_size_or_sections</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> See functional.split.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="n">split</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">split_size_or_sections</span><span class="p">,</span> <span class="n">dim</span><span class="p">)</span></div>
<div class="viewcode-block" id="Tensor.is_dynamic"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.is_dynamic">[docs]</a> <span class="k">def</span> <span class="nf">is_dynamic</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> If the argument &#39;dim&#39; is None, that function returns a boolean that</span>
<span class="sd"> indicates if the tensor contains a dynamic dimension (True) or not</span>
<span class="sd"> (False). In that case, the first dimension is excluded (as it usually</span>
<span class="sd"> corresponds to the batch size). If the argument is an integer, that</span>
<span class="sd"> functions returns a boolean that indicates if the dimension &#39;dim&#39; is</span>
<span class="sd"> dynamic (True) or not (False).</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">if</span> <span class="n">dim</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="n">dim</span><span class="p">]</span> <span class="o">==</span> <span class="o">-</span><span class="mi">1</span>
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">s</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">shape</span><span class="p">):</span>
<span class="k">if</span> <span class="n">i</span> <span class="o">!=</span> <span class="mi">0</span> <span class="ow">and</span> <span class="n">s</span> <span class="o">==</span> <span class="o">-</span><span class="mi">1</span><span class="p">:</span>
<span class="k">return</span> <span class="kc">True</span>
<span class="k">return</span> <span class="kc">False</span></div>
<span class="c1"># graph writer related functions</span>
<div class="viewcode-block" id="Tensor.get_parent"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.get_parent">[docs]</a> <span class="k">def</span> <span class="nf">get_parent</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39; Get the layer that produces this tensor. &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">network</span><span class="o">.</span><span class="n">get_tensor_parent</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span></div>
<div class="viewcode-block" id="Tensor.get_users"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.get_users">[docs]</a> <span class="k">def</span> <span class="nf">get_users</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39; Get the layers that use this tensor as an input. &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">network</span><span class="o">.</span><span class="n">get_tensor_users</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span></div>
<div class="viewcode-block" id="Tensor.replace_all_uses_with"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.replace_all_uses_with">[docs]</a> <span class="k">def</span> <span class="nf">replace_all_uses_with</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">new_tensor</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Replace all uses of this tensor as an input to consumer layers</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">network</span><span class="o">.</span><span class="n">is_graph_altered</span> <span class="o">=</span> <span class="kc">True</span>
<span class="n">users</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_users</span><span class="p">()</span>
<span class="k">for</span> <span class="n">user</span> <span class="ow">in</span> <span class="n">users</span><span class="p">:</span>
<span class="n">inputs_changed</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">user</span><span class="o">.</span><span class="n">num_inputs</span><span class="p">):</span>
<span class="k">if</span> <span class="n">user</span><span class="o">.</span><span class="n">get_inputs</span><span class="p">(</span><span class="n">i</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">trt_tensor</span> <span class="ow">is</span> <span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">:</span>
<span class="n">inputs_changed</span> <span class="o">+=</span> <span class="mi">1</span>
<span class="n">user</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="n">new_tensor</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">inputs_changed</span> <span class="o">&gt;=</span> <span class="mi">1</span><span class="p">,</span> <span class="s2">&quot;Tensor not found in layer inputs&quot;</span>
<span class="c1"># update the FLayerMetadata as well</span>
<span class="n">flayer</span> <span class="o">=</span> <span class="n">gw</span><span class="o">.</span><span class="n">FLayerInfoMemo</span><span class="o">.</span><span class="n">instance</span><span class="p">()</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">user</span><span class="o">.</span><span class="n">name</span><span class="p">)</span>
<span class="n">flayer</span> <span class="ow">and</span> <span class="n">flayer</span><span class="o">.</span><span class="n">replace_input_with</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">new_tensor</span><span class="p">)</span></div>
<div class="viewcode-block" id="Tensor.is_trt_wrapper"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.is_trt_wrapper">[docs]</a> <span class="k">def</span> <span class="nf">is_trt_wrapper</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Check if there is a trt.ITensor member inside, which is required for</span>
<span class="sd"> graph rewriter. In order to differentiate usages, it may be necessary</span>
<span class="sd"> to have an inheritance hierarachy.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="s1">&#39;trt_tensor&#39;</span><span class="p">):</span>
<span class="k">return</span> <span class="kc">True</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">return</span> <span class="kc">False</span></div>
<span class="k">def</span> <span class="fm">__hash__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_trt_wrapper</span><span class="p">():</span>
<span class="k">return</span> <span class="nb">id</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">return</span> <span class="nb">id</span><span class="p">(</span><span class="kc">None</span><span class="p">)</span></div>
<span class="k">def</span> <span class="nf">_create_tensor</span><span class="p">(</span><span class="n">trt_tensor</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">ITensor</span><span class="p">,</span>
<span class="n">producer</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">ILayer</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> A helper function to create a TensorRT-LLM Tensor object that encapsulates</span>
<span class="sd"> the connection between the TensorRT tensor (trt.ITensor) and the layer</span>
<span class="sd"> (trt.ILayer) that produces it.</span>
<span class="sd"> That function is expected to be used as:</span>
<span class="sd"> # Insert a new layer in the network using the TensorRT API:</span>
<span class="sd"> layer = default_trtnet().add_&lt;some_layer&gt;(...)</span>
<span class="sd"> # Extract the first output of that layer and connect it to the layer.</span>
<span class="sd"> return _create_tensor(layer.get_output(0), layer)</span>
<span class="sd"> That function also sets the precision of the layer/producer to the default</span>
<span class="sd"> precision of the network.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> trt_tensor : trt.ITensor</span>
<span class="sd"> The TensorRT tensor to connect to its producer (the layer).</span>
<span class="sd"> producer : trt.ILayer = None</span>
<span class="sd"> The producer.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The TensorRT-LLM tensor (functional.Tensor) that encapsulates the</span>
<span class="sd"> TensorRT tensor and the layer that produces it. The former is</span>
<span class="sd"> accessible through the attribute &#39;trt_tensor&#39; and the latter using the</span>
<span class="sd"> attribute &#39;producer&#39;.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">assert</span> <span class="n">trt_tensor</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="n">tensor</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">name</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span>
<span class="n">shape</span><span class="o">=</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span>
<span class="n">is_network_input</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="n">tensor</span><span class="o">.</span><span class="n">trt_tensor</span> <span class="o">=</span> <span class="n">trt_tensor</span>
<span class="n">tensor</span><span class="o">.</span><span class="n">producer</span> <span class="o">=</span> <span class="n">producer</span>
<span class="c1"># Set the layer name since this is the only</span>
<span class="c1"># centralized location to pass the name from</span>
<span class="c1"># module space to the TRT IR</span>
<span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">_set_layer_name</span><span class="p">(</span><span class="n">producer</span><span class="p">)</span>
<span class="k">if</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">dtype</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">strongly_typed</span><span class="p">:</span>
<span class="k">if</span> <span class="n">producer</span><span class="o">.</span><span class="n">type</span> <span class="ow">not</span> <span class="ow">in</span> <span class="p">[</span>
<span class="n">trt</span><span class="o">.</span><span class="n">LayerType</span><span class="o">.</span><span class="n">CONSTANT</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">LayerType</span><span class="o">.</span><span class="n">GATHER</span><span class="p">,</span>
<span class="n">trt</span><span class="o">.</span><span class="n">LayerType</span><span class="o">.</span><span class="n">CONCATENATION</span>
<span class="p">]:</span>
<span class="n">producer</span><span class="o">.</span><span class="n">precision</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">dtype</span>
<span class="k">assert</span> <span class="n">tensor</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="k">if</span> <span class="n">gw</span><span class="o">.</span><span class="n">FLayerInfoMemo</span><span class="o">.</span><span class="n">instance</span><span class="p">()</span><span class="o">.</span><span class="n">cur_flayer</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">gw</span><span class="o">.</span><span class="n">FLayerInfoMemo</span><span class="o">.</span><span class="n">instance</span><span class="p">()</span><span class="o">.</span><span class="n">cur_flayer</span><span class="o">.</span><span class="n">layer_name</span> <span class="o">=</span> <span class="n">producer</span><span class="o">.</span><span class="n">name</span>
<span class="k">return</span> <span class="n">tensor</span>
<div class="viewcode-block" id="RotaryScalingType"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.RotaryScalingType">[docs]</a><span class="k">class</span> <span class="nc">RotaryScalingType</span><span class="p">(</span><span class="n">IntEnum</span><span class="p">):</span>
<span class="n">none</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">linear</span> <span class="o">=</span> <span class="mi">1</span>
<span class="n">dynamic</span> <span class="o">=</span> <span class="mi">2</span></div>
<div class="viewcode-block" id="PositionEmbeddingType"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.PositionEmbeddingType">[docs]</a><span class="k">class</span> <span class="nc">PositionEmbeddingType</span><span class="p">(</span><span class="n">IntEnum</span><span class="p">):</span>
<span class="n">learned_absolute</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">rope_gptj</span> <span class="o">=</span> <span class="mi">1</span>
<span class="n">rope_gpt_neox</span> <span class="o">=</span> <span class="mi">2</span>
<span class="n">alibi</span> <span class="o">=</span> <span class="mi">3</span>
<span class="n">alibi_with_scale</span> <span class="o">=</span> <span class="mi">4</span>
<span class="n">relative</span> <span class="o">=</span> <span class="mi">5</span>
<div class="viewcode-block" id="PositionEmbeddingType.is_rope"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.PositionEmbeddingType.is_rope">[docs]</a> <span class="k">def</span> <span class="nf">is_rope</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
<span class="k">return</span> <span class="bp">self</span> <span class="ow">in</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">rope_gptj</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">rope_gpt_neox</span><span class="p">]</span></div>
<div class="viewcode-block" id="PositionEmbeddingType.is_alibi"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.PositionEmbeddingType.is_alibi">[docs]</a> <span class="k">def</span> <span class="nf">is_alibi</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
<span class="k">return</span> <span class="bp">self</span> <span class="ow">in</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">alibi</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">alibi_with_scale</span><span class="p">]</span></div>
<div class="viewcode-block" id="PositionEmbeddingType.choices"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.PositionEmbeddingType.choices">[docs]</a> <span class="nd">@staticmethod</span>
<span class="k">def</span> <span class="nf">choices</span><span class="p">()</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]:</span>
<span class="k">return</span> <span class="p">[</span><span class="n">embedding</span><span class="o">.</span><span class="n">name</span> <span class="k">for</span> <span class="n">embedding</span> <span class="ow">in</span> <span class="n">PositionEmbeddingType</span><span class="p">]</span></div></div>
<div class="viewcode-block" id="AttentionMaskType"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.AttentionMaskType">[docs]</a><span class="k">class</span> <span class="nc">AttentionMaskType</span><span class="p">(</span><span class="n">IntEnum</span><span class="p">):</span>
<span class="n">padding</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">causal</span> <span class="o">=</span> <span class="mi">1</span>
<span class="n">bidirectional</span> <span class="o">=</span> <span class="mi">2</span></div>
<div class="viewcode-block" id="LayerNormType"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.LayerNormType">[docs]</a><span class="k">class</span> <span class="nc">LayerNormType</span><span class="p">(</span><span class="n">IntEnum</span><span class="p">):</span>
<span class="n">LayerNorm</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">RmsNorm</span> <span class="o">=</span> <span class="mi">1</span>
<span class="n">GroupNorm</span> <span class="o">=</span> <span class="mi">2</span></div>
<div class="viewcode-block" id="LayerNormPositionType"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.LayerNormPositionType">[docs]</a><span class="k">class</span> <span class="nc">LayerNormPositionType</span><span class="p">(</span><span class="n">IntEnum</span><span class="p">):</span>
<span class="n">pre_layernorm</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">post_layernorm</span> <span class="o">=</span> <span class="mi">1</span></div>
<div class="viewcode-block" id="activation"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.activation">[docs]</a><span class="k">def</span> <span class="nf">activation</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">act_type</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">ActivationType</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an activation function.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor on which the activation function is applied.</span>
<span class="sd"> act_type : trt.ActivationType</span>
<span class="sd"> The type of the activation (RELU, TANH, SIGMOID, ...).</span>
<span class="sd"> The following closures are defined in functional.*:</span>
<span class="sd"> relu for op=trt.ActivationType.RELU</span>
<span class="sd"> tanh for op=trt.ActivationType.TANH</span>
<span class="sd"> sigmoid for op=trt.ActivationType.SIGMOID</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by the activation layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_activation</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">act_type</span><span class="p">)</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
<div class="viewcode-block" id="clip"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.clip">[docs]</a><span class="k">def</span> <span class="nf">clip</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">alpha</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">beta</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add a CLIP operation that sets the range to [alpha, beta].</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor on which the activation function is applied.</span>
<span class="sd"> alpha : float</span>
<span class="sd"> The lower bound of the CLIP function.</span>
<span class="sd"> beta : float</span>
<span class="sd"> The upper bound of the CLIP function.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by the activation layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_activation</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
<span class="n">trt</span><span class="o">.</span><span class="n">ActivationType</span><span class="o">.</span><span class="n">CLIP</span><span class="p">)</span>
<span class="n">layer</span><span class="o">.</span><span class="n">alpha</span> <span class="o">=</span> <span class="n">alpha</span>
<span class="n">layer</span><span class="o">.</span><span class="n">beta</span> <span class="o">=</span> <span class="n">beta</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
<span class="n">relu</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">activation</span><span class="p">,</span> <span class="n">act_type</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ActivationType</span><span class="o">.</span><span class="n">RELU</span><span class="p">)</span>
<span class="n">tanh</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">activation</span><span class="p">,</span> <span class="n">act_type</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ActivationType</span><span class="o">.</span><span class="n">TANH</span><span class="p">)</span>
<span class="n">sigmoid</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">activation</span><span class="p">,</span> <span class="n">act_type</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ActivationType</span><span class="o">.</span><span class="n">SIGMOID</span><span class="p">)</span>
<div class="viewcode-block" id="silu"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.silu">[docs]</a><span class="k">def</span> <span class="nf">silu</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add a SiLU (`x * sigmoid(x)`) operation.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor on which the activation function is applied.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by the activation layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="nb">input</span> <span class="o">*</span> <span class="n">sigmoid</span><span class="p">(</span><span class="nb">input</span><span class="p">)</span></div>
<div class="viewcode-block" id="swiglu"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.swiglu">[docs]</a><span class="k">def</span> <span class="nf">swiglu</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add a SwiGLU (`x * SiLU(gate)`) operation.</span>
<span class="sd"> That function takes a tensor, splits it into two halves along the last</span>
<span class="sd"> dimension, applies SiLU to the second half and multiply the results. The</span>
<span class="sd"> behaviour is undefined if the last dimension is not even.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor on which the activation function is applied.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by the activation layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">x</span><span class="p">,</span> <span class="n">gate</span> <span class="o">=</span> <span class="n">chunk</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="k">return</span> <span class="n">silu</span><span class="p">(</span><span class="n">gate</span><span class="p">)</span> <span class="o">*</span> <span class="n">x</span></div>
<div class="viewcode-block" id="squared_relu"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.squared_relu">[docs]</a><span class="k">def</span> <span class="nf">squared_relu</span><span class="p">(</span><span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add a Squared ReLU operation.</span>
<span class="sd"> This function applies ReLU and squares the output.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor on which the activation function is applied.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by the activation layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="nb">pow</span><span class="p">(</span><span class="n">relu</span><span class="p">(</span><span class="n">x</span><span class="p">),</span> <span class="mf">2.0</span><span class="p">)</span></div>
<div class="viewcode-block" id="cast"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.cast">[docs]</a><span class="k">def</span> <span class="nf">cast</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dtype</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="p">]):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add a cast operation.</span>
<span class="sd"> For an input tensor of type INT8, this function sets the dynamic range of</span>
<span class="sd"> the input to [-127, 127] for automatic dequantization. For a cast into</span>
<span class="sd"> INT8, that function sets the dynamic range of the output to [-127, 127] for</span>
<span class="sd"> automatic quantization.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor on which the cast is applied.</span>
<span class="sd"> dtype : str or trt.DataType</span>
<span class="sd"> The data type of the output tensor after the cast. When &#39;dtype&#39; is</span>
<span class="sd"> provided as a string, it must be a name amongst the valid names.</span>
<span class="sd"> See _str_to_trt_dtype_dict in _utils.py for a list of supported</span>
<span class="sd"> types and type names.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by the inserted layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dtype</span><span class="p">,</span> <span class="nb">str</span><span class="p">):</span>
<span class="n">cvt_dtype</span> <span class="o">=</span> <span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span>
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dtype</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="p">):</span>
<span class="n">cvt_dtype</span> <span class="o">=</span> <span class="n">dtype</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">TypeError</span><span class="p">(</span><span class="s2">&quot;</span><span class="si">%s</span><span class="s2"> is not supported&quot;</span> <span class="o">%</span> <span class="nb">type</span><span class="p">(</span><span class="n">dtype</span><span class="p">))</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_cast</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">cvt_dtype</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">strongly_typed</span><span class="p">:</span>
<span class="n">layer</span><span class="o">.</span><span class="n">set_output_type</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">cvt_dtype</span><span class="p">)</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">input</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="s1">&#39;int8&#39;</span><span class="p">):</span>
<span class="n">layer</span><span class="o">.</span><span class="n">get_input</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">set_dynamic_range</span><span class="p">(</span><span class="o">-</span><span class="mi">127</span><span class="p">,</span> <span class="mi">127</span><span class="p">)</span>
<span class="k">if</span> <span class="n">cvt_dtype</span> <span class="o">==</span> <span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="s1">&#39;int8&#39;</span><span class="p">):</span>
<span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">set_dynamic_range</span><span class="p">(</span><span class="o">-</span><span class="mi">127</span><span class="p">,</span> <span class="mi">127</span><span class="p">)</span>
<span class="k">return</span> <span class="n">output</span></div>
<div class="viewcode-block" id="flip"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.flip">[docs]</a><span class="k">def</span> <span class="nf">flip</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dims</span><span class="p">:</span> <span class="n">Sequence</span><span class="p">[</span><span class="nb">int</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Reverses the order of an n-D tensor along given axis in dims.</span>
<span class="sd"> That flip operation maps to a TensorRT ISliceLayer. For the dimensions</span>
<span class="sd"> listed in dims it copies the elements from the last one to the first one</span>
<span class="sd"> (from (N-1) down to 0 with a step of -1). For the dimensions not in &#39;dims&#39;,</span>
<span class="sd"> it copies the elements from the first one to the last one (from 0 to N-1</span>
<span class="sd"> with a step of 1).</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor on which the cast is applied.</span>
<span class="sd"> dims : list or tuple</span>
<span class="sd"> The axes to flip. Negative indices are supported.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by the inserted layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">assert</span> <span class="ow">not</span> <span class="nb">input</span><span class="o">.</span><span class="n">is_dynamic</span><span class="p">()</span>
<span class="n">ndim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
<span class="k">for</span> <span class="n">index</span><span class="p">,</span> <span class="n">value</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">dims</span><span class="p">):</span>
<span class="k">assert</span> <span class="o">-</span><span class="n">ndim</span> <span class="o">&lt;=</span> <span class="n">value</span> <span class="o">&lt;</span> <span class="n">ndim</span>
<span class="k">if</span> <span class="o">-</span><span class="n">ndim</span> <span class="o">&lt;=</span> <span class="n">value</span> <span class="o">&lt;</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">dims</span><span class="p">[</span><span class="n">index</span><span class="p">]</span> <span class="o">+=</span> <span class="n">ndim</span>
<span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">dims</span><span class="p">)</span> <span class="o">==</span> <span class="nb">len</span><span class="p">(</span><span class="nb">set</span><span class="p">(</span><span class="n">dims</span><span class="p">))</span>
<span class="n">start_values</span> <span class="o">=</span> <span class="p">[</span>
<span class="nb">input</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="n">i</span><span class="p">]</span> <span class="o">-</span> <span class="mi">1</span> <span class="k">if</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">dims</span> <span class="k">else</span> <span class="mi">0</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)</span>
<span class="p">]</span>
<span class="n">stride_values</span> <span class="o">=</span> <span class="p">[</span><span class="o">-</span><span class="mi">1</span> <span class="k">if</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">dims</span> <span class="k">else</span> <span class="mi">1</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)]</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_slice</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
<span class="n">start</span><span class="o">=</span><span class="n">start_values</span><span class="p">,</span>
<span class="n">shape</span><span class="o">=</span><span class="nb">input</span><span class="o">.</span><span class="n">size</span><span class="p">(),</span>
<span class="n">stride</span><span class="o">=</span><span class="n">stride_values</span><span class="p">)</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
<div class="viewcode-block" id="interpolate"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.interpolate">[docs]</a><span class="k">def</span> <span class="nf">interpolate</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">size</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">scale_factor</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">mode</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s1">&#39;nearest&#39;</span><span class="p">,</span>
<span class="n">align_corners</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">recompute_scale_factor</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">antialias</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="c1">##</span>
<span class="c1">## TODO: Document that function!</span>
<span class="c1">##</span>
<span class="k">assert</span> <span class="ow">not</span> <span class="nb">input</span><span class="o">.</span><span class="n">is_dynamic</span><span class="p">()</span>
<span class="n">input_ndim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
<span class="k">assert</span> <span class="mi">2</span> <span class="o">&lt;</span> <span class="n">input_ndim</span> <span class="o">&lt;</span> <span class="mi">6</span><span class="p">,</span> <span class="s2">&quot;Only 3D, 4D and 5D input Tensors supported&quot;</span>
<span class="k">assert</span> <span class="p">(</span><span class="n">size</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">)</span> <span class="o">^</span> <span class="p">(</span>
<span class="n">scale_factor</span>
<span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">),</span> <span class="s2">&quot;Only one of out_shape or scales should be defined&quot;</span>
<span class="k">assert</span> <span class="n">mode</span> <span class="ow">in</span> <span class="p">(</span><span class="s1">&#39;nearest&#39;</span><span class="p">,</span> <span class="s1">&#39;linear&#39;</span><span class="p">,</span> <span class="s1">&#39;bilinear&#39;</span><span class="p">,</span> <span class="s1">&#39;bicubic&#39;</span><span class="p">,</span> <span class="s1">&#39;trilinear&#39;</span><span class="p">,</span>
<span class="s1">&#39;nearest-exact&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="n">mode</span> <span class="o">==</span> <span class="s1">&#39;trilinear&#39;</span> <span class="ow">and</span> <span class="n">input_ndim</span> <span class="o">!=</span> <span class="mi">5</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;trilinear only supports 5D tensor&quot;</span><span class="p">)</span>
<span class="k">if</span> <span class="n">mode</span> <span class="o">==</span> <span class="s2">&quot;bilinear&quot;</span> <span class="ow">and</span> <span class="n">input_ndim</span> <span class="o">!=</span> <span class="mi">4</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;bilinear only supports 4D tensor&quot;</span><span class="p">)</span>
<span class="k">if</span> <span class="n">mode</span> <span class="o">==</span> <span class="s2">&quot;linear&quot;</span> <span class="ow">and</span> <span class="n">input_ndim</span> <span class="o">!=</span> <span class="mi">3</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;linear only supports 3D tensor&quot;</span><span class="p">)</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_resize</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
<span class="n">input_shape</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">size</span><span class="p">()</span>
<span class="n">updated_shape</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">if</span> <span class="n">scale_factor</span><span class="p">:</span>
<span class="n">scale_len</span> <span class="o">=</span> <span class="mi">1</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scale_factor</span><span class="p">,</span>
<span class="p">(</span><span class="nb">float</span><span class="p">,</span> <span class="nb">int</span><span class="p">))</span> <span class="k">else</span> <span class="nb">len</span><span class="p">(</span><span class="n">scale_factor</span><span class="p">)</span>
<span class="k">if</span> <span class="n">scale_len</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">and</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scale_factor</span><span class="p">,</span> <span class="p">(</span><span class="nb">float</span><span class="p">,</span> <span class="nb">int</span><span class="p">)):</span>
<span class="n">updated_scale</span> <span class="o">=</span> <span class="p">[</span><span class="n">scale_factor</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">input_ndim</span> <span class="o">-</span> <span class="mi">2</span><span class="p">)]</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">updated_scale</span> <span class="o">=</span> <span class="n">scale_factor</span>
<span class="n">updated_shape</span> <span class="o">=</span> <span class="p">[</span>
<span class="nb">int</span><span class="p">(</span><span class="n">math</span><span class="o">.</span><span class="n">floor</span><span class="p">(</span><span class="n">updated_scale</span><span class="p">[</span><span class="n">i</span> <span class="o">-</span> <span class="mi">2</span><span class="p">]</span> <span class="o">*</span>
<span class="n">input_shape</span><span class="p">[</span><span class="n">i</span><span class="p">]))</span> <span class="k">if</span> <span class="n">i</span> <span class="o">&gt;</span> <span class="mi">1</span> <span class="k">else</span> <span class="n">input_shape</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">input_ndim</span><span class="p">)</span>
<span class="p">]</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">size_len</span> <span class="o">=</span> <span class="mi">1</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="nb">int</span><span class="p">)</span> <span class="k">else</span> <span class="nb">len</span><span class="p">(</span><span class="n">size</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">size_len</span> <span class="o">==</span> <span class="n">input_ndim</span> <span class="o">-</span> <span class="mi">2</span>
<span class="k">if</span> <span class="n">size_len</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">and</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
<span class="n">updated_size</span> <span class="o">=</span> <span class="p">[</span><span class="n">size</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">input_ndim</span> <span class="o">-</span> <span class="mi">2</span><span class="p">)]</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">updated_size</span> <span class="o">=</span> <span class="n">size</span>
<span class="n">updated_shape</span> <span class="o">=</span> <span class="p">[</span>
<span class="n">input_shape</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="k">if</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="mi">2</span> <span class="k">else</span> <span class="n">updated_size</span><span class="p">[</span><span class="n">i</span> <span class="o">-</span> <span class="mi">2</span><span class="p">]</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">input_ndim</span><span class="p">)</span>
<span class="p">]</span>
<span class="n">layer</span><span class="o">.</span><span class="n">shape</span> <span class="o">=</span> <span class="n">updated_shape</span>
<span class="k">if</span> <span class="n">mode</span> <span class="ow">in</span> <span class="p">[</span><span class="s1">&#39;nearest&#39;</span><span class="p">,</span> <span class="s1">&#39;nearest-exact&#39;</span><span class="p">]</span> <span class="ow">or</span> <span class="n">mode</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">layer</span><span class="o">.</span><span class="n">resize_mode</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">ResizeMode</span><span class="o">.</span><span class="n">NEAREST</span>
<span class="n">layer</span><span class="o">.</span><span class="n">coordinate_transformation</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">ResizeCoordinateTransformation</span><span class="o">.</span><span class="n">ASYMMETRIC</span>
<span class="k">elif</span> <span class="n">mode</span> <span class="ow">in</span> <span class="p">[</span><span class="s1">&#39;linear&#39;</span><span class="p">,</span> <span class="s1">&#39;bilinear&#39;</span><span class="p">,</span> <span class="s1">&#39;trilinear&#39;</span><span class="p">]:</span>
<span class="n">layer</span><span class="o">.</span><span class="n">resize_mode</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">ResizeMode</span><span class="o">.</span><span class="n">LINEAR</span>
<span class="k">if</span> <span class="n">align_corners</span><span class="p">:</span>
<span class="n">layer</span><span class="o">.</span><span class="n">coordinate_transformation</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">ResizeCoordinateTransformation</span><span class="o">.</span><span class="n">ALIGN_CORNERS</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">layer</span><span class="o">.</span><span class="n">coordinate_transformation</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">ResizeCoordinateTransformation</span><span class="o">.</span><span class="n">HALF_PIXEL</span>
<span class="c1"># TODO, need to confirm the align_corners effect on bilinear mode.</span>
<span class="k">if</span> <span class="n">mode</span> <span class="o">==</span> <span class="s1">&#39;bilinear&#39;</span><span class="p">:</span>
<span class="n">layer</span><span class="o">.</span><span class="n">coordinate_transformation</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">ResizeCoordinateTransformation</span><span class="o">.</span><span class="n">HALF_PIXEL</span>
<span class="k">elif</span> <span class="n">mode</span> <span class="ow">in</span> <span class="p">[</span><span class="s1">&#39;bicubic&#39;</span><span class="p">]:</span>
<span class="n">layer</span><span class="o">.</span><span class="n">resize_mode</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">ResizeMode</span><span class="o">.</span><span class="n">CUBIC</span>
<span class="n">layer</span><span class="o">.</span><span class="n">coordinate_transformation</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">ResizeCoordinateTransformation</span><span class="o">.</span><span class="n">HALF_PIXEL</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">layer</span><span class="o">.</span><span class="n">resize_mode</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">ResizeMode</span><span class="o">.</span><span class="n">NEAREST</span>
<span class="n">layer</span><span class="o">.</span><span class="n">coordinate_transformation</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">ResizeCoordinateTransformation</span><span class="o">.</span><span class="n">ASYMMETRIC</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
<div class="viewcode-block" id="matmul"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.matmul">[docs]</a><span class="k">def</span> <span class="nf">matmul</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">mat2</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">transa</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">transb</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add a matrix multiplication.</span>
<span class="sd"> That operation maps to a tensorrt.IMatrixMultiplyLayer layer. As explained</span>
<span class="sd"> in the TensorRT documentation, it computes the inner product between the</span>
<span class="sd"> two inputs after applying an optional transposition on the inputs.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The first tensor (often called A).</span>
<span class="sd"> mat2 : Tensor</span>
<span class="sd"> The second tensor (often called B).</span>
<span class="sd"> transa : bool</span>
<span class="sd"> Is the first input transposed? Set to &#39;True&#39; if you want the first</span>
<span class="sd"> input to be transposed, &#39;False&#39; otherwise.</span>
<span class="sd"> transb : bool</span>
<span class="sd"> Is the second input transposed? Set to &#39;True&#39; if you want the</span>
<span class="sd"> second input to be transposed, &#39;False&#39; otherwise.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by the inserted layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="nb">input</span><span class="p">,</span> <span class="n">mat2</span> <span class="o">=</span> <span class="n">broadcast_helper</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">mat2</span><span class="p">)</span>
<span class="n">op0</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">MatrixOperation</span><span class="o">.</span><span class="n">TRANSPOSE</span> <span class="k">if</span> <span class="n">transa</span> \
<span class="k">else</span> <span class="n">trt</span><span class="o">.</span><span class="n">MatrixOperation</span><span class="o">.</span><span class="n">NONE</span>
<span class="n">op1</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">MatrixOperation</span><span class="o">.</span><span class="n">TRANSPOSE</span> <span class="k">if</span> <span class="n">transb</span> \
<span class="k">else</span> <span class="n">trt</span><span class="o">.</span><span class="n">MatrixOperation</span><span class="o">.</span><span class="n">NONE</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_matrix_multiply</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">op0</span><span class="p">,</span>
<span class="n">mat2</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">op1</span><span class="p">)</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
<div class="viewcode-block" id="constant"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.constant">[docs]</a><span class="k">def</span> <span class="nf">constant</span><span class="p">(</span><span class="n">ndarray</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add a constant layer.</span>
<span class="sd"> TensorRT graphs encapsulate constant values in the form of constant layers</span>
<span class="sd"> (tensorrt.IConstantLayer). That function creates such a layer from a Numpy</span>
<span class="sd"> array of values. After compilation of the network by TensorRT, those</span>
<span class="sd"> weights are stored in the serialized TensorRT engine.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> ndarray : numpy.ndarray</span>
<span class="sd"> The array of values (weights) encapsulated by this constant layer.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by the inserted layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">weights</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">Weights</span><span class="p">(</span><span class="n">np_dtype_to_trt</span><span class="p">(</span><span class="n">ndarray</span><span class="o">.</span><span class="n">dtype</span><span class="p">),</span> <span class="n">ndarray</span><span class="o">.</span><span class="n">ctypes</span><span class="o">.</span><span class="n">data</span><span class="p">,</span>
<span class="n">ndarray</span><span class="o">.</span><span class="n">size</span><span class="p">)</span>
<span class="c1"># Prevent underlying numpy array from going out of scope</span>
<span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">register_ndarray</span><span class="p">(</span><span class="n">ndarray</span><span class="p">)</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_constant</span><span class="p">(</span><span class="n">trt</span><span class="o">.</span><span class="n">Dims</span><span class="p">(</span><span class="n">ndarray</span><span class="o">.</span><span class="n">shape</span><span class="p">),</span> <span class="n">weights</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">_strongly_typed</span><span class="p">:</span>
<span class="n">layer</span><span class="o">.</span><span class="n">set_output_type</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">np_dtype_to_trt</span><span class="p">(</span><span class="n">ndarray</span><span class="o">.</span><span class="n">dtype</span><span class="p">))</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
<span class="c1"># TODO: TensorRT uses sizes of the output dimensions.</span>
<span class="c1"># DL framework uses ends usually. Will change it to ends.</span>
<div class="viewcode-block" id="slice"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.slice">[docs]</a><span class="k">def</span> <span class="nf">slice</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">starts</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">Sequence</span><span class="p">[</span><span class="nb">int</span><span class="p">]],</span>
<span class="n">sizes</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">Sequence</span><span class="p">[</span><span class="nb">int</span><span class="p">]])</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an operation to extract a slice from a tensor.</span>
<span class="sd"> As described in the TensorRT documentation of the ISliceLayer, the slice</span>
<span class="sd"> layer has two variants: Static and dynamic.</span>
<span class="sd"> For static slicing, this function takes the starts and sizes values in the</span>
<span class="sd"> different dimensions to slice at layer creation time via a sequence of</span>
<span class="sd"> integers. For dynamic slicing, it accepts starts and sizes as</span>
<span class="sd"> tensorrt.ITensor`s.</span>
<span class="sd"> The slice layer selects for each dimension a start location from within the</span>
<span class="sd"> input tensor, and copies elements to the output tensor using a stride of 1</span>
<span class="sd"> across the input tensor. Start and size tensors must be 1-D int32 shape</span>
<span class="sd"> tensors if not specified as a sequence of integers.</span>
<span class="sd"> As an example, on input = [[0, 2, 4], [1, 3, 5]], the call to</span>
<span class="sd"> slice(input, start=[1, 0], size=[1, 2])</span>
<span class="sd"> will produce the tensor [[1, 3]] as output. The slice operator when</span>
<span class="sd"> executed by TensorRT will copy one row (because size[0] == 1) starting from</span>
<span class="sd"> the 2nd row (because start[0] == 1) and two columns (size[1] == 2) starting</span>
<span class="sd"> from the 1st column (because start[1] == 0).</span>
<span class="sd"> In pseudo-code the behaviour of that operation can be described as follows</span>
<span class="sd"> for a 2D tensor (and easily be extended to more dimensions):</span>
<span class="sd"> output = Tensor(shape=sizes)</span>
<span class="sd"> for ii in range(sizes[0]):</span>
<span class="sd"> for jj in range(sizes[1]):</span>
<span class="sd"> output[ii][jj] = input[starts[0]+ii][starts[1]+jj]</span>
<span class="sd"> Note that it is common in deep-learning frameworks to use ranges</span>
<span class="sd"> [start:end] for similar operations. It can be emulated by setting the sizes</span>
<span class="sd"> argument such that in each dimension [start:start+size] == [start:end] i.e.</span>
<span class="sd"> size = end-start.</span>
<span class="sd"> TensorRT supports different slice modes but that function restricts that</span>
<span class="sd"> choice to `mode == tensorrt.SliceMode.STRICT_BOUNDS`.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor on which the slicing is performed.</span>
<span class="sd"> starts : Union[Tensor, Sequence[int]]</span>
<span class="sd"> The starting points, in the input tensor, and each dimension.</span>
<span class="sd"> sizes : Union[Tensor, Sequence[int]]</span>
<span class="sd"> The number of elements in each dimension of the sliced tensor (output).</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by the slice layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">input_ndim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
<span class="n">trt_starts</span> <span class="o">=</span> <span class="n">starts</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">starts</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">):</span>
<span class="n">trt_starts</span> <span class="o">=</span> <span class="p">[</span><span class="mi">0</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">input_ndim</span><span class="p">)]</span> <span class="c1"># unused dummy value</span>
<span class="n">trt_sizes</span> <span class="o">=</span> <span class="n">sizes</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">sizes</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">):</span>
<span class="n">trt_sizes</span> <span class="o">=</span> <span class="p">[</span><span class="mi">1</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">input_ndim</span><span class="p">)]</span> <span class="c1"># unused dummy value</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_slice</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
<span class="n">start</span><span class="o">=</span><span class="n">trt_starts</span><span class="p">,</span>
<span class="n">shape</span><span class="o">=</span><span class="n">trt_sizes</span><span class="p">,</span>
<span class="n">stride</span><span class="o">=</span><span class="p">[</span><span class="mi">1</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">input_ndim</span><span class="p">)])</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">starts</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">):</span>
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">starts</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">sizes</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">):</span>
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">sizes</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
<span class="c1"># TODO: support step.</span>
<div class="viewcode-block" id="arange"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.arange">[docs]</a><span class="k">def</span> <span class="nf">arange</span><span class="p">(</span><span class="n">start</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">],</span> <span class="n">end</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">],</span>
<span class="n">dtype</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an operation to fill a 1D tensor.</span>
<span class="sd"> The tensor is filled with the values between start and end with a step of 1</span>
<span class="sd"> between the different elements. In pseudo-code, it corresponds to a tensor</span>
<span class="sd"> populated with the values:</span>
<span class="sd"> output = Tensor([dtype(ii) for ii in range(start, end, 1)])</span>
<span class="sd"> For example, a call to arange(3, 6, &#39;int32&#39;) will add an operation to the</span>
<span class="sd"> TensorRT graph that will produce [3, 4, 5] when executed. The call to</span>
<span class="sd"> arange(2, 5, &#39;float32&#39;) will add a layer to generate [2.0, 3.0, 4.0].</span>
<span class="sd"> This operation is implemented using a tensorrt.IFillLayer in</span>
<span class="sd"> trt.FillOperation.LINSPACE mode.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> start : Union[Tensor, int]</span>
<span class="sd"> The starting point of the range.</span>
<span class="sd"> end : Union[Tensor, int]</span>
<span class="sd"> The end point of the range.</span>
<span class="sd"> dtype : str</span>
<span class="sd"> The type of the elements. See _str_to_trt_dtype_dict in _utils.py</span>
<span class="sd"> for a list of supported types and type names.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by the fill layer. It is a 1D tensor containing</span>
<span class="sd"> `end-start` elements of type `dtype`.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">start</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">end</span><span class="p">,</span> <span class="nb">int</span><span class="p">)</span>
<span class="n">start</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">int32_array</span><span class="p">(</span><span class="n">start</span><span class="p">))</span>
<span class="n">end</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">int32_array</span><span class="p">(</span><span class="n">end</span><span class="p">))</span>
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">start</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">):</span>
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">end</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">TypeError</span><span class="p">(</span><span class="s2">&quot;</span><span class="si">%s</span><span class="s2"> is not supported&quot;</span> <span class="o">%</span> <span class="nb">type</span><span class="p">(</span><span class="n">start</span><span class="p">))</span>
<span class="n">step</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">int32_array</span><span class="p">([</span><span class="mi">1</span><span class="p">]))</span>
<span class="n">num</span> <span class="o">=</span> <span class="n">end</span> <span class="o">-</span> <span class="n">start</span>
<span class="n">num</span> <span class="o">=</span> <span class="n">num</span><span class="o">.</span><span class="n">view</span><span class="p">([</span><span class="mi">1</span><span class="p">])</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_fill</span><span class="p">([</span><span class="mi">0</span><span class="p">],</span> <span class="n">trt</span><span class="o">.</span><span class="n">FillOperation</span><span class="o">.</span><span class="n">LINSPACE</span><span class="p">)</span>
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">num</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span> <span class="c1"># rank = 1</span>
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">start</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span> <span class="c1"># rank = 0</span>
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">step</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span> <span class="c1"># rank = 1</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
<div class="viewcode-block" id="expand"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.expand">[docs]</a><span class="k">def</span> <span class="nf">expand</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">expand_shape</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an operation to expand a tensor.</span>
<span class="sd"> The operation expands the input tensor in the singleton dimensions to the</span>
<span class="sd"> size indicated by the corresponding dimension in the `expand_shape` tensor.</span>
<span class="sd"> In other words, given an input tensor with dimensions of size 1, those</span>
<span class="sd"> dimensions will be expanded to the size in `expand_shape`.</span>
<span class="sd"> For example, a tensor of shape [4, 3, 1, 3] will be expanded to a tensor of</span>
<span class="sd"> shape [4, 3, 2, 3] by the layer created using expand(input, [4, 3, 2, 3]).</span>
<span class="sd"> The expansion may either replicate the values or be mapped to a view with a</span>
<span class="sd"> stride of 0 in the expanded dimensions. For example, for a tensor [[3, 2]] of</span>
<span class="sd"> shape [1, 2],</span>
<span class="sd"> expand([[3, 2]], [2, 2])</span>
<span class="sd"> can be used to expand the input to [[3, 2], [3, 2]].</span>
<span class="sd"> This operation is implemented using a tensorrt.ISliceLayer. The current</span>
<span class="sd"> implementation does not verify that non singleton dimensions are not</span>
<span class="sd"> shrinked. In other words, for an input of shape [4, 1, 2],</span>
<span class="sd"> expand(input, [3, 2, 2])</span>
<span class="sd"> will produce a tensor of shape [3, 2, 2]. That behaviour is subject to</span>
<span class="sd"> change in the future.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor.</span>
<span class="sd"> expand_shape : Tensor</span>
<span class="sd"> The new shape of the expanded tensor.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by the expand layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">ndim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_slice</span><span class="p">(</span>
<span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
<span class="n">start</span><span class="o">=</span><span class="p">[</span><span class="mi">0</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)],</span>
<span class="n">shape</span><span class="o">=</span><span class="p">[</span><span class="mi">1</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)],</span> <span class="c1"># unused dummy value</span>
<span class="n">stride</span><span class="o">=</span><span class="p">[</span><span class="mi">1</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)]</span> <span class="c1"># unused dummy value</span>
<span class="p">)</span>
<span class="c1"># The stride is either:</span>
<span class="c1"># 0 for dimensions of size 1 (i.e. shape(input, i) - 1 == 1 - 1 == 0) or,</span>
<span class="c1"># 1 for dimensions of size &gt; 1 since minimum(value &gt;= 1, 1) == 1.</span>
<span class="n">stride_tensor</span> <span class="o">=</span> <span class="n">concat</span><span class="p">(</span>
<span class="p">[</span><span class="n">minimum</span><span class="p">((</span><span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">i</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span><span class="p">),</span> <span class="mi">1</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)])</span>
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">expand_shape</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="n">stride_tensor</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
<div class="viewcode-block" id="einsum"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.einsum">[docs]</a><span class="k">def</span> <span class="nf">einsum</span><span class="p">(</span><span class="n">einsum_eq</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">inputs</span><span class="p">:</span> <span class="n">Sequence</span><span class="p">[</span><span class="n">Tensor</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an Einsum operation.</span>
<span class="sd"> That operation maps to tensorrt.IEinsumLayer. As explained in the TensorRT</span>
<span class="sd"> documentation, this layer implements a summation over the elements of the</span>
<span class="sd"> inputs along dimensions specified by the equation parameter, based on the</span>
<span class="sd"> Einstein summation convention. The layer can have one or more inputs of</span>
<span class="sd"> rank &gt;= 0. All the inputs must be of same data type. This layer supports</span>
<span class="sd"> all TensorRT data types except bool. There is one output tensor of the same</span>
<span class="sd"> type as the input tensors. The shape of output tensor is determined by the</span>
<span class="sd"> equation.</span>
<span class="sd"> The equation specifies ASCII lower-case letters for each dimension in the</span>
<span class="sd"> inputs in the same order as the dimensions, separated by comma for each</span>
<span class="sd"> input. The dimensions labeled with the same subscript must match or be</span>
<span class="sd"> broadcastable. Repeated subscript labels in one input take the diagonal.</span>
<span class="sd"> Repeating a label across multiple inputs means that those axes will be</span>
<span class="sd"> multiplied. Omitting a label from the output means values along those axes</span>
<span class="sd"> will be summed. In implicit mode, the indices which appear once in the</span>
<span class="sd"> expression will be part of the output in increasing alphabetical order. In</span>
<span class="sd"> explicit mode, the output can be controlled by specifying output subscript</span>
<span class="sd"> labels by adding an arrow (-&gt;) followed by subscripts for the output. For</span>
<span class="sd"> example, “ij,jk-&gt;ik” is equivalent to “ij,jk”. Ellipsis (‘…’) can be used</span>
<span class="sd"> in place of subscripts to broadcast the dimensions. See the TensorRT</span>
<span class="sd"> Developer Guide for more details on equation syntax.</span>
<span class="sd"> Many common operations can be expressed using the Einsum equation. For</span>
<span class="sd"> example:</span>
<span class="sd"> Matrix Transpose: ij-&gt;ji</span>
<span class="sd"> Sum: ij-&gt; Matrix-Matrix</span>
<span class="sd"> Multiplication: ik,kj-&gt;ij</span>
<span class="sd"> Dot Product: i,i-&gt;</span>
<span class="sd"> Matrix-Vector Multiplication: ik,k-&gt;i</span>
<span class="sd"> Batch Matrix Multiplication: ijk,ikl-&gt;ijl</span>
<span class="sd"> Batch Diagonal: …ii-&gt;…i</span>
<span class="sd"> Note that TensorRT does not support ellipsis or diagonal operations so,</span>
<span class="sd"> neither, does TensorRT-LLM.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> einsum_eq : str</span>
<span class="sd"> The Einsum equation.</span>
<span class="sd"> inputs: Sequence[Tensor]</span>
<span class="sd"> The sequence of inputs consumed by the Einsum operation.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by the Einsum operation.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_einsum</span><span class="p">([</span><span class="n">i</span><span class="o">.</span><span class="n">trt_tensor</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">inputs</span><span class="p">],</span>
<span class="n">einsum_eq</span><span class="p">)</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
<div class="viewcode-block" id="permute"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.permute">[docs]</a><span class="k">def</span> <span class="nf">permute</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dims</span><span class="p">:</span> <span class="n">Sequence</span><span class="p">[</span><span class="nb">int</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an operation to permute the dimensions of a tensor.</span>
<span class="sd"> The dimensions of the input tensor are permutted according to the sequence</span>
<span class="sd"> of dimensions in &#39;dims&#39;. That operation maps to tensorrt.IShuffleLayer where</span>
<span class="sd"> the second transposition is described by the indices in &#39;dims&#39;.</span>
<span class="sd"> Given a tensor of rank N, the result of the permutation is a tensor of rank</span>
<span class="sd"> N in which the i-th input dimension maps to the dims[i]-th dimension.</span>
<span class="sd"> For example, permute(input, [1, 0]) will transpose a 2D tensor by permuting</span>
<span class="sd"> the rows and columns.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor to permute.</span>
<span class="sd"> dims : Sequence[int]</span>
<span class="sd"> The description of the permutation.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by the permutation layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">dims</span> <span class="o">=</span> <span class="n">dim_resolve_negative</span><span class="p">(</span><span class="nb">tuple</span><span class="p">(</span><span class="n">dims</span><span class="p">),</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">())</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_shuffle</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
<span class="n">layer</span><span class="o">.</span><span class="n">second_transpose</span> <span class="o">=</span> <span class="n">dims</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
<div class="viewcode-block" id="transpose"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.transpose">[docs]</a><span class="k">def</span> <span class="nf">transpose</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dim0</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">dim1</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an operation to transpose two dimensions of a tensor.</span>
<span class="sd"> That operation produces a tensor in which the dimensions &#39;dim0&#39; and &#39;dim1&#39;</span>
<span class="sd"> are permuted. The other dimensions, if the rank of the tensor is greater</span>
<span class="sd"> than 2, remain untouched.</span>
<span class="sd"> That function is a helper built on the &#39;functional.permute&#39; function.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor to transpose.</span>
<span class="sd"> dim0 : int</span>
<span class="sd"> The first dimension to transpose.</span>
<span class="sd"> dim1 : int</span>
<span class="sd"> The second dimension to transpose.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by the permutation layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">permutation</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()))</span>
<span class="n">permutation</span><span class="p">[</span><span class="n">dim0</span><span class="p">]</span> <span class="o">=</span> <span class="n">dim1</span>
<span class="n">permutation</span><span class="p">[</span><span class="n">dim1</span><span class="p">]</span> <span class="o">=</span> <span class="n">dim0</span>
<span class="k">return</span> <span class="n">permute</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">permutation</span><span class="p">)</span></div>
<div class="viewcode-block" id="view"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.view">[docs]</a><span class="k">def</span> <span class="nf">view</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">shape</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">Sequence</span><span class="p">[</span><span class="nb">int</span><span class="p">]],</span>
<span class="n">zero_is_placeholder</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an operation to create a view of a tensor.</span>
<span class="sd"> That operation adds a tensorrt.IShuffleLayer to the network. If the &#39;shape&#39;</span>
<span class="sd"> parameter is a Tensor, that view is dynamic. Otherwise, it is a static</span>
<span class="sd"> view.</span>
<span class="sd"> Note that TensorRT limits the number of inferred dimensions to 1. It means</span>
<span class="sd"> that the shape sequence or tensor cannot contain more than one -1. This</span>
<span class="sd"> function enforces that constraint and will assert if it is not respected.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor to transpose.</span>
<span class="sd"> shape : Union[Tensor, Sequence[int]]</span>
<span class="sd"> The shape of the new tensor.</span>
<span class="sd"> zero_is_placeholder : bool</span>
<span class="sd"> When that parameter is True, the 0s in &#39;shape&#39; are replaced by the</span>
<span class="sd"> sizes of the corresponding dimensions from the &#39;input&#39;. Otherwise,</span>
<span class="sd"> the dimensions corresponding to 0s are shrinked.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by the view/shuffle layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="c1"># TensorRT demands that at most one dimension is permitted to be specified as -1</span>
<span class="k">def</span> <span class="nf">assert_no_more_than_one_inferred_dim</span><span class="p">(</span><span class="nb">list</span><span class="p">):</span>
<span class="n">inferred_dim_list</span> <span class="o">=</span> <span class="p">[</span><span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">list</span> <span class="k">if</span> <span class="n">i</span> <span class="o">==</span> <span class="o">-</span><span class="mi">1</span><span class="p">]</span>
<span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">inferred_dim_list</span><span class="p">)</span> <span class="o">&lt;=</span> <span class="mi">1</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_shuffle</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
<span class="n">layer</span><span class="o">.</span><span class="n">zero_is_placeholder</span> <span class="o">=</span> <span class="n">zero_is_placeholder</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">shape</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">):</span>
<span class="n">assert_no_more_than_one_inferred_dim</span><span class="p">(</span><span class="n">shape</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">shape</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">shape</span><span class="p">,</span> <span class="p">(</span><span class="nb">list</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">)):</span>
<span class="n">assert_no_more_than_one_inferred_dim</span><span class="p">(</span><span class="n">shape</span><span class="p">)</span>
<span class="n">layer</span><span class="o">.</span><span class="n">reshape_dims</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">(</span><span class="n">shape</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">TypeError</span><span class="p">(</span><span class="s2">&quot;</span><span class="si">%s</span><span class="s2"> is not supported&quot;</span> <span class="o">%</span> <span class="nb">type</span><span class="p">(</span><span class="n">shape</span><span class="p">))</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
<div class="viewcode-block" id="expand_dims"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.expand_dims">[docs]</a><span class="k">def</span> <span class="nf">expand_dims</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">Sequence</span><span class="p">[</span><span class="nb">int</span><span class="p">]])</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an operation to expand the tensor shape with singleton dimensions.</span>
<span class="sd"> That function adds a tensorrt.IShuffleLayer to the network. Given an &#39;input&#39;</span>
<span class="sd"> of rank N and a sequence of M dimensions, the output tensor produced by</span>
<span class="sd"> this operation (when executed by TensorRT) will have a rank of N+M. Singleton</span>
<span class="sd"> dimensions will be inserted at the different positions in &#39;dim&#39;.</span>
<span class="sd"> The pseudo-code for that operation is:</span>
<span class="sd"> new_shape, ii = [], 0</span>
<span class="sd"> for jj in range(input.rank() + len(dim)):</span>
<span class="sd"> new_shape.append(1 if jj in dims else input.shape[ii++])</span>
<span class="sd"> For example, for a tensor of shape [3, 4, 1, 5]</span>
<span class="sd"> expand_dims(input, [0, 2])</span>
<span class="sd"> will produce a tensor of shape [1, 3, 1, 4, 1, 5].</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor to expand.</span>
<span class="sd"> dim : Union[int, Sequence[int]]</span>
<span class="sd"> The positions in the output tensor where to insert singleton</span>
<span class="sd"> dimensions.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by the shuffle layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
<span class="n">dim</span> <span class="o">=</span> <span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="p">)</span>
<span class="n">out_ndim</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span> <span class="o">+</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
<span class="n">input_shape</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">)</span>
<span class="n">out_shapes</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">j</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">out_ndim</span><span class="p">):</span>
<span class="k">if</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">dim</span><span class="p">:</span>
<span class="n">out_shapes</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">out_shapes</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">gather</span><span class="p">(</span><span class="n">input_shape</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">j</span><span class="p">))</span>
<span class="n">j</span> <span class="o">=</span> <span class="n">j</span> <span class="o">+</span> <span class="mi">1</span>
<span class="n">out_shape</span> <span class="o">=</span> <span class="n">concat</span><span class="p">(</span><span class="n">out_shapes</span><span class="p">)</span>
<span class="k">return</span> <span class="n">view</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">out_shape</span><span class="p">)</span></div>
<div class="viewcode-block" id="unsqueeze"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.unsqueeze">[docs]</a><span class="k">def</span> <span class="nf">unsqueeze</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">axis</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an operation to insert a singleton dimension to a tensor.</span>
<span class="sd"> That functions creates an operation that insert a singleton dimension</span>
<span class="sd"> (dimension of size 1) at position &#39;dim&#39; in the output tensor. It works with</span>
<span class="sd"> negative values for the &#39;axis&#39;.</span>
<span class="sd"> For example, for a tensor &#39;input&#39; of shape [4, 4]:</span>
<span class="sd"> unsqueeze(input, 0) will produce an output of shape [1, 4, 4],</span>
<span class="sd"> unsqueeze(input, 1) will produce an output of shape [4, 1, 4],</span>
<span class="sd"> unsqueeze(input, -1) will produce an output of shape [4, 4, 1],</span>
<span class="sd"> unsqueeze(input, -2) will produce an output of shape [4, 1, 4],</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor to expand with a singleton dimension.</span>
<span class="sd"> axis : int</span>
<span class="sd"> The index of the singleton dimension in the output tensor.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by the layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">if</span> <span class="n">axis</span> <span class="o">&lt;</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">axis</span> <span class="o">=</span> <span class="n">axis</span> <span class="o">+</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">+</span> <span class="mi">1</span>
<span class="k">return</span> <span class="n">expand_dims</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">axis</span><span class="p">)</span></div>
<div class="viewcode-block" id="expand_dims_like"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.expand_dims_like">[docs]</a><span class="k">def</span> <span class="nf">expand_dims_like</span><span class="p">(</span><span class="n">left</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">],</span> <span class="n">right</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an operation to expand the first tensor to the same rank as the second</span>
<span class="sd"> tensor.</span>
<span class="sd"> That function takes a first tensor. It also accepts an integer or a float,</span>
<span class="sd"> in which case it creates a constant tensor from it. In both cases, the rank</span>
<span class="sd"> of that first tensor is compared to the rank of the second tensor. If they</span>
<span class="sd"> are of the same rank, the first tensor is returned. Otherwise, the first</span>
<span class="sd"> tensor is expanded on the left to match the rank of the second tensor.</span>
<span class="sd"> Note that the shapes do not have to match, only the rank is considered in</span>
<span class="sd"> that function.</span>
<span class="sd"> For example, for a pair of tensors of shapes [3, 4] and [4, 3, 2], the</span>
<span class="sd"> first tensor will be expanded to a tensor of rank 3 and shape [1, 3, 4].</span>
<span class="sd"> Parameters:</span>
<span class="sd"> left : Union[Tensor, int, float]</span>
<span class="sd"> The first tensor to expand. When a scalar value is provided as a</span>
<span class="sd"> parameter, that function first creates a tensor before expanding it</span>
<span class="sd"> (if needed).</span>
<span class="sd"> right : Tensor</span>
<span class="sd"> The reference tensor to match.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by the shuffle layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">left</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
<span class="n">left</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">int32_array</span><span class="p">([</span><span class="n">left</span><span class="p">]))</span>
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">left</span><span class="p">,</span> <span class="nb">float</span><span class="p">):</span>
<span class="k">if</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">strongly_typed</span><span class="p">:</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">right</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">)</span> <span class="ow">and</span> <span class="n">right</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">HALF</span><span class="p">:</span>
<span class="n">left</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">fp16_array</span><span class="p">([</span><span class="n">left</span><span class="p">]))</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">left</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">fp32_array</span><span class="p">([</span><span class="n">left</span><span class="p">]))</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">left</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">fp32_array</span><span class="p">([</span><span class="n">left</span><span class="p">]))</span>
<span class="n">left_ndim</span> <span class="o">=</span> <span class="n">left</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
<span class="n">right_ndim</span> <span class="o">=</span> <span class="n">right</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
<span class="k">if</span> <span class="n">right_ndim</span> <span class="o">&gt;</span> <span class="n">left_ndim</span><span class="p">:</span>
<span class="n">new_ndim</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">right_ndim</span> <span class="o">-</span> <span class="n">left_ndim</span><span class="p">))</span>
<span class="k">return</span> <span class="n">expand_dims</span><span class="p">(</span><span class="n">left</span><span class="p">,</span> <span class="n">new_ndim</span><span class="p">)</span>
<span class="k">return</span> <span class="n">left</span></div>
<span class="c1"># If dim is None, return a 1-D TensorRT-LLM tensor of the size</span>
<span class="c1"># If dim is not None, return a 0-D TensorRT-LLM tensor of the dimension size</span>
<div class="viewcode-block" id="shape"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.shape">[docs]</a><span class="k">def</span> <span class="nf">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an operation to create a shape tensor.</span>
<span class="sd"> The shape tensor can either be the shape of the input tensor when the</span>
<span class="sd"> parameter dim is None or a scalar (tensor of rank 0) that corresponds to</span>
<span class="sd"> the size of dim-th dimension.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor from which we want to extract the shape or the</span>
<span class="sd"> size in one dimension.</span>
<span class="sd"> dim : Optional[int]</span>
<span class="sd"> The dimension from which to extract the size. If it is None, the</span>
<span class="sd"> entire shape of the input tensor is returned.</span>
<span class="sd"> Returns:</span>
<span class="sd"> A tensor that contains the shape of the input tensor (if &#39;dim&#39; is None)</span>
<span class="sd"> or the size in the dimension &#39;dim&#39; of the input tensor. If &#39;dim&#39; is</span>
<span class="sd"> &#39;None&#39;, that tensor has the same rank as the input tensor, otherwise</span>
<span class="sd"> its rank is 0.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_shape</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
<span class="n">res</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
<span class="k">if</span> <span class="n">dim</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">return</span> <span class="n">res</span>
<span class="k">return</span> <span class="n">gather</span><span class="p">(</span><span class="n">res</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">indices</span><span class="o">=</span><span class="n">dim</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">([])</span></div>
<div class="viewcode-block" id="gather"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.gather">[docs]</a><span class="k">def</span> <span class="nf">gather</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">indices</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an operation to gather elements from a tensor.</span>
<span class="sd"> That function implements the GatherElements operator from the ONNX</span>
<span class="sd"> specification as described in</span>
<span class="sd"> https://github.com/onnx/onnx/blob/main/docs/Operators.md#GatherElements</span>
<span class="sd"> The input and indices arguments must have the same rank &gt;= 1. The operation</span>
<span class="sd"> will produce a tensor with the same shape as the indices tensor. The axis</span>
<span class="sd"> is the dimension to gather on.</span>
<span class="sd"> As shown in the ONNX description, for a 3D tensor, the output is:</span>
<span class="sd"> out[i][j][k] = input[indices[i][j][k]][j][k] if axis = 0,</span>
<span class="sd"> out[i][j][k] = input[i][indices[i][j][k]][k] if axis = 1,</span>
<span class="sd"> out[i][j][k] = input[i][j][indices[i][j][k]] if axis = 2.</span>
<span class="sd"> For example,</span>
<span class="sd"> gather([[4, 2], [5, 3]], 0, [[1, 0], [0, 1]])</span>
<span class="sd"> will produce [[5, 2], [4, 3]].</span>
<span class="sd"> gather([[1, 2, 3], [4, 5, 6], 1, [[1], [0]])</span>
<span class="sd"> will produce [[2], [4]]. See the ONNX documentation for more examples.</span>
<span class="sd"> That operation maps to the TensorRT IGatherLayer.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor to gather elements from.</span>
<span class="sd"> dim : int</span>
<span class="sd"> The dimension to gather on.</span>
<span class="sd"> indices : Union[Tensor, int]</span>
<span class="sd"> The positions in the &#39;dim&#39; dimension to gather from.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor containing the gathered elements. It has the same shape as</span>
<span class="sd"> the indices tensor.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">indices</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
<span class="n">indices</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">int32_array</span><span class="p">([</span><span class="n">indices</span><span class="p">]))</span>
<span class="c1"># The input and indices tensors must have the same rank.</span>
<span class="k">assert</span> <span class="nb">input</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span> <span class="o">==</span> <span class="n">indices</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_gather_v2</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
<span class="n">indices</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
<span class="n">mode</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">GatherMode</span><span class="o">.</span><span class="n">ELEMENT</span><span class="p">)</span>
<span class="k">if</span> <span class="n">dim</span> <span class="o">&lt;</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">dim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">+</span> <span class="n">dim</span>
<span class="n">layer</span><span class="o">.</span><span class="n">axis</span> <span class="o">=</span> <span class="n">dim</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
<div class="viewcode-block" id="select"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.select">[docs]</a><span class="k">def</span> <span class="nf">select</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">index</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an operation to select a slice of elements from a tensor.</span>
<span class="sd"> Given an input tensor, that function creates an operation that selects the</span>
<span class="sd"> index-th slice of elements in the dimension &#39;dim&#39; to create a new tensor.</span>
<span class="sd"> The output tensor has a shape in which the input dimension &#39;dim&#39; is</span>
<span class="sd"> removed.</span>
<span class="sd"> The &#39;index&#39; can either be an integer or a 1D tensor containing a single</span>
<span class="sd"> element.</span>
<span class="sd"> For example, on input=[[4, 2, 5], [2, 1, 2], [4, 7, 1]], which has a shape</span>
<span class="sd"> [3, 3],</span>
<span class="sd"> select(input, 0, 1)</span>
<span class="sd"> will create a tensor of shape [3] that contains the [2, 1, 2].</span>
<span class="sd"> Regarding the shape of the output tensor, the dimension &#39;dim&#39; is removed.</span>
<span class="sd"> It means that for a tensor of shape [4, 2, 6, 3],</span>
<span class="sd"> select(input, 2, 4)</span>
<span class="sd"> will select the 5th slice (index == 4) from the 3rd dimension (dim == 2)</span>
<span class="sd"> and return a tensor of shape [4, 2, 3] (i.e. the 3rd dimension is removed).</span>
<span class="sd"> That operation maps to the TensorRT IGatherLayer.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor to select from.</span>
<span class="sd"> dim : int</span>
<span class="sd"> The dimension to select from.</span>
<span class="sd"> index : Union[Tensor, int]</span>
<span class="sd"> The index of the slice in the &#39;dim&#39; dimension to select.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor containing the selected slice.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">index</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
<span class="n">index</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">int32_array</span><span class="p">([</span><span class="n">index</span><span class="p">]))</span>
<span class="k">assert</span> <span class="n">index</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">and</span> <span class="n">index</span><span class="o">.</span><span class="n">size</span><span class="p">(</span>
<span class="mi">0</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;index should have rank 1, got </span><span class="si">{</span><span class="n">index</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span><span class="si">}</span><span class="s2">&quot;</span>
<span class="n">new_shape</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">rank</span><span class="p">()):</span>
<span class="k">if</span> <span class="n">i</span> <span class="o">!=</span> <span class="n">dim</span><span class="p">:</span>
<span class="n">new_shape</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">i</span><span class="p">))</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_gather</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">index</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">dim</span><span class="p">)</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">concat</span><span class="p">(</span><span class="n">new_shape</span><span class="p">))</span></div>
<div class="viewcode-block" id="index_select"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.index_select">[docs]</a><span class="k">def</span> <span class="nf">index_select</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">index</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an operation to select slices of elements from a tensor.</span>
<span class="sd"> Given an input tensor, that function creates an operation that selects the</span>
<span class="sd"> slices of elements in the dimension &#39;dim&#39; at the indices listed in &#39;index&#39;</span>
<span class="sd"> to create a new tensor. The output tensor has the same rank as the input</span>
<span class="sd"> tensor.</span>
<span class="sd"> The &#39;index&#39; is a tensor of rank 1.</span>
<span class="sd"> For example, on input=[[4, 2, 5], [2, 1, 2], [4, 7, 1]], which has a shape</span>
<span class="sd"> [3, 3],</span>
<span class="sd"> index_select(input, 0, [0, 1])</span>
<span class="sd"> will create a tensor of shape [3, 2] that contains the [[4, 2, 5], [2, 1, 2]].</span>
<span class="sd"> Regarding the shape of the output tensor, the dimension &#39;dim&#39; has the same</span>
<span class="sd"> size as the &#39;index&#39; tensor. It means that for a tensor of shape [4, 2, 6, 3],</span>
<span class="sd"> index_select(input, 2, [1, 4])</span>
<span class="sd"> will select the 2nd and 5th slices (index == 1 or 4) from the 3rd dimension</span>
<span class="sd"> (dim == 2) and return a tensor of shape [4, 2, 2, 3] (i.e. the 3rd</span>
<span class="sd"> dimension is shrinked to 2).</span>
<span class="sd"> Note that this operation can also be used to expand a tensor in the &#39;dim&#39;</span>
<span class="sd"> dimension, for example, on input [[0, 1], [2, 3]],</span>
<span class="sd"> index_select(input, 1, [0, 0, 0])</span>
<span class="sd"> will produce a tensor of shape [2, 3] containing [[0, 0, 0], [2, 2, 2]].</span>
<span class="sd"> That operation maps to the TensorRT IGatherLayer.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor to select from.</span>
<span class="sd"> dim : int</span>
<span class="sd"> The dimension to select from.</span>
<span class="sd"> index : Tensor</span>
<span class="sd"> The indices of the slices in the &#39;dim&#39; dimension to select.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor containing the selected slices.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">assert</span> <span class="n">index</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;index should have rank 1, got </span><span class="si">{</span><span class="n">index</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span><span class="si">}</span><span class="s2">&quot;</span>
<span class="n">new_shape</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">rank</span><span class="p">()):</span>
<span class="k">if</span> <span class="n">i</span> <span class="o">!=</span> <span class="n">dim</span><span class="p">:</span>
<span class="n">new_shape</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">i</span><span class="p">))</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">new_shape</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">shape</span><span class="p">(</span><span class="n">index</span><span class="p">,</span> <span class="mi">0</span><span class="p">))</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_gather</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">index</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">dim</span><span class="p">)</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">concat</span><span class="p">(</span><span class="n">new_shape</span><span class="p">))</span></div>
<div class="viewcode-block" id="concat"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.concat">[docs]</a><span class="k">def</span> <span class="nf">concat</span><span class="p">(</span><span class="n">inputs</span><span class="p">:</span> <span class="n">Sequence</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">]],</span> <span class="n">dim</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an operation to concatenate tensors.</span>
<span class="sd"> The function creates an operation that concatenates the tensors from the</span>
<span class="sd"> sequence &#39;inputs&#39;. The concatenation is done along the dimension &#39;dim&#39;.</span>
<span class="sd"> All the tensors in &#39;inputs&#39; must have the same shape expect for the</span>
<span class="sd"> dimension &#39;dim&#39;.</span>
<span class="sd"> for ii in range(inputs[0].rank()):</span>
<span class="sd"> assert (ii == dim) or all(inp.shape[ii] == inputs[0].shape[ii] for inp in inputs)</span>
<span class="sd"> The shape of the output tensor is defined as:</span>
<span class="sd"> for ii in range(inputs[0].rank()):</span>
<span class="sd"> # Same size as all the inputs in dimension ii != dim.</span>
<span class="sd"> output.shape[ii] = inputs[0].shape[ii]</span>
<span class="sd"> # Sum of the sizes in the different inputs in dimension &#39;dim&#39;.</span>
<span class="sd"> if ii == dim:</span>
<span class="sd"> for jj in range(1, len(inputs)):</span>
<span class="sd"> output.shape[ii] += inputs[jj].shape[ii]</span>
<span class="sd"> For example, given a sequence of two 2D tensors [[0, 1], [2, 3]] and</span>
<span class="sd"> [[4, 5], [6, 7]] both of shape [2, 2],</span>
<span class="sd"> concat(inputs, 0)</span>
<span class="sd"> will produce [[[0, 1], [2, 3]], [[4, 5], [6, 7]]] of shape [4, 2] and</span>
<span class="sd"> concat(inputs, 1)</span>
<span class="sd"> will produce [[0, 1, 4, 5], [2, 3, 6, 7]] of shape [2, 4].</span>
<span class="sd"> Parameters:</span>
<span class="sd"> inputs : Sequence[Union[Tensor, int]]</span>
<span class="sd"> The sequence of tensors to concatenate. For integers, that function</span>
<span class="sd"> creates constant tensors.</span>
<span class="sd"> dim : int</span>
<span class="sd"> The dimension in which the concatenation is performed.</span>
<span class="sd"> Returns:</span>
<span class="sd"> A tensor that contains the concatenation of the tensors.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">tmp</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">inputs</span><span class="p">:</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
<span class="n">tmp</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">constant</span><span class="p">(</span><span class="n">int32_array</span><span class="p">([</span><span class="n">i</span><span class="p">])))</span>
<span class="k">elif</span> <span class="n">i</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">tmp</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">i</span><span class="o">.</span><span class="n">view</span><span class="p">([</span><span class="mi">1</span><span class="p">]))</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">tmp</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">i</span><span class="p">)</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_concatenation</span><span class="p">([</span><span class="n">i</span><span class="o">.</span><span class="n">trt_tensor</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">tmp</span><span class="p">])</span>
<span class="n">layer</span><span class="o">.</span><span class="n">axis</span> <span class="o">=</span> <span class="n">dim</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
<div class="viewcode-block" id="softmax"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.softmax">[docs]</a><span class="k">def</span> <span class="nf">softmax</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an operation to compute softmax on a tensor.</span>
<span class="sd"> That operation computes the softmax on the input tensor in the dimension</span>
<span class="sd"> &#39;dim&#39; if specified. Otherwise, it is applied on the last dimension.</span>
<span class="sd"> It inserts a ISoftmaxLayer to the TensorRT graph.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor on which to apply softmax.</span>
<span class="sd"> dim : Optional[int]</span>
<span class="sd"> The dimension used to apply softmax.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The output tensor of the softmax layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">if</span> <span class="n">dim</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">dim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">-</span> <span class="mi">1</span>
<span class="k">if</span> <span class="n">dim</span> <span class="o">&lt;</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">dim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">+</span> <span class="n">dim</span>
<span class="n">axes</span> <span class="o">=</span> <span class="n">dim_to_trt_axes</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_softmax</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
<span class="n">layer</span><span class="o">.</span><span class="n">axes</span> <span class="o">=</span> <span class="n">axes</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
<span class="k">def</span> <span class="nf">_lookup_plugin</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">weight</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">rank</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an operation to perform lookup in a tensor.</span>
<span class="sd"> That operation performs the lookup needed by embedding layers. Given a</span>
<span class="sd"> &#39;weight&#39; tensor of shape [rows, cols], it produces a tensor of shape</span>
<span class="sd"> [inputs.size(0), cols] where the ith row corresponds to the input[i] row in</span>
<span class="sd"> the weight tensor.</span>
<span class="sd"> It inserts a IPluginV2Layer.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor the contains the indices to perform the lookup.</span>
<span class="sd"> weight : Tensor</span>
<span class="sd"> The table to gather from.</span>
<span class="sd"> rank : int</span>
<span class="sd"> The mpi rank.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The output tensor of the lookup layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
<span class="s1">&#39;Lookup&#39;</span><span class="p">,</span> <span class="s1">&#39;1&#39;</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="n">p_dtype</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">lookup_plugin</span>
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;type_id&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">))],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">rank</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;rank&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">rank</span><span class="p">)],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span><span class="n">pf_type</span><span class="p">,</span> <span class="n">rank</span><span class="p">])</span>
<span class="n">lookup_plug</span> <span class="o">=</span> <span class="n">plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">&quot;lookup&quot;</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">weight</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">]</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">lookup_plug</span><span class="p">)</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
<div class="viewcode-block" id="embedding"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.embedding">[docs]</a><span class="k">def</span> <span class="nf">embedding</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">weight</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">tp_size</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
<span class="n">tp_group</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">sharding_dim</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
<span class="n">tp_rank</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an operation to perform embedding lookup.</span>
<span class="sd"> That operation performs the embedding lookup. The &#39;input&#39; tensor contains</span>
<span class="sd"> the identifiers of the rows of &#39;weight&#39; to gather.</span>
<span class="sd"> 1. Distribute the embedding lookup table over multiple GPU</span>
<span class="sd"> When &#39;tp_size&#39; is greater than 1 and the &#39;tp_group&#39; is defined, this</span>
<span class="sd"> embedding lookup is distributed among multiple GPUs.</span>
<span class="sd"> When &#39;sharding_dim==0&#39;, each GPU stores a subset of the rows of the embedding</span>
<span class="sd"> table rows(that number of rows per GPU is given by weights.shape[0] and the offset to</span>
<span class="sd"> the 1st row stored on the GPU is given by rank * weights.shape[0]). Each</span>
<span class="sd"> parallel rank will query all the indices and set 0s for the weights that</span>
<span class="sd"> are not stored on the associated GPU. To compute the final result, a</span>
<span class="sd"> parallel all-reduce operation is added to the TensorRT graph. That lookup</span>
<span class="sd"> can be performed using either the plugin or the operators TensorRT support.</span>
<span class="sd"> When&#39;sharding_dim==1&#39;, each GPU stores a subset of the embedding table&#39;s columns.</span>
<span class="sd"> Each rank can obtain a portion of the embedding results.</span>
<span class="sd"> Then the embedding is collected using the all-gather operation.</span>
<span class="sd"> Related transposition operations are also used to obtain the final results.</span>
<span class="sd"> 2. Store embedding lookup table as a whole</span>
<span class="sd"> When &#39;tp_size&#39; is not greater than 1, the embedding lookup table will not</span>
<span class="sd"> be divided. In this case, when the default_net().plugin_config.lookup_plugin is set,</span>
<span class="sd"> the operation is implemented using a plugin (without the all-reduce operation).</span>
<span class="sd"> Otherwise, this operation is implemented using the standard IGatherLayer in TensorRT.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor the contains the indices to perform the lookup.</span>
<span class="sd"> weight : Tensor</span>
<span class="sd"> The table to gather from.</span>
<span class="sd"> tp_size : int</span>
<span class="sd"> The number of GPUs collaborating to perform that embedding.</span>
<span class="sd"> tg_group : Optional[List[int]]</span>
<span class="sd"> The group of world ranks participating in the all-reduce when</span>
<span class="sd"> tp_size &gt; 1.</span>
<span class="sd"> sharding_dim : int</span>
<span class="sd"> sharding_dim = 0 means that we shard the embedding table in vocab dim;</span>
<span class="sd"> sharding_dim = 1 means that we shard the embedding table in embedding dim.</span>
<span class="sd"> tp_rank : int</span>
<span class="sd"> The tensor parallelism rank. Used to calculate offset in TP on vocab dim.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by the embedding lookup layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="c1"># Distribute embedding lookup table accross multiple GPU</span>
<span class="k">if</span> <span class="n">tp_size</span> <span class="o">&gt;</span> <span class="mi">1</span> <span class="ow">and</span> <span class="n">tp_group</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">if</span> <span class="n">sharding_dim</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span> <span class="c1"># TP on vocab_size dimension</span>
<span class="k">if</span> <span class="n">tp_rank</span> <span class="o">==</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
<span class="s2">&quot;Rank cannot be none for tensor parallelism on vocab dim&quot;</span><span class="p">)</span>
<span class="k">if</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">lookup_plugin</span><span class="p">:</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">_lookup_plugin</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">tp_rank</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">allreduce</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">tp_group</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">shape_weight</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="n">weight</span><span class="p">)</span>
<span class="n">vocab_size</span> <span class="o">=</span> <span class="nb">slice</span><span class="p">(</span><span class="n">shape_weight</span><span class="p">,</span> <span class="n">starts</span><span class="o">=</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">sizes</span><span class="o">=</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
<span class="n">tmp_input</span> <span class="o">=</span> <span class="nb">input</span> <span class="o">-</span> <span class="n">vocab_size</span> <span class="o">*</span> <span class="n">tp_rank</span>
<span class="c1"># Identify the valid indices</span>
<span class="n">is_qualified</span> <span class="o">=</span> <span class="n">op_and</span><span class="p">(</span><span class="n">tmp_input</span> <span class="o">&gt;=</span> <span class="mi">0</span><span class="p">,</span> <span class="n">tmp_input</span> <span class="o">&lt;</span> <span class="n">vocab_size</span><span class="p">)</span>
<span class="n">is_qualified_expand</span> <span class="o">=</span> <span class="n">expand_dims</span><span class="p">(</span><span class="n">is_qualified</span><span class="p">,</span>
<span class="p">[</span><span class="n">is_qualified</span><span class="o">.</span><span class="n">ndim</span><span class="p">()])</span>
<span class="c1"># Replace the invalid ones to zero</span>
<span class="n">placeholder_input</span> <span class="o">=</span> <span class="n">where</span><span class="p">(</span><span class="n">is_qualified</span><span class="p">,</span> <span class="n">tmp_input</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
<span class="c1"># Get the temporal results</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_gather</span><span class="p">(</span>
<span class="n">weight</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">placeholder_input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
<span class="n">tmp_output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
<span class="c1"># Set zero for invalid results</span>
<span class="n">placeholder_tmp</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">is_qualified_expand</span><span class="p">,</span> <span class="n">tmp_output</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
<span class="n">placeholder</span> <span class="o">=</span> <span class="n">placeholder_tmp</span> <span class="o">-</span> <span class="n">placeholder_tmp</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">where</span><span class="p">(</span><span class="n">is_qualified_expand</span><span class="p">,</span> <span class="n">tmp_output</span><span class="p">,</span> <span class="n">placeholder</span><span class="p">)</span>
<span class="c1"># Use all reduce to collect the results</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">allreduce</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">tp_group</span><span class="p">)</span>
<span class="k">elif</span> <span class="n">sharding_dim</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span> <span class="c1"># TP on hidden dimension</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_gather</span><span class="p">(</span><span class="n">weight</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
<span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
<span class="c1"># 1. [dim0, local_dim] -&gt; [dim0 * tp_size, local_dim]</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">allgather</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">tp_group</span><span class="p">)</span>
<span class="c1"># 2. [dim0 * tp_size, local_dim] -&gt; [dim0, local_dim * tp_size]</span>
<span class="c1"># 2.1 split</span>
<span class="n">split_size</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="o">/</span> <span class="n">tp_size</span>
<span class="n">ndim</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
<span class="n">starts</span> <span class="o">=</span> <span class="p">[</span><span class="n">constant</span><span class="p">(</span><span class="n">int32_array</span><span class="p">([</span><span class="mi">0</span><span class="p">]))</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)]</span>
<span class="n">sizes</span> <span class="o">=</span> <span class="p">[</span><span class="n">shape</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="n">d</span><span class="p">)</span> <span class="k">for</span> <span class="n">d</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)]</span>
<span class="n">sizes</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">split_size</span>
<span class="n">sections</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">tp_size</span><span class="p">):</span>
<span class="n">starts</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">split_size</span> <span class="o">*</span> <span class="n">i</span>
<span class="n">sections</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="nb">slice</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">concat</span><span class="p">(</span><span class="n">starts</span><span class="p">),</span> <span class="n">concat</span><span class="p">(</span><span class="n">sizes</span><span class="p">)))</span>
<span class="c1"># 2.2 concat</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">concat</span><span class="p">(</span><span class="n">sections</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">-</span> <span class="mi">1</span><span class="p">))</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
<span class="s1">&#39;Tensor Parallelism only support splitting Embedding lookup along hidden (sharding_dim==1) and vocab (sharding_dim==0) dimensionis&#39;</span>
<span class="p">)</span>
<span class="c1"># Store embedding lookup table as a whole</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">if</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">lookup_plugin</span><span class="p">:</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">_lookup_plugin</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">rank</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_gather</span><span class="p">(</span><span class="n">weight</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
<span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
<span class="k">return</span> <span class="n">x</span></div>
<div class="viewcode-block" id="constant_to_tensor_"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.constant_to_tensor_">[docs]</a><span class="k">def</span> <span class="nf">constant_to_tensor_</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">],</span>
<span class="n">dtype</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
<span class="k">return</span> <span class="n">constant</span><span class="p">(</span><span class="n">int32_array</span><span class="p">([</span><span class="nb">input</span><span class="p">]))</span>
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="nb">float</span><span class="p">):</span>
<span class="k">assert</span> <span class="n">dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">float32</span> <span class="ow">or</span> <span class="n">dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">float16</span>
<span class="k">if</span> <span class="n">dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">float32</span><span class="p">:</span>
<span class="k">return</span> <span class="n">constant</span><span class="p">(</span><span class="n">fp32_array</span><span class="p">([</span><span class="nb">input</span><span class="p">]))</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">return</span> <span class="n">constant</span><span class="p">(</span><span class="n">fp16_array</span><span class="p">([</span><span class="nb">input</span><span class="p">]))</span>
<span class="k">return</span> <span class="nb">input</span></div>
<div class="viewcode-block" id="broadcast_helper"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.broadcast_helper">[docs]</a><span class="k">def</span> <span class="nf">broadcast_helper</span><span class="p">(</span><span class="n">left</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">],</span>
<span class="n">right</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">]:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Helper function to perform a broadcast.</span>
<span class="sd"> For each input, that function first creates a constant tensor if the input</span>
<span class="sd"> is an integer or a float. Then, if needed, it expands the smaller tensor to</span>
<span class="sd"> make sure its rank is the same as the larger one.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> left : Union[Tensor, int, float]</span>
<span class="sd"> The first input. If that input is an integer or a float, the</span>
<span class="sd"> function creates a constant tensor.</span>
<span class="sd"> right : Union[Tensor, int, float]</span>
<span class="sd"> The second input. If that input is an integer or a float, the</span>
<span class="sd"> function creates a constant tensor.</span>
<span class="sd"> Returns:</span>
<span class="sd"> A pair of tensors of same rank.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">strongly_typed</span><span class="p">:</span>
<span class="n">left</span> <span class="o">=</span> <span class="n">constant_to_tensor_</span><span class="p">(</span><span class="n">left</span><span class="p">)</span>
<span class="n">right</span> <span class="o">=</span> <span class="n">constant_to_tensor_</span><span class="p">(</span><span class="n">right</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">left</span> <span class="o">=</span> <span class="n">constant_to_tensor_</span><span class="p">(</span>
<span class="n">left</span><span class="p">,</span> <span class="n">right</span><span class="o">.</span><span class="n">dtype</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">right</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">)</span> <span class="k">else</span> <span class="n">trt</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="n">right</span> <span class="o">=</span> <span class="n">constant_to_tensor_</span><span class="p">(</span><span class="n">right</span><span class="p">,</span> <span class="n">left</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
<span class="k">if</span> <span class="n">left</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span> <span class="o">==</span> <span class="n">right</span><span class="o">.</span><span class="n">rank</span><span class="p">():</span>
<span class="k">return</span> <span class="p">(</span><span class="n">left</span><span class="p">,</span> <span class="n">right</span><span class="p">)</span>
<span class="k">if</span> <span class="n">left</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span> <span class="o">&lt;</span> <span class="n">right</span><span class="o">.</span><span class="n">rank</span><span class="p">():</span>
<span class="n">left</span> <span class="o">=</span> <span class="n">expand_dims_like</span><span class="p">(</span><span class="n">left</span><span class="p">,</span> <span class="n">right</span><span class="p">)</span>
<span class="k">return</span> <span class="p">(</span><span class="n">left</span><span class="p">,</span> <span class="n">right</span><span class="p">)</span>
<span class="k">if</span> <span class="n">left</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span> <span class="o">&gt;</span> <span class="n">right</span><span class="o">.</span><span class="n">rank</span><span class="p">():</span>
<span class="n">right</span> <span class="o">=</span> <span class="n">expand_dims_like</span><span class="p">(</span><span class="n">right</span><span class="p">,</span> <span class="n">left</span><span class="p">)</span>
<span class="k">return</span> <span class="p">(</span><span class="n">left</span><span class="p">,</span> <span class="n">right</span><span class="p">)</span></div>
<div class="viewcode-block" id="elementwise_binary"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.elementwise_binary">[docs]</a><span class="k">def</span> <span class="nf">elementwise_binary</span><span class="p">(</span><span class="n">left</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span>
<span class="nb">float</span><span class="p">],</span> <span class="n">right</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">],</span>
<span class="n">op</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an elementwise operation with two inputs.</span>
<span class="sd"> For each input, that function first creates a constant tensor if the input</span>
<span class="sd"> is an integer or a float. Then, if needed, it expands the smaller tensor to</span>
<span class="sd"> make sure its rank is the same as the larger one. Then, it performs the</span>
<span class="sd"> elementwise operation &#39;op&#39;.</span>
<span class="sd"> The following closures are defined in functional.*:</span>
<span class="sd"> add for op=trt.ElementWiseOperation.SUM</span>
<span class="sd"> sub for op=trt.ElementWiseOperation.SUB</span>
<span class="sd"> mul for op=trt.ElementWiseOperation.PROD</span>
<span class="sd"> div for op=trt.ElementWiseOperation.DIV</span>
<span class="sd"> gt for op=trt.ElementWiseOperation.GREATER</span>
<span class="sd"> lt for op=trt.ElementWiseOperation.LESS</span>
<span class="sd"> op_and for op=trt.ElementWiseOperation.AND</span>
<span class="sd"> op_or for op=trt.ElementWiseOperation.OR</span>
<span class="sd"> eq for op=trt.ElementWiseOperation.EQUAL</span>
<span class="sd"> minimum for op=trt.ElementWiseOperation.MIN</span>
<span class="sd"> maximum for op=trt.ElementWiseOperation.MAX</span>
<span class="sd"> pow for op=trt.ElementWiseOperation.POW</span>
<span class="sd"> It is implemented using the IElementWiseLayer from TensorRT.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> left : Union[Tensor, int, float]</span>
<span class="sd"> The first input. If that input is an integer or a float, the</span>
<span class="sd"> function creates a constant tensor.</span>
<span class="sd"> right : Union[Tensor, int, float]</span>
<span class="sd"> The second input. If that input is an integer or a float, the</span>
<span class="sd"> function creates a constant tensor.</span>
<span class="sd"> op : trt.ElementWiseOperation</span>
<span class="sd"> The binary operation to perform.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by this elementwise operation.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">left</span><span class="p">,</span> <span class="n">right</span> <span class="o">=</span> <span class="n">broadcast_helper</span><span class="p">(</span><span class="n">left</span><span class="p">,</span> <span class="n">right</span><span class="p">)</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_elementwise</span><span class="p">(</span><span class="n">left</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">right</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
<span class="n">op</span><span class="p">)</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
<span class="n">add</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">SUM</span><span class="p">)</span>
<span class="n">sub</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">SUB</span><span class="p">)</span>
<span class="n">mul</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">PROD</span><span class="p">)</span>
<span class="n">div</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">DIV</span><span class="p">)</span>
<span class="n">gt</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">GREATER</span><span class="p">)</span>
<span class="n">lt</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">LESS</span><span class="p">)</span>
<span class="n">op_and</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">AND</span><span class="p">)</span>
<span class="n">op_or</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">OR</span><span class="p">)</span>
<span class="n">eq</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">EQUAL</span><span class="p">)</span>
<span class="n">minimum</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">MIN</span><span class="p">)</span>
<span class="n">maximum</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">MAX</span><span class="p">)</span>
<span class="nb">pow</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">POW</span><span class="p">)</span>
<div class="viewcode-block" id="where"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.where">[docs]</a><span class="k">def</span> <span class="nf">where</span><span class="p">(</span><span class="n">condition</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">],</span> <span class="n">left</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">],</span>
<span class="n">right</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add a where (aka select or if-then-else) operation.</span>
<span class="sd"> Assuming the three input parameters have the same shape, that function creates</span>
<span class="sd"> the operation to compute a tensor of the same shape such that:</span>
<span class="sd"> for ii in range(mul(condition.shape)):</span>
<span class="sd"> output[ii] = left[ii] if condition[ii] else right[ii]</span>
<span class="sd"> For each input, that function first creates a constant tensor if the input</span>
<span class="sd"> is an integer or a float. Then, if needed, it expands the smaller tensor to</span>
<span class="sd"> make sure its rank is the same as the larger one. Then, it performs the</span>
<span class="sd"> selection.</span>
<span class="sd"> It is implemented using the ISelectLayer from TensorRT.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> left : Union[Tensor, int, float]</span>
<span class="sd"> The condition. If that input is an integer or a float, the function</span>
<span class="sd"> creates a constant tensor.</span>
<span class="sd"> left : Union[Tensor, int, float]</span>
<span class="sd"> The first input. If that input is an integer or a float, the</span>
<span class="sd"> function creates a constant tensor.</span>
<span class="sd"> right : Union[Tensor, int, float]</span>
<span class="sd"> The second input. If that input is an integer or a float, the</span>
<span class="sd"> function creates a constant tensor.</span>
<span class="sd"> op : trt.ElementWiseOperation</span>
<span class="sd"> The binary operation to perform.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by this select operation.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="c1"># Convert to tensors.</span>
<span class="n">condition</span> <span class="o">=</span> <span class="n">constant_to_tensor_</span><span class="p">(</span><span class="n">condition</span><span class="p">)</span>
<span class="n">left</span> <span class="o">=</span> <span class="n">constant_to_tensor_</span><span class="p">(</span><span class="n">left</span><span class="p">)</span>
<span class="n">right</span> <span class="o">=</span> <span class="n">constant_to_tensor_</span><span class="p">(</span><span class="n">right</span><span class="p">)</span>
<span class="c1"># Find the tensor with the largest rank of the three.</span>
<span class="n">largest</span> <span class="o">=</span> <span class="n">condition</span>
<span class="k">if</span> <span class="n">largest</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span> <span class="o">&lt;</span> <span class="n">left</span><span class="o">.</span><span class="n">rank</span><span class="p">():</span>
<span class="n">largest</span> <span class="o">=</span> <span class="n">left</span>
<span class="k">if</span> <span class="n">largest</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span> <span class="o">&lt;</span> <span class="n">right</span><span class="o">.</span><span class="n">rank</span><span class="p">():</span>
<span class="n">largest</span> <span class="o">=</span> <span class="n">right</span>
<span class="c1"># Expand the tensors to match the largest one.</span>
<span class="k">if</span> <span class="n">condition</span> <span class="ow">is</span> <span class="ow">not</span> <span class="n">largest</span><span class="p">:</span>
<span class="n">condition</span> <span class="o">=</span> <span class="n">expand_dims_like</span><span class="p">(</span><span class="n">condition</span><span class="p">,</span> <span class="n">largest</span><span class="p">)</span>
<span class="k">if</span> <span class="n">left</span> <span class="ow">is</span> <span class="ow">not</span> <span class="n">largest</span><span class="p">:</span>
<span class="n">left</span> <span class="o">=</span> <span class="n">expand_dims_like</span><span class="p">(</span><span class="n">left</span><span class="p">,</span> <span class="n">largest</span><span class="p">)</span>
<span class="k">if</span> <span class="n">right</span> <span class="ow">is</span> <span class="ow">not</span> <span class="n">largest</span><span class="p">:</span>
<span class="n">right</span> <span class="o">=</span> <span class="n">expand_dims_like</span><span class="p">(</span><span class="n">right</span><span class="p">,</span> <span class="n">largest</span><span class="p">)</span>
<span class="c1"># Insert the operation.</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_select</span><span class="p">(</span><span class="n">condition</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">left</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
<span class="n">right</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
<div class="viewcode-block" id="unary"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.unary">[docs]</a><span class="k">def</span> <span class="nf">unary</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">op</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">UnaryOperation</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an elementwise operation on a single input.</span>
<span class="sd"> The following closures are defined in functional.*:</span>
<span class="sd"> round for op=trt.UnaryOperation.ROUND</span>
<span class="sd"> sqrt for op=trt.UnaryOperation.SQRT</span>
<span class="sd"> exp for op=trt.UnaryOperation.EXP</span>
<span class="sd"> sin for op=trt.UnaryOperation.SIN</span>
<span class="sd"> cos for op=trt.UnaryOperation.COS</span>
<span class="sd"> abs for op=trt.UnaryOperation.ABS</span>
<span class="sd"> It is implemented using the IUnaryLayer from TensorRT.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor.</span>
<span class="sd"> op : trt.UnaryOperation</span>
<span class="sd"> The unary operation to perform.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by this elementwise operation.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_unary</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">op</span><span class="p">)</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
<span class="nb">round</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">unary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">UnaryOperation</span><span class="o">.</span><span class="n">ROUND</span><span class="p">)</span>
<span class="n">sqrt</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">unary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">UnaryOperation</span><span class="o">.</span><span class="n">SQRT</span><span class="p">)</span>
<span class="n">exp</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">unary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">UnaryOperation</span><span class="o">.</span><span class="n">EXP</span><span class="p">)</span>
<span class="n">sin</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">unary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">UnaryOperation</span><span class="o">.</span><span class="n">SIN</span><span class="p">)</span>
<span class="n">cos</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">unary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">UnaryOperation</span><span class="o">.</span><span class="n">COS</span><span class="p">)</span>
<span class="nb">abs</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">unary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">UnaryOperation</span><span class="o">.</span><span class="n">ABS</span><span class="p">)</span>
<div class="viewcode-block" id="mean"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.mean">[docs]</a><span class="k">def</span> <span class="nf">mean</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">keepdim</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an operation to compute the mean along a dimension.</span>
<span class="sd"> Computes the mean along the dimension &#39;dim&#39; of the input tensor.</span>
<span class="sd"> It is implemented using the IReduceLayer from TensorRT.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor.</span>
<span class="sd"> dim : int</span>
<span class="sd"> The dimension along which the mean is computed.</span>
<span class="sd"> keepdim : bool</span>
<span class="sd"> Is the dimension kept in the reduced tensor? When True the</span>
<span class="sd"> dimension is kept, it is removed from the shape otherwise.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by this reduction operation.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">dim</span> <span class="o">=</span> <span class="n">dim_resolve_negative</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">())</span>
<span class="n">axes</span> <span class="o">=</span> <span class="n">dim_to_trt_axes</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_reduce</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
<span class="n">trt</span><span class="o">.</span><span class="n">ReduceOperation</span><span class="o">.</span><span class="n">AVG</span><span class="p">,</span>
<span class="n">axes</span><span class="p">,</span>
<span class="n">keep_dims</span><span class="o">=</span><span class="n">keepdim</span><span class="p">)</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
<div class="viewcode-block" id="max"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.max">[docs]</a><span class="k">def</span> <span class="nf">max</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">keepdim</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an operation to compute the max along a dimension.</span>
<span class="sd"> Computes the max along the dimension &#39;dim&#39; of the input tensor.</span>
<span class="sd"> It is implemented using the IReduceLayer from TensorRT.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor.</span>
<span class="sd"> dim : int</span>
<span class="sd"> The dimension along which the mean is computed.</span>
<span class="sd"> keepdim : bool</span>
<span class="sd"> Is the dimension kept in the reduced tensor? When True the</span>
<span class="sd"> dimension is kept, it is removed from the shape otherwise.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by this reduction operation.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">dim</span> <span class="o">=</span> <span class="n">dim_resolve_negative</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">())</span>
<span class="n">axes</span> <span class="o">=</span> <span class="n">dim_to_trt_axes</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_reduce</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
<span class="n">trt</span><span class="o">.</span><span class="n">ReduceOperation</span><span class="o">.</span><span class="n">MAX</span><span class="p">,</span>
<span class="n">axes</span><span class="p">,</span>
<span class="n">keep_dims</span><span class="o">=</span><span class="n">keepdim</span><span class="p">)</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
<div class="viewcode-block" id="identity"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.identity">[docs]</a><span class="k">def</span> <span class="nf">identity</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an identity operation.</span>
<span class="sd"> TODO: Document why it can be done using a plugin!!!</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by this identity operation.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">identity_plugin</span><span class="p">:</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_identity</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
<span class="s1">&#39;Identity&#39;</span><span class="p">,</span> <span class="s1">&#39;1&#39;</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">()</span>
<span class="n">id_plug</span> <span class="o">=</span> <span class="n">plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">&quot;identity&quot;</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">]</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">id_plug</span><span class="p">)</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
<div class="viewcode-block" id="argmax"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.argmax">[docs]</a><span class="k">def</span> <span class="nf">argmax</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">keepdim</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an argmax operation.</span>
<span class="sd"> As explained in the ONNX documentation,</span>
<span class="sd"> https://github.com/onnx/onnx/blob/main/docs/Operators.md#argmax</span>
<span class="sd"> that function creates a layer computing the indices of the max elements of</span>
<span class="sd"> the input tensor&#39;s element along the provided dim. The resulting tensor</span>
<span class="sd"> has the same rank as the input if keepdims is True. If keepdims is False,</span>
<span class="sd"> then the resulting tensor has the reduced dimension pruned.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor.</span>
<span class="sd"> dim : int</span>
<span class="sd"> The dimension in which to compute the argmax indices.</span>
<span class="sd"> keepdim : bool</span>
<span class="sd"> Do we keep the dimension along which the reduction is performed?</span>
<span class="sd"> Yes, if set to True, no otherwise.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by this argmax operation.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">dim</span> <span class="o">=</span> <span class="n">dim_resolve_negative</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">())</span>
<span class="n">axes</span> <span class="o">=</span> <span class="n">dim_to_trt_axes</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_topk</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">TopKOperation</span><span class="o">.</span><span class="n">MAX</span><span class="p">,</span>
<span class="mi">1</span><span class="p">,</span> <span class="n">axes</span><span class="p">)</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
<span class="k">if</span> <span class="n">keepdim</span><span class="p">:</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">layer</span><span class="p">)</span>
<span class="n">a</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">())))</span>
<span class="n">a</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span>
<span class="n">indices</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">int32_array</span><span class="p">([</span><span class="n">a</span><span class="p">]))</span>
<span class="n">output_shape</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="n">output</span><span class="p">)</span>
<span class="n">new_shape</span> <span class="o">=</span> <span class="n">gather</span><span class="p">(</span><span class="n">output_shape</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">indices</span><span class="p">)</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">view</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">new_shape</span><span class="p">)</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
<div class="viewcode-block" id="gelu"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.gelu">[docs]</a><span class="k">def</span> <span class="nf">gelu</span><span class="p">(</span><span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add a GELU operation.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor on which the activation function is applied.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by the activation layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">if</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">strongly_typed</span><span class="p">:</span>
<span class="k">if</span> <span class="n">x</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">float16</span><span class="p">:</span>
<span class="n">v1</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">fp16_array</span><span class="p">([</span><span class="mf">0.5</span><span class="p">]))</span>
<span class="n">v2</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">fp16_array</span><span class="p">([</span><span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="mf">2.0</span> <span class="o">/</span> <span class="n">math</span><span class="o">.</span><span class="n">pi</span><span class="p">)]))</span>
<span class="n">v3</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">fp16_array</span><span class="p">([</span><span class="mf">0.044715</span><span class="p">]))</span>
<span class="n">v4</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">fp16_array</span><span class="p">([</span><span class="mf">3.0</span><span class="p">]))</span>
<span class="n">v5</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">fp16_array</span><span class="p">([</span><span class="mf">1.0</span><span class="p">]))</span>
<span class="k">elif</span> <span class="n">x</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">float32</span><span class="p">:</span>
<span class="n">v1</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">fp32_array</span><span class="p">([</span><span class="mf">0.5</span><span class="p">]))</span>
<span class="n">v2</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">fp32_array</span><span class="p">([</span><span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="mf">2.0</span> <span class="o">/</span> <span class="n">math</span><span class="o">.</span><span class="n">pi</span><span class="p">)]))</span>
<span class="n">v3</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">fp32_array</span><span class="p">([</span><span class="mf">0.044715</span><span class="p">]))</span>
<span class="n">v4</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">fp32_array</span><span class="p">([</span><span class="mf">3.0</span><span class="p">]))</span>
<span class="n">v5</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">fp32_array</span><span class="p">([</span><span class="mf">1.0</span><span class="p">]))</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">assert</span> <span class="kc">False</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;gelu on datatype of </span><span class="si">{</span><span class="n">x</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2"> is not supported&quot;</span>
<span class="k">return</span> <span class="n">v1</span> <span class="o">*</span> <span class="n">x</span> <span class="o">*</span> <span class="p">(</span><span class="n">tanh</span><span class="p">(</span><span class="n">v2</span> <span class="o">*</span> <span class="p">(</span><span class="n">x</span> <span class="o">+</span> <span class="n">v3</span> <span class="o">*</span> <span class="nb">pow</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">v4</span><span class="p">)))</span> <span class="o">+</span> <span class="n">v5</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">return</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="n">x</span> <span class="o">*</span> <span class="p">(</span>
<span class="n">tanh</span><span class="p">(</span><span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="mf">2.0</span> <span class="o">/</span> <span class="n">math</span><span class="o">.</span><span class="n">pi</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="n">x</span> <span class="o">+</span> <span class="mf">0.044715</span> <span class="o">*</span> <span class="nb">pow</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mf">3.0</span><span class="p">)))</span> <span class="o">+</span> <span class="mf">1.0</span><span class="p">)</span></div>
<div class="viewcode-block" id="geglu"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.geglu">[docs]</a><span class="k">def</span> <span class="nf">geglu</span><span class="p">(</span><span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add a Gated-GELU operation.</span>
<span class="sd"> That function takes a tensor, splits it into two halves along the last</span>
<span class="sd"> dimension, applies GELU to the second half and multiply the results. The</span>
<span class="sd"> behaviour is undefined if the last dimension is not even.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor on which the activation function is applied.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by the activation layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">a</span><span class="p">,</span> <span class="n">b</span> <span class="o">=</span> <span class="n">chunk</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="k">return</span> <span class="n">a</span> <span class="o">*</span> <span class="n">gelu</span><span class="p">(</span><span class="n">b</span><span class="p">)</span></div>
<div class="viewcode-block" id="group_norm"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.group_norm">[docs]</a><span class="k">def</span> <span class="nf">group_norm</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">num_groups</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">weight</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">bias</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">eps</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-05</span><span class="p">):</span>
<span class="c1">##</span>
<span class="c1">## TODO: Document that function!</span>
<span class="c1">##</span>
<span class="k">assert</span> <span class="ow">not</span> <span class="nb">input</span><span class="o">.</span><span class="n">is_dynamic</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
<span class="n">num_channels</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="mi">1</span><span class="p">]</span>
<span class="n">ndim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
<span class="n">old_shape</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">)</span>
<span class="n">new_shape</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span>
<span class="nb">input</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
<span class="n">num_groups</span><span class="p">,</span>
<span class="n">num_channels</span> <span class="o">//</span> <span class="n">num_groups</span><span class="p">,</span>
<span class="p">]</span> <span class="o">+</span> <span class="p">[</span><span class="nb">input</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="n">i</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">ndim</span><span class="p">)])</span>
<span class="n">x</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">new_shape</span><span class="p">)</span>
<span class="n">reduce_dim</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">ndim</span> <span class="o">+</span> <span class="mi">1</span><span class="p">))</span>
<span class="n">ux</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">reduce_dim</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">numerator</span> <span class="o">=</span> <span class="n">x</span> <span class="o">-</span> <span class="n">ux</span>
<span class="n">varx</span> <span class="o">=</span> <span class="n">numerator</span> <span class="o">*</span> <span class="n">numerator</span>
<span class="n">varx</span> <span class="o">=</span> <span class="n">varx</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">reduce_dim</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">denom</span> <span class="o">=</span> <span class="n">varx</span> <span class="o">+</span> <span class="n">eps</span>
<span class="n">denom</span> <span class="o">=</span> <span class="n">denom</span><span class="o">.</span><span class="n">sqrt</span><span class="p">()</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">numerator</span> <span class="o">/</span> <span class="n">denom</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">y</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">old_shape</span><span class="p">)</span>
<span class="n">new_shape</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span><span class="n">num_channels</span><span class="p">]</span> <span class="o">+</span> <span class="p">[</span><span class="mi">1</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">ndim</span><span class="p">)])</span>
<span class="k">if</span> <span class="n">weight</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">y</span> <span class="o">*</span> <span class="n">weight</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">new_shape</span><span class="p">)</span>
<span class="k">if</span> <span class="n">bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">y</span> <span class="o">+</span> <span class="n">bias</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">new_shape</span><span class="p">)</span>
<span class="k">return</span> <span class="n">y</span></div>
<div class="viewcode-block" id="softplus"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.softplus">[docs]</a><span class="k">def</span> <span class="nf">softplus</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">beta</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">threshold</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add the softplus activation base on PyTorch definition.</span>
<span class="sd"> See https://pytorch.org/docs/stable/generated/torch.nn.functional.softplus.html for a</span>
<span class="sd"> description of that function.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> Input TensorRT-LLM Tensor.</span>
<span class="sd"> beta : float</span>
<span class="sd"> The parameter for softplus computation.</span>
<span class="sd"> threshold : float</span>
<span class="sd"> The threshold for reverting to the linear function when input * beta &gt; threashold</span>
<span class="sd"> Returns:</span>
<span class="sd"> The output tensor created by that layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">sf_layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_activation</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
<span class="n">trt</span><span class="o">.</span><span class="n">ActivationType</span><span class="o">.</span><span class="n">SOFTPLUS</span><span class="p">)</span>
<span class="n">sf_layer</span><span class="o">.</span><span class="n">alpha</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">/</span> <span class="n">beta</span>
<span class="n">sf_layer</span><span class="o">.</span><span class="n">beta</span> <span class="o">=</span> <span class="n">beta</span>
<span class="n">prod_tensor</span> <span class="o">=</span> <span class="nb">input</span> <span class="o">*</span> <span class="n">beta</span>
<span class="n">result</span> <span class="o">=</span> <span class="n">prod_tensor</span> <span class="o">&gt;</span> <span class="n">threshold</span>
<span class="k">return</span> <span class="n">where</span><span class="p">(</span><span class="n">result</span><span class="p">,</span> <span class="nb">input</span><span class="p">,</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">sf_layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
<span class="n">sf_layer</span><span class="p">))</span></div>
<div class="viewcode-block" id="outer"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.outer">[docs]</a><span class="k">def</span> <span class="nf">outer</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">vec2</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an operation to compute the outer product between two tensors.</span>
<span class="sd"> That operation creates an Einsum node.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The first input tensor.</span>
<span class="sd"> vec2 : Tensor</span>
<span class="sd"> The second input tensor.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The output tensor produced by this layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="n">einsum</span><span class="p">(</span><span class="s1">&#39;i,j-&gt;ij&#39;</span><span class="p">,</span> <span class="p">[</span><span class="nb">input</span><span class="p">,</span> <span class="n">vec2</span><span class="p">])</span></div>
<div class="viewcode-block" id="avg_pool2d"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.avg_pool2d">[docs]</a><span class="k">def</span> <span class="nf">avg_pool2d</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">kernel_size</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span>
<span class="n">stride</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">padding</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span>
<span class="n">ceil_mode</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">count_include_pad</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="c1">##</span>
<span class="c1">## TODO: Document that function!</span>
<span class="c1">##</span>
<span class="k">assert</span> <span class="ow">not</span> <span class="nb">input</span><span class="o">.</span><span class="n">is_dynamic</span><span class="p">()</span>
<span class="n">ndim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
<span class="k">if</span> <span class="n">ndim</span> <span class="o">==</span> <span class="mi">3</span><span class="p">:</span>
<span class="nb">input</span> <span class="o">=</span> <span class="n">expand_dims</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_pooling</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PoolingType</span><span class="o">.</span><span class="n">AVERAGE</span><span class="p">,</span> <span class="n">kernel_size</span><span class="p">)</span>
<span class="k">if</span> <span class="n">stride</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">layer</span><span class="o">.</span><span class="n">stride</span> <span class="o">=</span> <span class="n">kernel_size</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">layer</span><span class="o">.</span><span class="n">stride</span> <span class="o">=</span> <span class="n">stride</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
<span class="k">if</span> <span class="n">ndim</span> <span class="o">==</span> <span class="mi">3</span><span class="p">:</span>
<span class="k">return</span> <span class="n">output</span><span class="o">.</span><span class="n">view</span><span class="p">(</span>
<span class="n">concat</span><span class="p">([</span><span class="n">output</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span>
<span class="n">output</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">2</span><span class="p">),</span>
<span class="n">output</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">3</span><span class="p">)]))</span>
<span class="k">return</span> <span class="n">output</span></div>
<div class="viewcode-block" id="conv2d"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.conv2d">[docs]</a><span class="k">def</span> <span class="nf">conv2d</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">weight</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">bias</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">stride</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
<span class="n">padding</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span>
<span class="n">dilation</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
<span class="n">groups</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="c1">##</span>
<span class="c1">## TODO: Document that function!</span>
<span class="c1">##</span>
<span class="k">assert</span> <span class="ow">not</span> <span class="nb">input</span><span class="o">.</span><span class="n">is_dynamic</span><span class="p">()</span>
<span class="n">ndim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
<span class="k">if</span> <span class="n">ndim</span> <span class="o">==</span> <span class="mi">3</span><span class="p">:</span>
<span class="nb">input</span> <span class="o">=</span> <span class="n">expand_dims</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
<span class="n">noutput</span> <span class="o">=</span> <span class="n">weight</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">kernel_size</span> <span class="o">=</span> <span class="p">(</span><span class="n">weight</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="o">-</span><span class="mi">2</span><span class="p">],</span> <span class="n">weight</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
<span class="n">is_weight_constant</span> <span class="o">=</span> <span class="p">(</span><span class="n">weight</span><span class="o">.</span><span class="n">producer</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="ow">and</span> <span class="n">weight</span><span class="o">.</span><span class="n">producer</span><span class="o">.</span><span class="n">type</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">LayerType</span><span class="o">.</span><span class="n">CONSTANT</span><span class="p">)</span>
<span class="n">weight</span> <span class="o">=</span> <span class="n">weight</span><span class="o">.</span><span class="n">producer</span><span class="o">.</span><span class="n">weights</span> <span class="k">if</span> <span class="n">is_weight_constant</span> <span class="k">else</span> <span class="n">trt</span><span class="o">.</span><span class="n">Weights</span><span class="p">()</span>
<span class="k">if</span> <span class="n">bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">is_bias_constant</span> <span class="o">=</span> <span class="p">(</span><span class="n">bias</span><span class="o">.</span><span class="n">producer</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="ow">and</span> <span class="n">bias</span><span class="o">.</span><span class="n">producer</span><span class="o">.</span><span class="n">type</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">LayerType</span><span class="o">.</span><span class="n">CONSTANT</span><span class="p">)</span>
<span class="n">bias</span> <span class="o">=</span> <span class="n">bias</span><span class="o">.</span><span class="n">producer</span><span class="o">.</span><span class="n">weights</span> <span class="k">if</span> <span class="n">is_bias_constant</span> <span class="k">else</span> <span class="n">trt</span><span class="o">.</span><span class="n">Weights</span><span class="p">()</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_convolution_nd</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">noutput</span><span class="p">,</span>
<span class="n">kernel_size</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">bias</span><span class="p">)</span>
<span class="n">layer</span><span class="o">.</span><span class="n">stride_nd</span> <span class="o">=</span> <span class="n">stride</span>
<span class="n">layer</span><span class="o">.</span><span class="n">padding_nd</span> <span class="o">=</span> <span class="n">padding</span>
<span class="n">layer</span><span class="o">.</span><span class="n">dilation</span> <span class="o">=</span> <span class="n">dilation</span>
<span class="n">layer</span><span class="o">.</span><span class="n">num_groups</span> <span class="o">=</span> <span class="n">groups</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">is_weight_constant</span><span class="p">:</span>
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">weight</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
<span class="k">if</span> <span class="n">bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">is_bias_constant</span><span class="p">:</span>
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">bias</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
<span class="k">if</span> <span class="n">ndim</span> <span class="o">==</span> <span class="mi">3</span><span class="p">:</span>
<span class="k">return</span> <span class="n">output</span><span class="o">.</span><span class="n">view</span><span class="p">(</span>
<span class="n">concat</span><span class="p">([</span><span class="n">output</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span>
<span class="n">output</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">2</span><span class="p">),</span>
<span class="n">output</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">3</span><span class="p">)]))</span>
<span class="k">return</span> <span class="n">output</span></div>
<div class="viewcode-block" id="conv_transpose2d"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.conv_transpose2d">[docs]</a><span class="k">def</span> <span class="nf">conv_transpose2d</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">weight</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">bias</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">stride</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
<span class="n">padding</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span>
<span class="n">output_padding</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span>
<span class="n">dilation</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
<span class="n">groups</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="c1">##</span>
<span class="c1">## TODO: Document that function!</span>
<span class="c1">##</span>
<span class="k">assert</span> <span class="ow">not</span> <span class="nb">input</span><span class="o">.</span><span class="n">is_dynamic</span><span class="p">()</span>
<span class="n">ndim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
<span class="k">if</span> <span class="n">ndim</span> <span class="o">==</span> <span class="mi">3</span><span class="p">:</span>
<span class="nb">input</span> <span class="o">=</span> <span class="n">expand_dims</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
<span class="n">noutput</span> <span class="o">=</span> <span class="n">weight</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="mi">1</span><span class="p">]</span>
<span class="n">kernel_size</span> <span class="o">=</span> <span class="p">(</span><span class="n">weight</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="o">-</span><span class="mi">2</span><span class="p">],</span> <span class="n">weight</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
<span class="n">is_weight_constant</span> <span class="o">=</span> <span class="p">(</span><span class="n">weight</span><span class="o">.</span><span class="n">producer</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="ow">and</span> <span class="n">weight</span><span class="o">.</span><span class="n">producer</span><span class="o">.</span><span class="n">type</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">LayerType</span><span class="o">.</span><span class="n">CONSTANT</span><span class="p">)</span>
<span class="n">weight</span> <span class="o">=</span> <span class="n">weight</span><span class="o">.</span><span class="n">producer</span><span class="o">.</span><span class="n">weights</span> <span class="k">if</span> <span class="n">is_weight_constant</span> <span class="k">else</span> <span class="n">trt</span><span class="o">.</span><span class="n">Weights</span><span class="p">()</span>
<span class="k">if</span> <span class="n">bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">is_bias_constant</span> <span class="o">=</span> <span class="p">(</span><span class="n">bias</span><span class="o">.</span><span class="n">producer</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="ow">and</span> <span class="n">bias</span><span class="o">.</span><span class="n">producer</span><span class="o">.</span><span class="n">type</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">LayerType</span><span class="o">.</span><span class="n">CONSTANT</span><span class="p">)</span>
<span class="n">bias</span> <span class="o">=</span> <span class="n">bias</span><span class="o">.</span><span class="n">producer</span><span class="o">.</span><span class="n">weights</span> <span class="k">if</span> <span class="n">is_bias_constant</span> <span class="k">else</span> <span class="n">trt</span><span class="o">.</span><span class="n">Weights</span><span class="p">()</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_deconvolution_nd</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">noutput</span><span class="p">,</span>
<span class="n">kernel_size</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">bias</span><span class="p">)</span>
<span class="n">layer</span><span class="o">.</span><span class="n">stride_nd</span> <span class="o">=</span> <span class="n">stride</span>
<span class="n">layer</span><span class="o">.</span><span class="n">padding_nd</span> <span class="o">=</span> <span class="n">padding</span>
<span class="n">layer</span><span class="o">.</span><span class="n">num_groups</span> <span class="o">=</span> <span class="n">groups</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">is_weight_constant</span><span class="p">:</span>
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">weight</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
<span class="k">if</span> <span class="n">bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">is_bias_constant</span><span class="p">:</span>
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">bias</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
<span class="k">if</span> <span class="n">ndim</span> <span class="o">==</span> <span class="mi">3</span><span class="p">:</span>
<span class="k">return</span> <span class="n">output</span><span class="o">.</span><span class="n">view</span><span class="p">(</span>
<span class="n">concat</span><span class="p">([</span><span class="n">output</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span>
<span class="n">output</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">2</span><span class="p">),</span>
<span class="n">output</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">3</span><span class="p">)]))</span>
<span class="k">return</span> <span class="n">output</span></div>
<div class="viewcode-block" id="split"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.split">[docs]</a><span class="k">def</span> <span class="nf">split</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">split_size_or_sections</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">Sequence</span><span class="p">[</span><span class="nb">int</span><span class="p">]],</span>
<span class="n">dim</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Sequence</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an operation that splits a tensor into sub-tensors.</span>
<span class="sd"> This operation creates a list of tensors that are obtained from the input</span>
<span class="sd"> tensor by slicing it along the dimension &#39;dim&#39;. If &#39;split_size_or_sections&#39;</span>
<span class="sd"> is an integer, the tensor is split into &#39;input.shape[dim] /</span>
<span class="sd"> split_size_or_sections&#39; slices. If &#39;split_size_or_sections&#39; is a list of</span>
<span class="sd"> sizes, the tensor is split into &#39;len(split_size_or_sections)&#39; slices and</span>
<span class="sd"> the size of the ith slice is given by &#39;split_size_or_sections[i]&#39;.</span>
<span class="sd"> There are several constraints with the current implementation:</span>
<span class="sd"> - The input tensor must be static (no dynamic dimension),</span>
<span class="sd"> - If &#39;split_size_or_sections&#39; is an integer, the number of elements in</span>
<span class="sd"> the &#39;dim&#39; dimension of the input must be a multiple of</span>
<span class="sd"> &#39;split_size_or_sections&#39;: &#39;input.shape[dim] % split_size_or_sections == 0&#39;.</span>
<span class="sd"> - If &#39;split_size_or_sections&#39; is a sequence, the sum of the elements in</span>
<span class="sd"> &#39;split_size_or_sections&#39; must be equal to the size in the dimension</span>
<span class="sd"> &#39;dim&#39;: &#39;input.shape[dim] == sum(ii for ii in split_size_or_sections)&#39;.</span>
<span class="sd"> That operation is implemented using a &#39;slice&#39; operation for each output</span>
<span class="sd"> slice.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> tensor : Tensor</span>
<span class="sd"> The input tensor to slice.</span>
<span class="sd"> split_size_or_sections : Union[int, Sequence[int]]</span>
<span class="sd"> If it is an integer, it encodes the size of each slice. Otherwise,</span>
<span class="sd"> if it is a sequence, it is the size of each slice.</span>
<span class="sd"> dim : int</span>
<span class="sd"> The dimension of the tensor to slice.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The list of tensors produced by the different operations.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">assert</span> <span class="ow">not</span> <span class="n">tensor</span><span class="o">.</span><span class="n">is_dynamic</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span>
<span class="n">ndim</span> <span class="o">=</span> <span class="n">tensor</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
<span class="k">if</span> <span class="n">dim</span> <span class="o">&lt;</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">dim</span> <span class="o">+=</span> <span class="n">ndim</span>
<span class="n">dim_value</span> <span class="o">=</span> <span class="n">tensor</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="n">dim</span><span class="p">]</span>
<span class="n">starts</span> <span class="o">=</span> <span class="p">[</span><span class="n">constant</span><span class="p">(</span><span class="n">int32_array</span><span class="p">([</span><span class="mi">0</span><span class="p">]))</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)]</span>
<span class="n">sizes</span> <span class="o">=</span> <span class="p">[</span><span class="n">shape</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">i</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)]</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">split_size_or_sections</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
<span class="c1"># TODO: support non-divisible cases</span>
<span class="k">assert</span> <span class="n">dim_value</span> <span class="o">%</span> <span class="n">split_size_or_sections</span> <span class="o">==</span> <span class="mi">0</span>
<span class="n">num_sections</span> <span class="o">=</span> <span class="n">dim_value</span> <span class="o">//</span> <span class="n">split_size_or_sections</span>
<span class="n">sizes</span><span class="p">[</span><span class="n">dim</span><span class="p">]</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">int32_array</span><span class="p">([</span><span class="n">split_size_or_sections</span><span class="p">]))</span>
<span class="n">outputs</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_sections</span><span class="p">):</span>
<span class="n">starts</span><span class="p">[</span><span class="n">dim</span><span class="p">]</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">int32_array</span><span class="p">([</span><span class="n">split_size_or_sections</span> <span class="o">*</span> <span class="n">i</span><span class="p">]))</span>
<span class="n">outputs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="nb">slice</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">concat</span><span class="p">(</span><span class="n">starts</span><span class="p">),</span> <span class="n">concat</span><span class="p">(</span><span class="n">sizes</span><span class="p">)))</span>
<span class="k">return</span> <span class="n">outputs</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">total_size</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">split_size_or_sections</span><span class="p">:</span>
<span class="n">total_size</span> <span class="o">+=</span> <span class="n">i</span>
<span class="k">assert</span> <span class="n">dim_value</span> <span class="o">==</span> <span class="n">total_size</span>
<span class="n">num_sections</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">split_size_or_sections</span><span class="p">)</span>
<span class="n">outputs</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_sections</span><span class="p">):</span>
<span class="k">if</span> <span class="n">i</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">starts</span><span class="p">[</span><span class="n">dim</span><span class="p">]</span> <span class="o">=</span> <span class="n">starts</span><span class="p">[</span><span class="n">dim</span><span class="p">]</span> <span class="o">+</span> <span class="n">sizes</span><span class="p">[</span><span class="n">dim</span><span class="p">]</span>
<span class="n">sizes</span><span class="p">[</span><span class="n">dim</span><span class="p">]</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">int32_array</span><span class="p">([</span><span class="n">split_size_or_sections</span><span class="p">[</span><span class="n">i</span><span class="p">]]))</span>
<span class="n">outputs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="nb">slice</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">concat</span><span class="p">(</span><span class="n">starts</span><span class="p">),</span> <span class="n">concat</span><span class="p">(</span><span class="n">sizes</span><span class="p">)))</span>
<span class="k">return</span> <span class="n">outputs</span></div>
<div class="viewcode-block" id="chunk"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.chunk">[docs]</a><span class="k">def</span> <span class="nf">chunk</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">chunks</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an operation that splits a tensor into sub-tensors.</span>
<span class="sd"> This operation creates a list of tensors that are obtained from the input</span>
<span class="sd"> tensor by chunking it along the dimension &#39;dim&#39;. It produces &#39;chunks&#39;</span>
<span class="sd"> sub-tensors.</span>
<span class="sd"> That operation is only defined for static tensors (no dynamic dimension)</span>
<span class="sd"> and the size of the tensor in the dimension &#39;dim&#39; must be a multiple of</span>
<span class="sd"> &#39;chunks&#39;: &#39;input.shape[dim] % chunks == 0&#39;.</span>
<span class="sd"> It maps to &#39;split&#39; with &#39;split_size = input.shape[dim] / chunks&#39;.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> tensor : Tensor</span>
<span class="sd"> The input tensor to slice.</span>
<span class="sd"> chunks : int</span>
<span class="sd"> The number of slices to split the input tensor into.</span>
<span class="sd"> dim : int</span>
<span class="sd"> The dimension of the tensor to slice.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The list of tensors produced by the different operations.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">assert</span> <span class="ow">not</span> <span class="n">tensor</span><span class="o">.</span><span class="n">is_dynamic</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span>
<span class="n">ndim</span> <span class="o">=</span> <span class="n">tensor</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
<span class="k">if</span> <span class="n">dim</span> <span class="o">&lt;</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">dim</span> <span class="o">+=</span> <span class="n">ndim</span>
<span class="n">dim_value</span> <span class="o">=</span> <span class="n">tensor</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="n">dim</span><span class="p">]</span>
<span class="k">assert</span> <span class="n">dim_value</span> <span class="o">%</span> <span class="n">chunks</span> <span class="o">==</span> <span class="mi">0</span>
<span class="k">return</span> <span class="n">split</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">dim_value</span> <span class="o">//</span> <span class="n">chunks</span><span class="p">,</span> <span class="n">dim</span><span class="p">)</span></div>
<div class="viewcode-block" id="AllReduceStrategy"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.AllReduceStrategy">[docs]</a><span class="k">class</span> <span class="nc">AllReduceStrategy</span><span class="p">(</span><span class="n">IntEnum</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Warning: actual definition is in cpp/tensorrt_llm/kernels/customAllReduceKernels.h</span>
<span class="sd"> they must be kept in sync</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">RING</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">ONESHOT</span> <span class="o">=</span> <span class="mi">1</span>
<span class="n">TWOSHOT</span> <span class="o">=</span> <span class="mi">2</span>
<span class="n">AUTO</span> <span class="o">=</span> <span class="mi">3</span></div>
<div class="viewcode-block" id="allreduce"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.allreduce">[docs]</a><span class="k">def</span> <span class="nf">allreduce</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">group</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span>
<span class="n">workspace</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">instance_id</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
<span class="n">strategy</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">AllReduceStrategy</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an operation that performs a collective all-reduce.</span>
<span class="sd"> Let&#39;s define &#39;world_size&#39; as the length of the &#39;group&#39; list. That functions</span>
<span class="sd"> creates a layer to compute the sum of &#39;world_size&#39; tensors distributed</span>
<span class="sd"> amongst the &#39;world_size&#39; participating ranks (one GPU per rank).</span>
<span class="sd"> The list &#39;group&#39; contains the identifiers of the ranks participating into</span>
<span class="sd"> the collective operation.</span>
<span class="sd"> The tensors in the different ranks must be 1D tensors (or views) and the output</span>
<span class="sd"> tensor will have that same shape. The output tensor will be replicated on</span>
<span class="sd"> the &#39;world_size&#39; ranks.</span>
<span class="sd"> That operation is implemented using a plugin that wraps the NCCL all-reduce</span>
<span class="sd"> collective operation. See</span>
<span class="sd"> https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#allreduce</span>
<span class="sd"> for details.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> tensor : Tensor</span>
<span class="sd"> The input tensor.</span>
<span class="sd"> group : List[int]</span>
<span class="sd"> The ranks participating into the all-reduce operation.</span>
<span class="sd"> workspace: Optional[Tensor]</span>
<span class="sd"> When using CUSTOM or AUTO mode, a tensor containing pointers to memory</span>
<span class="sd"> visible to all GPUs. It should be 3 poitners per TP rank -</span>
<span class="sd"> ptr to data buffer, ptr to barriers in, ptr to barriers out.</span>
<span class="sd"> It must be initilized using IpcMemory class.</span>
<span class="sd"> instance_id: int</span>
<span class="sd"> Used for synchronization with CUSTOM or AUTO. Corresponding plugins MUST have the same</span>
<span class="sd"> instance_id. I.e. GPU#0&#39;s allreduce after MLP at layer i must have the same instance_id as</span>
<span class="sd"> GPU#1, GPU#2... Also, instance_id MUST be unique per model. There should be two allreduce instance</span>
<span class="sd"> in GPU#0 that have the same id.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by that layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">allreduce_plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
<span class="s1">&#39;AllReduce&#39;</span><span class="p">,</span> <span class="s1">&#39;1&#39;</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
<span class="k">if</span> <span class="n">strategy</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">if</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_custom_all_reduce</span><span class="p">:</span>
<span class="n">strategy</span> <span class="o">=</span> <span class="n">AllReduceStrategy</span><span class="o">.</span><span class="n">AUTO</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">strategy</span> <span class="o">=</span> <span class="n">AllReduceStrategy</span><span class="o">.</span><span class="n">RING</span>
<span class="k">assert</span> <span class="n">allreduce_plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="n">group</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;group&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">group</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">p_dtype</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">nccl_plugin</span>
<span class="n">pf_dtype</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;type_id&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">))],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">pfc</span> <span class="o">=</span> <span class="p">[</span><span class="n">group</span><span class="p">,</span> <span class="n">pf_dtype</span><span class="p">]</span>
<span class="n">p_strategy</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;strategy&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">strategy</span><span class="p">)],</span> <span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="n">pfc</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">p_strategy</span><span class="p">)</span>
<span class="n">p_counter</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;counter&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">instance_id</span> <span class="o">+</span> <span class="mi">1</span><span class="p">],</span>
<span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">pfc</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">p_counter</span><span class="p">)</span>
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">(</span><span class="n">pfc</span><span class="p">)</span>
<span class="n">ar_plug</span> <span class="o">=</span> <span class="n">allreduce_plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">&quot;allreduce&quot;</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">tensor</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">]</span>
<span class="k">if</span> <span class="n">strategy</span> <span class="o">!=</span> <span class="n">AllReduceStrategy</span><span class="o">.</span><span class="n">RING</span><span class="p">:</span>
<span class="n">plug_inputs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">workspace</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">ar_plug</span><span class="p">)</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
<div class="viewcode-block" id="allgather"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.allgather">[docs]</a><span class="k">def</span> <span class="nf">allgather</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">group</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an operation that performs a collective all-gather.</span>
<span class="sd"> Let&#39;s define &#39;world_size&#39; as the length of the &#39;group&#39; list. That functions</span>
<span class="sd"> creates a layer to gather &#39;world_size&#39; tensors distributed</span>
<span class="sd"> amongst the &#39;world_size&#39; participating ranks (one GPU per rank).</span>
<span class="sd"> The list &#39;group&#39; contains the identifiers of the ranks participating into</span>
<span class="sd"> the collective operation.</span>
<span class="sd"> The tensors in the different ranks must be 1D tensors (or views) and the</span>
<span class="sd"> output tensor will have that same shape.</span>
<span class="sd"> Given the &#39;section_size = input.shape[0] / world_size&#39;, each rank</span>
<span class="sd"> contributes a section of its input tensor that correspond to</span>
<span class="sd"> &#39;rank*section_size:(rank+1)*section_size&#39;.</span>
<span class="sd"> That operation is implemented using a plugin that wraps the NCCL all-gather</span>
<span class="sd"> collective operation. See</span>
<span class="sd"> https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#allgather</span>
<span class="sd"> for details.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> tensor : Tensor</span>
<span class="sd"> The input tensor.</span>
<span class="sd"> group : List[int]</span>
<span class="sd"> The ranks participating into the all-gather operation.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by that layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">allgather_plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
<span class="s1">&#39;AllGather&#39;</span><span class="p">,</span> <span class="s1">&#39;1&#39;</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">allgather_plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="n">group</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;group&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">group</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">p_dtype</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">nccl_plugin</span>
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;type_id&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">))],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span><span class="n">group</span><span class="p">,</span> <span class="n">pf_type</span><span class="p">])</span>
<span class="n">allgather</span> <span class="o">=</span> <span class="n">allgather_plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">&quot;allgather&quot;</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">tensor</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">]</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">allgather</span><span class="p">)</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
<div class="viewcode-block" id="send"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.send">[docs]</a><span class="k">def</span> <span class="nf">send</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">tgt</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an operation that performs a send from a rank to another.</span>
<span class="sd"> The send operation sends a tensor from one rank to another. If a rank &#39;i&#39;</span>
<span class="sd"> sends a tensor to a rank &#39;j&#39;, the rank &#39;j&#39; must have a corresponding &#39;recv&#39;</span>
<span class="sd"> operation from rank &#39;i&#39;. See &#39;recv&#39;.</span>
<span class="sd"> That operation is implemented using a plugin that wraps the NCCL send</span>
<span class="sd"> point-to-point operation. See</span>
<span class="sd"> https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/p2p.html#ncclsend</span>
<span class="sd"> for details.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> tensor : Tensor</span>
<span class="sd"> The input tensor.</span>
<span class="sd"> tgt : int</span>
<span class="sd"> The rank that receives the tensor.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by that layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">send_plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
<span class="s1">&#39;Send&#39;</span><span class="p">,</span> <span class="s1">&#39;1&#39;</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">send_plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="n">tgt</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;tgt_rank&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">tgt</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">p_dtype</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">nccl_plugin</span>
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;type_id&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">))],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span><span class="n">tgt</span><span class="p">,</span> <span class="n">pf_type</span><span class="p">])</span>
<span class="n">send_plug</span> <span class="o">=</span> <span class="n">send_plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">&quot;send&quot;</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">tensor</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">]</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">send_plug</span><span class="p">)</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
<div class="viewcode-block" id="recv"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.recv">[docs]</a><span class="k">def</span> <span class="nf">recv</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">src</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an operation that performs a recv to a rank from another.</span>
<span class="sd"> The recv operation receives a tensor from on a rank from another. If a rank &#39;i&#39;</span>
<span class="sd"> receives a tensor from a rank &#39;j&#39;, the rank &#39;j&#39; must have a corresponding &#39;send&#39;</span>
<span class="sd"> operation to rank &#39;j&#39;. See &#39;send&#39;.</span>
<span class="sd"> That operation is implemented using a plugin that wraps the NCCL recv</span>
<span class="sd"> point-to-point operation. See</span>
<span class="sd"> https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/p2p.html#ncclrecv</span>
<span class="sd"> for details.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> tensor : Tensor</span>
<span class="sd"> The input tensor.</span>
<span class="sd"> src : int</span>
<span class="sd"> The rank that sends the tensor to.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by that layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">recv_plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
<span class="s1">&#39;Recv&#39;</span><span class="p">,</span> <span class="s1">&#39;1&#39;</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">recv_plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="n">src</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;src_rank&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">src</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">p_dtype</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">nccl_plugin</span>
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;type_id&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">))],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span><span class="n">src</span><span class="p">,</span> <span class="n">pf_type</span><span class="p">])</span>
<span class="n">recv_plug</span> <span class="o">=</span> <span class="n">recv_plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">&quot;recv&quot;</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">tensor</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">]</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">recv_plug</span><span class="p">)</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
<div class="viewcode-block" id="bert_attention"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.bert_attention">[docs]</a><span class="k">def</span> <span class="nf">bert_attention</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">input_lengths</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">num_heads</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">head_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">q_scaling</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span>
<span class="n">relative_attention</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">relative_attention_bias</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">max_distance</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an operation that performs the multi-head attention in BERT.</span>
<span class="sd"> The multihead-attention (MHA) is the sequence of a batched matmul, a</span>
<span class="sd"> softmax and a batched matmul as described in</span>
<span class="sd"> https://arxiv.org/abs/1706.03762. That function adds an operation that</span>
<span class="sd"> performs those computations using a single GPU kernel.</span>
<span class="sd"> The input tensor contains the Q, K and V elements. It is a 2D tensor and</span>
<span class="sd"> its shape is &#39;[sum_of_tokens, 3*hidden_dim]&#39; where the &#39;sum_of_tokens&#39; is</span>
<span class="sd"> the sum of the sequence lengths in the batch.</span>
<span class="sd"> In MHA, the output of the Q*K^T product is scaled by a constant value that</span>
<span class="sd"> is computed as:</span>
<span class="sd"> 1.f / (q_scaling * sqrt(head_size)).</span>
<span class="sd"> That &#39;q_scaling&#39; constant is the last argument of that function.</span>
<span class="sd"> That layer is implemented using a plugin (see bertAttentionPlugin).</span>
<span class="sd"> Parameters:</span>
<span class="sd"> tensor : Tensor</span>
<span class="sd"> The QKV input tensor.</span>
<span class="sd"> input_lengths : Tensor</span>
<span class="sd"> The length of each sequence. It is a 1D tensor of size &#39;batch_size&#39;.</span>
<span class="sd"> num_heads : int</span>
<span class="sd"> The number of heads.</span>
<span class="sd"> head_size : int</span>
<span class="sd"> The size of each head.</span>
<span class="sd"> q_scaling : float</span>
<span class="sd"> The factor to compute the scaling factor to scale the output of the</span>
<span class="sd"> &#39;Q*K^T&#39; product.</span>
<span class="sd"> relative_attention: bool = False</span>
<span class="sd"> If enable relative attention.</span>
<span class="sd"> relative_attention_bias: Tensor = None</span>
<span class="sd"> The relative attention bias [num_heads, max_seq_len, max_seq_len], or The relative attention embedding table for implicit mode, [num_heads, num_buckets].</span>
<span class="sd"> max_distance: int = 0</span>
<span class="sd"> The maximum distance of relative position in attention, for implicit mode.</span>
<span class="sd"> Default value is 0, meaning to use the regular mode of relative attention bias.</span>
<span class="sd"> Implicit mode is only enabled when passing in non-zero positive max_distance value.</span>
<span class="sd"> See relative attention bias in docs/gpt_attention.md</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by that layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">attn_plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
<span class="s1">&#39;BertAttention&#39;</span><span class="p">,</span> <span class="s1">&#39;1&#39;</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">attn_plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="n">nheads</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;num_heads&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">num_heads</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">head_size</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;head_size&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">head_size</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">q_scaling</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;q_scaling&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">q_scaling</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">FLOAT32</span><span class="p">)</span>
<span class="n">enable_qk_half_accum</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;enable_qk_half_accum&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span>
<span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">attention_qk_half_accumulation</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="n">context_fmha_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;context_fmha_type&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">context_fmha_type</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="n">p_dtype</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">bert_attention_plugin</span>
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;type_id&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">))],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">do_relative_attention</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;do_relative_attention&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">relative_attention</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="n">max_distance</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;max_distance&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">max_distance</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span>
<span class="n">nheads</span><span class="p">,</span> <span class="n">head_size</span><span class="p">,</span> <span class="n">q_scaling</span><span class="p">,</span> <span class="n">enable_qk_half_accum</span><span class="p">,</span> <span class="n">context_fmha_type</span><span class="p">,</span>
<span class="n">pf_type</span><span class="p">,</span> <span class="n">do_relative_attention</span><span class="p">,</span> <span class="n">max_distance</span>
<span class="p">])</span>
<span class="n">attn_plug</span> <span class="o">=</span> <span class="n">attn_plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">&quot;padding_attn&quot;</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">tensor</span><span class="p">,</span> <span class="n">input_lengths</span><span class="p">]</span>
<span class="k">if</span> <span class="n">relative_attention_bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">relative_attention_bias</span><span class="p">]</span>
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">i</span><span class="o">.</span><span class="n">trt_tensor</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">plug_inputs</span><span class="p">]</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">attn_plug</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">layer</span><span class="o">.</span><span class="n">num_outputs</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span> \
<span class="sa">f</span><span class="s2">&quot;Plugin outputs number mismatch with expected, got </span><span class="si">{</span><span class="n">layer</span><span class="o">.</span><span class="n">num_outputs</span><span class="si">}</span><span class="s2">, expected 1&quot;</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">output</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="k">return</span> <span class="n">output</span></div>
<div class="viewcode-block" id="gpt_attention"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.gpt_attention">[docs]</a><span class="nd">@gw</span><span class="o">.</span><span class="n">record_signature</span>
<span class="k">def</span> <span class="nf">gpt_attention</span><span class="p">(</span>
<span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">past_key_value</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">sequence_length</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">host_past_key_value_lengths</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">context_lengths</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">cache_indirection</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">host_request_types</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">num_heads</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">num_kv_heads</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">hidden_size_per_head</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">q_scaling</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span>
<span class="n">rotary_embedding_dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">rotary_embedding_base</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">10000.0</span><span class="p">,</span>
<span class="n">rotary_embedding_scale_type</span><span class="p">:</span> <span class="n">RotaryScalingType</span> <span class="o">=</span> <span class="n">RotaryScalingType</span><span class="o">.</span><span class="n">none</span><span class="p">,</span>
<span class="n">rotary_embedding_scale</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span>
<span class="n">rotary_embedding_max_positions</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1024</span><span class="p">,</span>
<span class="n">position_embedding_type</span><span class="p">:</span> <span class="n">PositionEmbeddingType</span> <span class="o">=</span> <span class="n">PositionEmbeddingType</span><span class="o">.</span>
<span class="n">learned_absolute</span><span class="p">,</span>
<span class="n">multi_block_mode</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">kv_orig_quant_scale</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">kv_quant_orig_scale</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">kv_cache_quant_mode</span><span class="p">:</span> <span class="n">QuantMode</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">max_context_length</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">mask_type</span><span class="p">:</span> <span class="n">AttentionMaskType</span> <span class="o">=</span> <span class="n">AttentionMaskType</span><span class="o">.</span><span class="n">causal</span><span class="p">,</span>
<span class="n">alibi_slopes</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">tp_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
<span class="n">tp_rank</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
<span class="n">kv_cache_block_pointers</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">do_cross_attention</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">cross_qkv</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="c1"># for cross attention</span>
<span class="n">cross_qkv_length</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="c1"># for cross attention</span>
<span class="n">encoder_input_lengths</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="c1"># for cross attention</span>
<span class="n">relative_attention_bias</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="c1"># for relative attention</span>
<span class="n">max_distance</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span> <span class="c1"># for relative attention</span>
<span class="n">host_context_lengths</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="c1"># for pad-free input mode</span>
<span class="n">qkv_bias</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an operation that performs the multi-head attention in GPT-like models.</span>
<span class="sd"> The signature of the function will change in the future release - we are in</span>
<span class="sd"> the process of simplifying the API. The current version is still</span>
<span class="sd"> work-in-progress! The following API is provided with hints regarding the</span>
<span class="sd"> arguments that are likely to be removed or merged with others in the future</span>
<span class="sd"> release.</span>
<span class="sd"> See docs/gpt_attention.md for the documentation of that function.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> tensor: Tensor</span>
<span class="sd"> The input QKV tensor. Its shape is [batch_beam_size, max_seqlen, 3</span>
<span class="sd"> * hidden_dim] in padded mode and [1, num_tokens, 3 * hidden_dim] in</span>
<span class="sd"> packed mode. See QKV Input in docs/gpt_attention.md.</span>
<span class="sd"> past_key_value: Tensor</span>
<span class="sd"> The tensor that stores KV cache data. Its shape is</span>
<span class="sd"> [max_batch_size * max_beam_width, 2, num_heads, max_seqlen, hidden_dim_per_head]</span>
<span class="sd"> in contiguous mode and</span>
<span class="sd"> [max_blocks, 2, num_heads, num_tokens_per_block, hidden_dim_per_head]</span>
<span class="sd"> in paged mode. See KV Cache in docs/gpt_attention.md,</span>
<span class="sd"> sequence_lengths: Tensor</span>
<span class="sd"> The tensor that stores the length of each sequence. Its shape is</span>
<span class="sd"> [batch_size]. See QKV Input in docs/gpt_attention.md,</span>
<span class="sd"> host past_key_value_length: Tensor</span>
<span class="sd"> An INT32 tensor of shape [batch_size].</span>
<span class="sd"> context_lengths: Tensor</span>
<span class="sd"> The tensor that stores the context-phase sequence length of each request. Its shape</span>
<span class="sd"> is [batch_size]. See QKV Input in doc/functional.py,</span>
<span class="sd"> cache_indirection: Tensor</span>
<span class="sd"> The tensor to reconstruct the paths when using beam-search. Its</span>
<span class="sd"> shape is [batch_size, beam_width, max_seqlen]. See Beam-Search in</span>
<span class="sd"> docs/gpt_attention.md,</span>
<span class="sd"> host_request_types: Tensor = None</span>
<span class="sd"> The tensor on the host that indicates if a request is in context or</span>
<span class="sd"> generation phase. Its shape is [batch_size]. See Inflight Batching</span>
<span class="sd"> in docs/gpt_attention.md,</span>
<span class="sd"> num_heads: int</span>
<span class="sd"> The number of heads,</span>
<span class="sd"> num_kv_heads: int</span>
<span class="sd"> The number of KV heads, generic to handle MHA/MQA/GQA,</span>
<span class="sd"> hidden_size_per_head: int</span>
<span class="sd"> The hidden size per head,</span>
<span class="sd"> q_scaling: float</span>
<span class="sd"> The value used to compute the scaling factor applied to the output</span>
<span class="sd"> of the Q*K^T product. See Scaling Factors in docs/gpt_attention.md,</span>
<span class="sd"> rotary_embedding_dim: int</span>
<span class="sd"> The dimension to compute RoPE. Use 0 when position_embedding_type is not RoPE.</span>
<span class="sd"> rotary_embedding_base: float</span>
<span class="sd"> The theta value to use for RoPE. Ignored when position_embedding_type is not RoPE.</span>
<span class="sd"> rotary_embedding_scale_type: RotaryScalingType</span>
<span class="sd"> The scaling type of RoPE. Ignored when position_embedding_type is not RoPE.</span>
<span class="sd"> Possible rotary scaling type:</span>
<span class="sd"> * RotaryScalingType.none</span>
<span class="sd"> * RotaryScalingType.linear</span>
<span class="sd"> * RotaryScalingType.dynamic</span>
<span class="sd"> rotary_embedding_scale: float</span>
<span class="sd"> The scale value to use for linear/dynamic scaling in RoPE.</span>
<span class="sd"> Ignored when position_embedding_type is not RoPE.</span>
<span class="sd"> Must be set to 1 (default) if rotary_embedding_scale_type is `none`.</span>
<span class="sd"> rotary_embedding_max_positions: int</span>
<span class="sd"> Needed only for `dynamic` RoPE scaling. Ignored otherwise.</span>
<span class="sd"> position_embedding_type: PositionEmbeddingType</span>
<span class="sd"> The position embedding type:</span>
<span class="sd"> * PositionEmbeddingType.learned_absolute</span>
<span class="sd"> * PositionEmbeddingType.relative</span>
<span class="sd"> * PositionEmbeddingType.rope_gptj</span>
<span class="sd"> * PositionEmbeddingType.rope_gpt_neox</span>
<span class="sd"> * PositionEmbeddingType.alibi</span>
<span class="sd"> * PositionEmbeddingType.alibi_with_scale</span>
<span class="sd"> multi_block_mode: bool</span>
<span class="sd"> Do we enable multi-block for the masked MHA. See Generation Phase</span>
<span class="sd"> in docs/gpt_attention.md,</span>
<span class="sd"> kv_orig_quant_scale: Tensor</span>
<span class="sd"> The tensor to store the scaling factor for quantization to INT8/FP8</span>
<span class="sd"> in the KV cache. Its shape is [1]. See INT8/FP8 KV Cache in</span>
<span class="sd"> docs/gpt_attention.md,</span>
<span class="sd"> kv_quant_orig_scale: Tensor</span>
<span class="sd"> The tensor to store the scaling factor for dequantization from</span>
<span class="sd"> INT8/FP8 in the KV cache. Its shape is [1]. See INT8/FP8 KV Cache</span>
<span class="sd"> in docs/gpt_attention.md,</span>
<span class="sd"> kv_cache_quant_mode: QuantMode (int flags)</span>
<span class="sd"> Do we enable the INT8 or FP8 KV cache?</span>
<span class="sd"> max_context_length: int32_t</span>
<span class="sd"> The length of the longest input sequence. See QKV Input in</span>
<span class="sd"> docs/gpt_attention.md,</span>
<span class="sd"> mask_type: int = 1</span>
<span class="sd"> The type of mask:</span>
<span class="sd"> * tensorrt_llm.layers.AttentionMaskType.padding for BERT,</span>
<span class="sd"> * tensorrt_llm.layers.AttentionMaskType.causal for GPT,</span>
<span class="sd"> * tensorrt_llm.layers.AttentionMaskType.bidirectional for ChatGLM,</span>
<span class="sd"> alibi_slopes: Tensor</span>
<span class="sd"> The ALiBi slopes. The ALiBi bias is computed on-the-fly in the kernel</span>
<span class="sd"> when possible,</span>
<span class="sd"> tp_size: int</span>
<span class="sd"> The number of processes/GPUs when tensor parallelism is activated,</span>
<span class="sd"> tp_rank: int</span>
<span class="sd"> The rank of that process (when running tensor parallelism),</span>
<span class="sd"> kv_cache_block_pointers:</span>
<span class="sd"> The tensor of block pointers for the KV cache. Its shape is</span>
<span class="sd"> [max_batch_size, max_beam_width, 2, max_blocks_per_sequence * 2]</span>
<span class="sd"> See KV cache section in docs/gpt_attention.md,</span>
<span class="sd"> do_cross_attention: bool = False</span>
<span class="sd"> Do we use this as cross attention instead of self attention,</span>
<span class="sd"> cross_qkv: Tensor = None</span>
<span class="sd"> The QKV tensor of encoder output hidden states. Its shape is [batch_size, max_seqlen, 3</span>
<span class="sd"> * hidden_dim] in padded mode and [1, num_tokens, 3 * hidden_dim] in</span>
<span class="sd"> packed mode,</span>
<span class="sd"> cross_qkv_length: Tensor = None</span>
<span class="sd"> The length of the longest encoder output sequence,</span>
<span class="sd"> encoder_input_lengths: Tensor</span>
<span class="sd"> The tensor that stores the length of each encoder input sequence. Its shape is [batch_size],</span>
<span class="sd"> relative_attention_bias: Tensor = None</span>
<span class="sd"> The relative attention bias [num_heads, max_seq_len, max_seq_len], or The relative attention embedding table for implicit mode, [num_heads, num_buckets].</span>
<span class="sd"> max_distance: int = 0</span>
<span class="sd"> The maximum distance of relative position in attention, for implicit mode.</span>
<span class="sd"> Default value is 0, meaning to use the regular mode of relative attention bias.</span>
<span class="sd"> Implicit mode is only enabled when passing in non-zero positive max_distance value.</span>
<span class="sd"> See relative attention bias in docs/gpt_attention.md</span>
<span class="sd"> host_context_lengths: Tensor = None</span>
<span class="sd"> A host tensor that contains the lengths of the different inputs,</span>
<span class="sd"> qkv_bias: Tensor = None,</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by that layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">assert</span> <span class="n">host_request_types</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="k">assert</span> <span class="p">(</span><span class="n">alibi_slopes</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">)</span> <span class="o">==</span> <span class="p">(</span><span class="n">position_embedding_type</span><span class="o">.</span><span class="n">is_alibi</span><span class="p">())</span>
<span class="n">attn_plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
<span class="s1">&#39;GPTAttention&#39;</span><span class="p">,</span> <span class="s1">&#39;1&#39;</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">attn_plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="k">assert</span> <span class="n">host_context_lengths</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">or</span> <span class="ow">not</span> <span class="n">default_net</span><span class="p">(</span>
<span class="p">)</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">remove_input_padding</span>
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">max_context_length</span><span class="p">,</span> <span class="nb">int</span><span class="p">)</span>
<span class="n">paged_kv_cache_flag</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">paged_kv_cache</span>
<span class="n">nheads</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;num_heads&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">num_heads</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">num_kv_heads</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;num_kv_heads&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">num_kv_heads</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">head_size</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;head_size&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">hidden_size_per_head</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">unidirectional</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;unidirectional&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">q_scaling</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;q_scaling&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">q_scaling</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">FLOAT32</span><span class="p">)</span>
<span class="n">rotary_embedding_dim</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;rotary_embedding_dim&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">rotary_embedding_dim</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">rotary_embedding_base</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;rotary_embedding_base&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">rotary_embedding_base</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">FLOAT32</span><span class="p">)</span>
<span class="n">rotary_embedding_scale_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;rotary_embedding_scale_type&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">rotary_embedding_scale_type</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="n">rotary_embedding_scale</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;rotary_embedding_scale&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">rotary_embedding_scale</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">FLOAT32</span><span class="p">)</span>
<span class="n">rotary_embedding_max_positions</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;rotary_embedding_max_positions&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">rotary_embedding_max_positions</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">position_embedding_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;position_embedding_type&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="nb">int</span><span class="p">(</span><span class="n">position_embedding_type</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="n">context_fmha_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;context_fmha_type&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">context_fmha_type</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="n">remove_input_padding</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;remove_input_padding&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="n">p_dtype</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">gpt_attention_plugin</span>
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;type_id&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">))],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">mask_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;mask_type&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">mask_type</span><span class="p">)],</span>
<span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">multi_block_mode</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;multi_block_mode&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">multi_block_mode</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="n">tp_size</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;tp_size&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">tp_size</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">tp_rank</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;tp_rank&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">tp_rank</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">kv_cache_quant_mode_field</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;kv_cache_quant_mode&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">kv_cache_quant_mode</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">paged_kv_cache</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;paged_kv_cache&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">paged_kv_cache_flag</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">tokens_per_block</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;tokens_per_block&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">tokens_per_block</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">max_context_length</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;max_context_length&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">max_context_length</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="k">if</span> <span class="n">qkv_bias</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">qkv_bias_enabled</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;qkv_bias_enabled&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">qkv_bias_enabled</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;qkv_bias_enabled&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="n">do_cross_attention_field</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;do_cross_attention&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">do_cross_attention</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="n">max_distance</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;max_distance&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">max_distance</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span>
<span class="n">nheads</span><span class="p">,</span> <span class="n">num_kv_heads</span><span class="p">,</span> <span class="n">head_size</span><span class="p">,</span> <span class="n">unidirectional</span><span class="p">,</span> <span class="n">q_scaling</span><span class="p">,</span>
<span class="n">position_embedding_type</span><span class="p">,</span> <span class="n">rotary_embedding_dim</span><span class="p">,</span> <span class="n">rotary_embedding_base</span><span class="p">,</span>
<span class="n">rotary_embedding_scale_type</span><span class="p">,</span> <span class="n">rotary_embedding_scale</span><span class="p">,</span>
<span class="n">rotary_embedding_max_positions</span><span class="p">,</span> <span class="n">tp_size</span><span class="p">,</span> <span class="n">tp_rank</span><span class="p">,</span> <span class="n">context_fmha_type</span><span class="p">,</span>
<span class="n">multi_block_mode</span><span class="p">,</span> <span class="n">kv_cache_quant_mode_field</span><span class="p">,</span> <span class="n">remove_input_padding</span><span class="p">,</span>
<span class="n">mask_type</span><span class="p">,</span> <span class="n">paged_kv_cache</span><span class="p">,</span> <span class="n">tokens_per_block</span><span class="p">,</span> <span class="n">pf_type</span><span class="p">,</span>
<span class="n">max_context_length</span><span class="p">,</span> <span class="n">qkv_bias_enabled</span><span class="p">,</span> <span class="n">do_cross_attention_field</span><span class="p">,</span>
<span class="n">max_distance</span>
<span class="p">])</span>
<span class="n">attn_plug</span> <span class="o">=</span> <span class="n">attn_plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">&quot;causal_attn&quot;</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span>
<span class="n">tensor</span><span class="p">,</span>
<span class="n">sequence_length</span><span class="p">,</span>
<span class="n">host_past_key_value_lengths</span><span class="p">,</span>
<span class="n">context_lengths</span><span class="p">,</span>
<span class="n">cache_indirection</span><span class="p">,</span>
<span class="n">host_request_types</span><span class="p">,</span>
<span class="p">]</span>
<span class="k">if</span> <span class="n">paged_kv_cache_flag</span><span class="p">:</span>
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">kv_cache_block_pointers</span><span class="p">]</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">past_key_value</span><span class="p">]</span>
<span class="k">if</span> <span class="n">kv_cache_quant_mode</span><span class="o">.</span><span class="n">has_kv_cache_quant</span><span class="p">():</span>
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">kv_orig_quant_scale</span><span class="p">,</span> <span class="n">kv_quant_orig_scale</span><span class="p">]</span>
<span class="k">if</span> <span class="n">alibi_slopes</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">alibi_slopes</span><span class="p">]</span>
<span class="k">if</span> <span class="n">relative_attention_bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">relative_attention_bias</span><span class="p">]</span>
<span class="k">if</span> <span class="n">do_cross_attention</span><span class="p">:</span>
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">cross_qkv</span><span class="p">,</span> <span class="n">cross_qkv_length</span><span class="p">,</span> <span class="n">encoder_input_lengths</span><span class="p">]</span>
<span class="k">if</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">host_context_lengths</span><span class="p">]</span>
<span class="k">if</span> <span class="n">qkv_bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">qkv_bias</span><span class="p">]</span>
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">i</span><span class="o">.</span><span class="n">trt_tensor</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">plug_inputs</span><span class="p">]</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">attn_plug</span><span class="p">)</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
<span class="n">present_key_value</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">paged_kv_cache_flag</span><span class="p">:</span>
<span class="n">present_key_value</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">present_key_value</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="n">expected_outputs</span> <span class="o">=</span> <span class="mi">2</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">expected_outputs</span> <span class="o">=</span> <span class="mi">1</span>
<span class="k">assert</span> <span class="n">layer</span><span class="o">.</span><span class="n">num_outputs</span> <span class="o">==</span> <span class="n">expected_outputs</span><span class="p">,</span> \
<span class="sa">f</span><span class="s2">&quot;Plugin outputs number mismatch with expected, got </span><span class="si">{</span><span class="n">layer</span><span class="o">.</span><span class="n">num_outputs</span><span class="si">}</span><span class="s2">, expected </span><span class="si">{</span><span class="n">expected_outputs</span><span class="si">}</span><span class="s2">&quot;</span>
<span class="k">if</span> <span class="n">kv_cache_quant_mode</span><span class="o">.</span><span class="n">has_int8_kv_cache</span><span class="p">()</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">paged_kv_cache_flag</span><span class="p">:</span>
<span class="c1"># past key value</span>
<span class="n">layer</span><span class="o">.</span><span class="n">get_input</span><span class="p">(</span><span class="mi">6</span><span class="p">)</span><span class="o">.</span><span class="n">set_dynamic_range</span><span class="p">(</span><span class="o">-</span><span class="mi">127</span><span class="p">,</span> <span class="mi">127</span><span class="p">)</span>
<span class="c1"># present key value</span>
<span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">set_dynamic_range</span><span class="p">(</span><span class="o">-</span><span class="mi">127</span><span class="p">,</span> <span class="mi">127</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">output</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="k">return</span> <span class="n">output</span><span class="p">,</span> <span class="n">present_key_value</span></div>
<div class="viewcode-block" id="assertion"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.assertion">[docs]</a><span class="k">def</span> <span class="nf">assertion</span><span class="p">(</span><span class="n">condition</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">message</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s1">&#39;&#39;</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_assertion</span><span class="p">(</span><span class="n">condition</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">message</span><span class="p">)</span></div>
<div class="viewcode-block" id="layer_norm"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.layer_norm">[docs]</a><span class="k">def</span> <span class="nf">layer_norm</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">normalized_shape</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">]],</span>
<span class="n">weight</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">bias</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">eps</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-05</span><span class="p">,</span>
<span class="n">use_diff_of_squares</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add a layer-norm operation on a tensor.</span>
<span class="sd"> That operation applies the layer-normalization to its input tensor. In its</span>
<span class="sd"> simplest form, for large language models, the &#39;normalized_shape&#39; should be</span>
<span class="sd"> set to the hidden dimension of the activation tensor. Otherwise, it is the</span>
<span class="sd"> shape of the normalized fraction of the tensor (starting from the</span>
<span class="sd"> right-most dimension).</span>
<span class="sd"> The &#39;weight&#39; tensor corresponds to &#39;gamma&#39; in the layer-norm formula and</span>
<span class="sd"> &#39;bias&#39; is &#39;beta&#39;. The &#39;eps&#39; value is added to the variance before computing</span>
<span class="sd"> the squared-root.</span>
<span class="sd"> This implementation (when using the plugin) supports an additional flag to</span>
<span class="sd"> enable/disable the use of a difference of squares (&#39;Var = Mean(X^2) -</span>
<span class="sd"> Mean(X)^2&#39;).</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The tensor to normalize.</span>
<span class="sd"> normalized_shape : Union[int, Tuple[int]]</span>
<span class="sd"> The shape of the sub-tensor that is normalized. Use &#39;hidden_dim&#39; to</span>
<span class="sd"> normalize the inner-most dimension of an activation tensor in LLMs.</span>
<span class="sd"> weight : Optional[Tensor] = None</span>
<span class="sd"> The &#39;gamma&#39; term in layer-norm. Its shape must be</span>
<span class="sd"> &#39;normalized_shape&#39;.</span>
<span class="sd"> bias : Optional[Tensor] = None</span>
<span class="sd"> The &#39;beta&#39; term in layer-norm. Its shape must be</span>
<span class="sd"> &#39;normalized_shape&#39;.</span>
<span class="sd"> eps : float</span>
<span class="sd"> The epsilon term to be added to the variance in the squared-root.</span>
<span class="sd"> use_diff_of_squares : bool</span>
<span class="sd"> Does the plugin use the difference of squares to compute the</span>
<span class="sd"> variance?</span>
<span class="sd"> Returns:</span>
<span class="sd"> The output tensor of that operation.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">layernorm_plugin</span><span class="p">:</span>
<span class="nb">input</span><span class="p">,</span> <span class="n">weight</span> <span class="o">=</span> <span class="n">broadcast_helper</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">weight</span><span class="p">)</span>
<span class="nb">input</span><span class="p">,</span> <span class="n">bias</span> <span class="o">=</span> <span class="n">broadcast_helper</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">bias</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">normalized_shape</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span> <span class="c1"># FIXME: better way?</span>
<span class="n">axis</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">-</span> <span class="mi">1</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">axis</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">-</span> <span class="nb">len</span><span class="p">(</span><span class="n">normalized_shape</span><span class="p">)</span>
<span class="n">axes_mask</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">axis</span><span class="p">,</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()):</span>
<span class="n">axes_mask</span> <span class="o">|=</span> <span class="mi">1</span> <span class="o">&lt;&lt;</span> <span class="n">i</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_normalization</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
<span class="n">weight</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
<span class="n">bias</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">axes_mask</span><span class="p">)</span>
<span class="n">layer</span><span class="o">.</span><span class="n">epsilon</span> <span class="o">=</span> <span class="n">eps</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
<span class="s1">&#39;Layernorm&#39;</span><span class="p">,</span> <span class="s1">&#39;1&#39;</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="n">eps</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;eps&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">eps</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">FLOAT32</span><span class="p">)</span>
<span class="n">use_diff_of_squares</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;use_diff_of_squares&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">use_diff_of_squares</span><span class="p">)],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">p_dtype</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">layernorm_plugin</span>
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;type_id&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">))],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span><span class="n">eps</span><span class="p">,</span> <span class="n">use_diff_of_squares</span><span class="p">,</span> <span class="n">pf_type</span><span class="p">])</span>
<span class="n">layernorm_plug</span> <span class="o">=</span> <span class="n">plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">&quot;layernorm&quot;</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
<span class="n">normalized_shape</span> <span class="o">=</span> <span class="p">[</span><span class="n">normalized_shape</span><span class="p">]</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span>
<span class="n">normalized_shape</span><span class="p">,</span> <span class="nb">int</span><span class="p">)</span> <span class="k">else</span> <span class="n">normalized_shape</span>
<span class="k">if</span> <span class="n">weight</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">weight</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span>
<span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">normalized_shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">str_dtype_to_np</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">)))</span>
<span class="k">if</span> <span class="n">bias</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">bias</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span>
<span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">normalized_shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">str_dtype_to_np</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">)))</span>
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">weight</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">bias</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">]</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">layernorm_plug</span><span class="p">)</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
<div class="viewcode-block" id="rms_norm"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.rms_norm">[docs]</a><span class="k">def</span> <span class="nf">rms_norm</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">normalized_shape</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">]],</span>
<span class="n">weight</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">eps</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-06</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add a RMS norm operation on a tensor.</span>
<span class="sd"> That operation applies the rms-normalization to its input tensor. In its</span>
<span class="sd"> simplest form, for large language models, the &#39;normalized_shape&#39; should be</span>
<span class="sd"> set to the hidden dimension of the activation tensor. Otherwise, it is the</span>
<span class="sd"> shape of the normalized fraction of the tensor (starting from the</span>
<span class="sd"> right-most dimension).</span>
<span class="sd"> The &#39;weight&#39; tensor corresponds to &#39;gamma&#39; in the rms-norm formula.</span>
<span class="sd"> The &#39;eps&#39; value is added to the variance before computing the squared-root.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input: Tensor</span>
<span class="sd"> The tensor to normalize.</span>
<span class="sd"> normalized_shape : Union[int, Tuple[int]]</span>
<span class="sd"> The shape of the sub-tensor that is normalized. Use &#39;hidden_dim&#39; to</span>
<span class="sd"> normalize the inner-most dimension of an activation tensor in LLMs.</span>
<span class="sd"> weight : Optional[Tensor] = None</span>
<span class="sd"> The &#39;gamma&#39; term in layer-norm. Its shape must be</span>
<span class="sd"> &#39;normalized_shape&#39;.</span>
<span class="sd"> eps : float</span>
<span class="sd"> The epsilon term to be added to the variance in the squared-root.weig</span>
<span class="sd"> Returns:</span>
<span class="sd"> The output tensor of that operation.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">rmsnorm_plugin</span><span class="p">:</span>
<span class="n">normalized_shape</span> <span class="o">=</span> <span class="p">[</span><span class="n">normalized_shape</span><span class="p">]</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span>
<span class="n">normalized_shape</span><span class="p">,</span> <span class="nb">int</span><span class="p">)</span> <span class="k">else</span> <span class="n">normalized_shape</span>
<span class="n">dim</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">([</span><span class="o">-</span><span class="n">i</span> <span class="o">-</span> <span class="mi">1</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">normalized_shape</span><span class="p">))])</span>
<span class="k">if</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">strongly_typed</span><span class="p">:</span>
<span class="n">input_dtype</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">dtype</span>
<span class="n">fp32_input</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="s2">&quot;float32&quot;</span><span class="p">)</span>
<span class="n">varx</span> <span class="o">=</span> <span class="nb">pow</span><span class="p">(</span><span class="n">fp32_input</span><span class="p">,</span> <span class="mf">2.0</span><span class="p">)</span>
<span class="n">varx</span> <span class="o">=</span> <span class="n">varx</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">denom</span> <span class="o">=</span> <span class="n">varx</span> <span class="o">+</span> <span class="n">eps</span>
<span class="n">denom</span> <span class="o">=</span> <span class="n">denom</span><span class="o">.</span><span class="n">sqrt</span><span class="p">()</span>
<span class="n">fp32_y</span> <span class="o">=</span> <span class="n">fp32_input</span> <span class="o">/</span> <span class="n">denom</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">fp32_y</span><span class="p">,</span> <span class="n">input_dtype</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">with</span> <span class="n">precision</span><span class="p">(</span><span class="s2">&quot;float32&quot;</span><span class="p">):</span>
<span class="n">varx</span> <span class="o">=</span> <span class="nb">pow</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="mf">2.0</span><span class="p">)</span>
<span class="n">varx</span> <span class="o">=</span> <span class="n">varx</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">denom</span> <span class="o">=</span> <span class="n">varx</span> <span class="o">+</span> <span class="n">eps</span>
<span class="n">denom</span> <span class="o">=</span> <span class="n">denom</span><span class="o">.</span><span class="n">sqrt</span><span class="p">()</span>
<span class="n">y</span> <span class="o">=</span> <span class="nb">input</span> <span class="o">/</span> <span class="n">denom</span>
<span class="k">if</span> <span class="n">weight</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">y</span> <span class="o">*</span> <span class="n">weight</span>
<span class="k">return</span> <span class="n">y</span>
<span class="k">else</span><span class="p">:</span>
<span class="c1"># TODO remove the plugin version if rmsnorm operation can be offloaded</span>
<span class="c1"># to Myelin.</span>
<span class="n">plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
<span class="s1">&#39;Rmsnorm&#39;</span><span class="p">,</span> <span class="s1">&#39;1&#39;</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="n">eps</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;eps&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">eps</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">FLOAT32</span><span class="p">)</span>
<span class="n">p_dtype</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">rmsnorm_plugin</span>
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;type_id&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">))],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span><span class="n">eps</span><span class="p">,</span> <span class="n">pf_type</span><span class="p">])</span>
<span class="n">rmsnorm_plug</span> <span class="o">=</span> <span class="n">plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">&quot;rmsnorm&quot;</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
<span class="n">normalized_shape</span> <span class="o">=</span> <span class="p">[</span><span class="n">normalized_shape</span><span class="p">]</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span>
<span class="n">normalized_shape</span><span class="p">,</span> <span class="nb">int</span><span class="p">)</span> <span class="k">else</span> <span class="n">normalized_shape</span>
<span class="k">if</span> <span class="n">weight</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">weight</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span>
<span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">normalized_shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">str_dtype_to_np</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">)))</span>
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">weight</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">]</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">rmsnorm_plug</span><span class="p">)</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
<div class="viewcode-block" id="generate_alibi_slopes"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.generate_alibi_slopes">[docs]</a><span class="k">def</span> <span class="nf">generate_alibi_slopes</span><span class="p">(</span><span class="n">num_heads</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">dtype</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span>
<span class="n">tp_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
<span class="n">tp_rank</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
<span class="n">alibi_scale</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Compute the ALiBi slopes as described in https://arxiv.org/abs/2211.05100.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> num_heads : int</span>
<span class="sd"> The number of heads.</span>
<span class="sd"> dtype : trt.DataType</span>
<span class="sd"> The data type of the returned slopes</span>
<span class="sd"> tp_size : int</span>
<span class="sd"> The tensor parallelism size</span>
<span class="sd"> tp_rank : int</span>
<span class="sd"> The tensor parallelism rank</span>
<span class="sd"> Returns:</span>
<span class="sd"> A constant tensor that contains the ALiBi slopes.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">start_head_id</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">end_head_id</span> <span class="o">=</span> <span class="n">num_heads</span>
<span class="k">if</span> <span class="n">tp_size</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
<span class="n">rank_heads</span> <span class="o">=</span> <span class="n">num_heads</span> <span class="o">//</span> <span class="n">tp_size</span>
<span class="n">start_head_id</span> <span class="o">=</span> <span class="n">rank_heads</span> <span class="o">*</span> <span class="n">tp_rank</span>
<span class="n">end_head_id</span> <span class="o">=</span> <span class="n">start_head_id</span> <span class="o">+</span> <span class="n">rank_heads</span>
<span class="n">closest_power_of_2</span> <span class="o">=</span> <span class="mi">2</span><span class="o">**</span><span class="n">np</span><span class="o">.</span><span class="n">floor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">log2</span><span class="p">(</span><span class="n">num_heads</span><span class="p">))</span>
<span class="c1"># FT&#39;s implementation</span>
<span class="c1"># https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/kernels/gen_relative_pos_bias.cu#L248</span>
<span class="n">slopes_ft</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">h_id</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">start_head_id</span><span class="p">,</span> <span class="n">end_head_id</span><span class="p">):</span>
<span class="k">if</span> <span class="n">h_id</span> <span class="o">&lt;</span> <span class="n">closest_power_of_2</span><span class="p">:</span>
<span class="n">slopes_ft</span><span class="o">.</span><span class="n">append</span><span class="p">(</span>
<span class="n">np</span><span class="o">.</span><span class="n">power</span><span class="p">(</span><span class="mi">2</span><span class="o">**</span><span class="p">(</span><span class="o">-</span><span class="p">(</span><span class="mi">2</span><span class="o">**-</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">log2</span><span class="p">(</span><span class="n">closest_power_of_2</span><span class="p">)</span> <span class="o">-</span> <span class="mi">3</span><span class="p">))),</span>
<span class="n">h_id</span> <span class="o">+</span> <span class="mi">1</span><span class="p">))</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">slopes_ft</span><span class="o">.</span><span class="n">append</span><span class="p">(</span>
<span class="n">np</span><span class="o">.</span><span class="n">power</span><span class="p">(</span><span class="mi">2</span><span class="o">**</span><span class="p">(</span><span class="o">-</span><span class="p">(</span><span class="mi">2</span><span class="o">**-</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">log2</span><span class="p">(</span><span class="n">closest_power_of_2</span> <span class="o">*</span> <span class="mi">2</span><span class="p">)</span> <span class="o">-</span> <span class="mi">3</span><span class="p">))),</span>
<span class="p">(</span><span class="n">h_id</span> <span class="o">-</span> <span class="n">closest_power_of_2</span><span class="p">)</span> <span class="o">*</span> <span class="mi">2</span> <span class="o">+</span> <span class="mi">1</span><span class="p">))</span>
<span class="n">slopes</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">asarray</span><span class="p">(</span><span class="n">slopes_ft</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="n">slopes</span> <span class="o">=</span> <span class="n">alibi_scale</span> <span class="o">*</span> <span class="n">slopes</span>
<span class="c1"># Note that for bfloat16, we cannot case numpy tensor from float32 to bfloat16</span>
<span class="c1"># becuases numpy does not support bfloat16. Even if we use custom type to define</span>
<span class="c1"># the np_bfloat16, the &quot;astype&quot; here would be undefined.</span>
<span class="c1"># So, we must use torch to cast tensor from float32 to bfloat16, and then use torch_to_numpy</span>
<span class="c1"># to cast the tensor back.</span>
<span class="n">slopes</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">from_numpy</span><span class="p">(</span><span class="n">slopes</span><span class="p">)</span>
<span class="n">slopes</span> <span class="o">=</span> <span class="n">slopes</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">trt_dtype_to_torch</span><span class="p">(</span><span class="n">dtype</span><span class="p">))</span>
<span class="n">slopes</span> <span class="o">=</span> <span class="n">torch_to_numpy</span><span class="p">(</span><span class="n">slopes</span><span class="p">)</span>
<span class="n">slopes</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">slopes</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="p">(</span><span class="n">end_head_id</span> <span class="o">-</span> <span class="n">start_head_id</span><span class="p">),</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
<span class="k">return</span> <span class="n">slopes</span></div>
<div class="viewcode-block" id="generate_alibi_biases"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.generate_alibi_biases">[docs]</a><span class="k">def</span> <span class="nf">generate_alibi_biases</span><span class="p">(</span><span class="n">slopes</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">key_length</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Compute the ALiBi biases as described in https://arxiv.org/abs/2211.05100.</span>
<span class="sd"> The ALiBi biases are added to the result of the Q*K^T product in the</span>
<span class="sd"> multihead-attention block.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> slopes : Tensor</span>
<span class="sd"> The slopes.</span>
<span class="sd"> key_length : Tensor</span>
<span class="sd"> The size of the K vector per head.</span>
<span class="sd"> Returns:</span>
<span class="sd"> A constant tensor that contains the ALiBi biases.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="c1"># We don&#39;t need to care about the batch size or query length since we can just broadcast</span>
<span class="c1"># across the batch and query dimensions</span>
<span class="n">trt_0</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">int32_array</span><span class="p">(</span><span class="mi">0</span><span class="p">))</span>
<span class="n">arange_shape</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">key_length</span><span class="p">])</span>
<span class="n">arange_tensor</span> <span class="o">=</span> <span class="n">arange</span><span class="p">(</span><span class="n">trt_0</span><span class="p">,</span> <span class="n">key_length</span><span class="p">,</span> <span class="s2">&quot;float32&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">arange_shape</span><span class="p">)</span>
<span class="n">arange_tensor</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">arange_tensor</span><span class="p">,</span> <span class="s2">&quot;float32&quot;</span><span class="p">)</span>
<span class="k">return</span> <span class="n">slopes</span> <span class="o">*</span> <span class="n">arange_tensor</span></div>
<div class="viewcode-block" id="expand_mask"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.expand_mask">[docs]</a><span class="k">def</span> <span class="nf">expand_mask</span><span class="p">(</span><span class="n">mask</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">tgt_len</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Expand an attention mask.</span>
<span class="sd"> That function adds the sequence of operations to expand from a tensor of</span>
<span class="sd"> shape &#39;[batch_size, src_seq_len]&#39; to a tensor of shape</span>
<span class="sd"> &#39;[batch_size, 1, tgt_seq_len, src_seq_len]&#39;. It can be used to create the</span>
<span class="sd"> mask applied to the Q*K^T product before the softmax operation in the</span>
<span class="sd"> multihead-attention block.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> mask : Tensor</span>
<span class="sd"> The input mask</span>
<span class="sd"> tgt_len : Optional[Tensor]</span>
<span class="sd"> The dimension of the 3rd dimension in the output tensor. If None,</span>
<span class="sd"> the 2nd dimension of the input is used.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor created by that sequence of operations.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">bsz</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
<span class="n">src_len</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">tgt_len</span> <span class="o">=</span> <span class="n">tgt_len</span> <span class="k">if</span> <span class="n">tgt_len</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">src_len</span>
<span class="n">mask</span> <span class="o">=</span> <span class="n">mask</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">concat</span><span class="p">([</span><span class="n">bsz</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">src_len</span><span class="p">]))</span>
<span class="n">mask</span> <span class="o">=</span> <span class="n">expand</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="n">concat</span><span class="p">([</span><span class="n">bsz</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">tgt_len</span><span class="p">,</span> <span class="n">src_len</span><span class="p">]))</span>
<span class="n">mask</span> <span class="o">=</span> <span class="n">where</span><span class="p">(</span><span class="n">mask</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="nb">float</span><span class="p">(</span><span class="s1">&#39;-inf&#39;</span><span class="p">),</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">mask</span><span class="p">)</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="s1">&#39;float32&#39;</span><span class="p">))</span>
<span class="k">return</span> <span class="n">mask</span></div>
<div class="viewcode-block" id="gather_last_token_logits"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.gather_last_token_logits">[docs]</a><span class="k">def</span> <span class="nf">gather_last_token_logits</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">last_token_ids</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">remove_input_padding</span><span class="p">:</span> <span class="nb">bool</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Extract the logits that correspond to the last token from the hidden states.</span>
<span class="sd"> That function adds the operations to extract the logits of the last tokens</span>
<span class="sd"> in a batch of sequences.</span>
<span class="sd"> Depending on whether &#39;remove_input_padding&#39; is &#39;True&#39; or &#39;False&#39;, that</span>
<span class="sd"> function assumes inputs of different shapes.</span>
<span class="sd"> When &#39;remove_input_padding&#39; is &#39;True&#39;, the &#39;hidden_states&#39; tensor is</span>
<span class="sd"> assumed to be packed. It has a shape &#39;[num_tokens, hidden_dim]&#39; where</span>
<span class="sd"> &#39;num_tokens&#39; is the sum of the lengths of the sequences in the batch and</span>
<span class="sd"> &#39;hidden_dim&#39; is the hidden dimension. The &#39;last_tokens_ids&#39; is a 1D tensor</span>
<span class="sd"> that encodes the inclusive prefix-sums of the lengths of the sequences in</span>
<span class="sd"> the batch.</span>
<span class="sd"> When &#39;remove_input_padding&#39; is &#39;False&#39;, the &#39;hidden_states&#39; tensor is</span>
<span class="sd"> assumed to be padded. It has a shape &#39;[batch_size, max_seqlen, hidden_dim]&#39;</span>
<span class="sd"> where &#39;max_seqlen&#39; is the length of the longest sequence in the batch and</span>
<span class="sd"> &#39;hidden_dim&#39; is the hidden dimension. The &#39;last_token_ids&#39; is a 1D tensor</span>
<span class="sd"> that encodes the length of each sequence in the batch.</span>
<span class="sd"> In both cases, that function produces a tensor of shape &#39;[batch_size,</span>
<span class="sd"> hidden_size]&#39; where the row at index &#39;i&#39; corresponds to the logits of the</span>
<span class="sd"> last token from the &#39;i&#39;-th sequence.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> hidden_states : Tensor</span>
<span class="sd"> The hidden states</span>
<span class="sd"> last_token_ids : Tensor</span>
<span class="sd"> The inclusive prefix-sum of the lengths or the lenghts of the</span>
<span class="sd"> sequences in the batch.</span>
<span class="sd"> remove_input_padding : bool</span>
<span class="sd"> Indicate if the hidden_states are packed (&#39;True&#39;) or padded</span>
<span class="sd"> (&#39;False&#39;).</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor created by that sequence of operations.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">if</span> <span class="n">last_token_ids</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">return</span> <span class="n">hidden_states</span>
<span class="k">if</span> <span class="n">remove_input_padding</span><span class="p">:</span>
<span class="n">hidden_states</span> <span class="o">=</span> <span class="n">index_select</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span>
<span class="n">last_token_ids</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="c1"># [1, seq_len, hidden]</span>
<span class="n">hidden_states</span> <span class="o">=</span> <span class="n">hidden_states</span><span class="o">.</span><span class="n">view</span><span class="p">(</span>
<span class="n">concat</span><span class="p">([</span><span class="n">shape</span><span class="p">(</span><span class="n">last_token_ids</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span>
<span class="n">shape</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span> <span class="mi">2</span><span class="p">)]))</span>
<span class="k">else</span><span class="p">:</span>
<span class="c1"># only calculate logits for the last token</span>
<span class="c1"># [batch_size, seqlen, hidden_size] -&gt; [batch_size, hidden_size]</span>
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">last_token_ids</span><span class="o">.</span><span class="n">view</span><span class="p">(</span>
<span class="n">concat</span><span class="p">([</span><span class="n">shape</span><span class="p">(</span><span class="n">last_token_ids</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">]))</span>
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">expand</span><span class="p">(</span>
<span class="n">last_token_ids</span><span class="p">,</span>
<span class="n">concat</span><span class="p">([</span><span class="n">shape</span><span class="p">(</span><span class="n">last_token_ids</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span> <span class="mi">1</span><span class="p">,</span>
<span class="n">shape</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span> <span class="mi">2</span><span class="p">)]))</span>
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">last_token_ids</span> <span class="o">-</span> <span class="mi">1</span>
<span class="n">hidden_states</span> <span class="o">=</span> <span class="n">gather</span><span class="p">(</span>
<span class="n">hidden_states</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">indices</span><span class="o">=</span><span class="n">last_token_ids</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span>
<span class="n">concat</span><span class="p">([</span><span class="n">shape</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span>
<span class="n">shape</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span> <span class="mi">2</span><span class="p">)]))</span>
<span class="k">return</span> <span class="n">hidden_states</span></div>
<span class="n">ACT2FN</span> <span class="o">=</span> <span class="p">{</span>
<span class="s1">&#39;relu&#39;</span><span class="p">:</span> <span class="n">relu</span><span class="p">,</span>
<span class="s1">&#39;tanh&#39;</span><span class="p">:</span> <span class="n">tanh</span><span class="p">,</span>
<span class="s1">&#39;gelu&#39;</span><span class="p">:</span> <span class="n">gelu</span><span class="p">,</span>
<span class="s1">&#39;gelu_new&#39;</span><span class="p">:</span> <span class="n">gelu</span><span class="p">,</span>
<span class="s1">&#39;gelu_fast&#39;</span><span class="p">:</span> <span class="n">gelu</span><span class="p">,</span>
<span class="s1">&#39;geglu&#39;</span><span class="p">:</span> <span class="n">geglu</span><span class="p">,</span>
<span class="s1">&#39;silu&#39;</span><span class="p">:</span> <span class="n">silu</span><span class="p">,</span>
<span class="s1">&#39;softplus&#39;</span><span class="p">:</span> <span class="n">softplus</span><span class="p">,</span>
<span class="s1">&#39;squared-relu&#39;</span><span class="p">:</span> <span class="n">squared_relu</span><span class="p">,</span>
<span class="s1">&#39;swiglu&#39;</span><span class="p">:</span> <span class="n">swiglu</span><span class="p">,</span>
<span class="s1">&#39;fast-swiglu&#39;</span><span class="p">:</span> <span class="n">swiglu</span><span class="p">,</span>
<span class="p">}</span>
<span class="n">GATED_ACT_2_ACT</span> <span class="o">=</span> <span class="p">{</span>
<span class="s1">&#39;swiglu&#39;</span><span class="p">:</span> <span class="s1">&#39;silu&#39;</span><span class="p">,</span>
<span class="s1">&#39;fast-swiglu&#39;</span><span class="p">:</span> <span class="s1">&#39;silu&#39;</span><span class="p">,</span>
<span class="s1">&#39;geglu&#39;</span><span class="p">:</span> <span class="s1">&#39;gelu&#39;</span><span class="p">,</span>
<span class="p">}</span>
<div class="viewcode-block" id="is_gated_activation"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.is_gated_activation">[docs]</a><span class="k">def</span> <span class="nf">is_gated_activation</span><span class="p">(</span><span class="n">activation</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Is a given activation function gated?</span>
<span class="sd"> Parameters:</span>
<span class="sd"> activation : str</span>
<span class="sd"> The name of the activation function.</span>
<span class="sd"> Returns:</span>
<span class="sd"> True if the function is gated, False otherwise.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">assert</span> <span class="n">activation</span> <span class="ow">in</span> <span class="n">ACT2FN</span>
<span class="k">return</span> <span class="n">activation</span> <span class="ow">in</span> <span class="n">GATED_ACT_2_ACT</span></div>
<div class="viewcode-block" id="non_gated_version"><a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.non_gated_version">[docs]</a><span class="k">def</span> <span class="nf">non_gated_version</span><span class="p">(</span><span class="n">activation</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Given an activation function, get the non-gated version.</span>
<span class="sd"> If the activation function is non-gated, it returns the same activation</span>
<span class="sd"> function name.</span>
<span class="sd"> For example, that function returns &#39;silu&#39; for &#39;swiglu&#39; and &#39;relu&#39; for</span>
<span class="sd"> &#39;relu&#39;.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> activation : str</span>
<span class="sd"> The name of the activation function.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The name of the non-gated activation function.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">if</span> <span class="n">is_gated_activation</span><span class="p">(</span><span class="n">activation</span><span class="p">):</span>
<span class="k">return</span> <span class="n">GATED_ACT_2_ACT</span><span class="p">[</span><span class="n">activation</span><span class="p">]</span>
<span class="k">return</span> <span class="n">activation</span></div>
</pre></div>
</div>
</div>
<footer>
<hr/>
<div role="contentinfo">
<p>&#169; Copyright 2023, NVidia.</p>
</div>
Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
<a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
provided by <a href="https://readthedocs.org">Read the Docs</a>.
</footer>
</div>
</div>
</section>
</div>
<script>
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
</script>
</body>
</html>