mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
4422 lines
490 KiB
HTML
4422 lines
490 KiB
HTML
<!DOCTYPE html>
|
||
<html class="writer-html5" lang="en" data-content_root="../../">
|
||
<head>
|
||
<meta charset="utf-8" />
|
||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||
<title>tensorrt_llm.functional — tensorrt_llm documentation</title>
|
||
<link rel="stylesheet" type="text/css" href="../../_static/pygments.css?v=80d5e7a1" />
|
||
<link rel="stylesheet" type="text/css" href="../../_static/css/theme.css?v=19f00094" />
|
||
|
||
|
||
<!--[if lt IE 9]>
|
||
<script src="../../_static/js/html5shiv.min.js"></script>
|
||
<![endif]-->
|
||
|
||
<script src="../../_static/jquery.js?v=5d32c60e"></script>
|
||
<script src="../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||
<script src="../../_static/documentation_options.js?v=5929fcd5"></script>
|
||
<script src="../../_static/doctools.js?v=888ff710"></script>
|
||
<script src="../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||
<script src="../../_static/js/theme.js"></script>
|
||
<link rel="index" title="Index" href="../../genindex.html" />
|
||
<link rel="search" title="Search" href="../../search.html" />
|
||
</head>
|
||
|
||
<body class="wy-body-for-nav">
|
||
<div class="wy-grid-for-nav">
|
||
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
|
||
<div class="wy-side-scroll">
|
||
<div class="wy-side-nav-search" >
|
||
|
||
|
||
|
||
<a href="../../index.html" class="icon icon-home">
|
||
tensorrt_llm
|
||
</a>
|
||
<div role="search">
|
||
<form id="rtd-search-form" class="wy-form" action="../../search.html" method="get">
|
||
<input type="text" name="q" placeholder="Search docs" aria-label="Search docs" />
|
||
<input type="hidden" name="check_keywords" value="yes" />
|
||
<input type="hidden" name="area" value="default" />
|
||
</form>
|
||
</div>
|
||
</div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
|
||
<p class="caption" role="heading"><span class="caption-text">Contents:</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../architecture.html">TensorRT-LLM Architecture</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../gpt_runtime.html">C++ GPT Runtime</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../batch_manager.html">The Batch Manager in TensorRT-LLM</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../gpt_attention.html">Multi-head, Multi-query and Group-query Attention</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../precision.html">Numerical Precision</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../installation.html">Build TensorRT-LLM</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../performance.html">Performance of TensorRT-LLM</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../2023-05-19-how-to-debug.html">How to debug</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../2023-05-17-how-to-add-a-new-model.html">How to add a new model</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../graph-rewriting.html">Graph Rewriting Module</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../memory.html">Memory Usage of TensorRT-LLM</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/tensorrt_llm.layers.html">Layers</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/tensorrt_llm.functional.html">Functionals</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/tensorrt_llm.models.html">Models</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/tensorrt_llm.plugin.html">Plugin</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/tensorrt_llm.quantization.html">Quantization</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/tensorrt_llm.runtime.html">Runtime</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">C++ API</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../_cpp_gen/runtime.html">Runtime</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">Blogs</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../blogs/H100vsA100.html">H100 has 4.6x A100 Performance in TensorRT-LLM, achieving 10,000 tok/s at 100ms to first token</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../blogs/H200launch.html">H200 achieves nearly 12,000 tokens/sec on Llama2-13B with TensorRT-LLM</a></li>
|
||
</ul>
|
||
|
||
</div>
|
||
</div>
|
||
</nav>
|
||
|
||
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
|
||
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
|
||
<a href="../../index.html">tensorrt_llm</a>
|
||
</nav>
|
||
|
||
<div class="wy-nav-content">
|
||
<div class="rst-content">
|
||
<div role="navigation" aria-label="Page navigation">
|
||
<ul class="wy-breadcrumbs">
|
||
<li><a href="../../index.html" class="icon icon-home" aria-label="Home"></a></li>
|
||
<li class="breadcrumb-item"><a href="../index.html">Module code</a></li>
|
||
<li class="breadcrumb-item active">tensorrt_llm.functional</li>
|
||
<li class="wy-breadcrumbs-aside">
|
||
</li>
|
||
</ul>
|
||
<hr/>
|
||
</div>
|
||
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
|
||
<div itemprop="articleBody">
|
||
|
||
<h1>Source code for tensorrt_llm.functional</h1><div class="highlight"><pre>
|
||
<span></span><span class="c1"># SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.</span>
|
||
<span class="c1"># SPDX-License-Identifier: Apache-2.0</span>
|
||
<span class="c1">#</span>
|
||
<span class="c1"># Licensed under the Apache License, Version 2.0 (the "License");</span>
|
||
<span class="c1"># you may not use this file except in compliance with the License.</span>
|
||
<span class="c1"># You may obtain a copy of the License at</span>
|
||
<span class="c1">#</span>
|
||
<span class="c1"># http://www.apache.org/licenses/LICENSE-2.0</span>
|
||
<span class="c1">#</span>
|
||
<span class="c1"># Unless required by applicable law or agreed to in writing, software</span>
|
||
<span class="c1"># distributed under the License is distributed on an "AS IS" BASIS,</span>
|
||
<span class="c1"># WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.</span>
|
||
<span class="c1"># See the License for the specific language governing permissions and</span>
|
||
<span class="c1"># limitations under the License.</span>
|
||
<span class="kn">import</span> <span class="nn">math</span>
|
||
<span class="kn">import</span> <span class="nn">weakref</span>
|
||
<span class="kn">from</span> <span class="nn">collections</span> <span class="kn">import</span> <span class="n">OrderedDict</span>
|
||
<span class="kn">from</span> <span class="nn">enum</span> <span class="kn">import</span> <span class="n">IntEnum</span>
|
||
<span class="kn">from</span> <span class="nn">functools</span> <span class="kn">import</span> <span class="n">partial</span>
|
||
<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">List</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Sequence</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Union</span>
|
||
|
||
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
|
||
|
||
<span class="c1"># isort: off</span>
|
||
<span class="kn">import</span> <span class="nn">torch</span>
|
||
<span class="kn">import</span> <span class="nn">tensorrt</span> <span class="k">as</span> <span class="nn">trt</span>
|
||
<span class="c1"># isort: on</span>
|
||
|
||
<span class="kn">from</span> <span class="nn">.</span> <span class="kn">import</span> <span class="n">graph_rewriting</span> <span class="k">as</span> <span class="n">gw</span>
|
||
<span class="kn">from</span> <span class="nn">._common</span> <span class="kn">import</span> <span class="n">default_net</span><span class="p">,</span> <span class="n">default_trtnet</span><span class="p">,</span> <span class="n">precision</span>
|
||
<span class="kn">from</span> <span class="nn">._utils</span> <span class="kn">import</span> <span class="p">(</span><span class="n">bf16_array</span><span class="p">,</span> <span class="n">dim_resolve_negative</span><span class="p">,</span> <span class="n">dim_to_trt_axes</span><span class="p">,</span>
|
||
<span class="n">fp16_array</span><span class="p">,</span> <span class="n">fp32_array</span><span class="p">,</span> <span class="n">int32_array</span><span class="p">,</span> <span class="n">np_dtype_to_trt</span><span class="p">,</span>
|
||
<span class="n">str_dtype_to_np</span><span class="p">,</span> <span class="n">str_dtype_to_trt</span><span class="p">,</span> <span class="n">torch_to_numpy</span><span class="p">,</span>
|
||
<span class="n">trt_dtype_to_torch</span><span class="p">)</span>
|
||
<span class="kn">from</span> <span class="nn">.logger</span> <span class="kn">import</span> <span class="n">logger</span>
|
||
<span class="kn">from</span> <span class="nn">.plugin</span> <span class="kn">import</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span>
|
||
<span class="kn">from</span> <span class="nn">.quantization</span> <span class="kn">import</span> <span class="n">QuantMode</span>
|
||
|
||
|
||
<div class="viewcode-block" id="DimRange">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.DimRange">[docs]</a>
|
||
<span class="k">class</span> <span class="nc">DimRange</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> One DimRange object stores the ranges of all the dimensions of one tensor in one optimization profile.</span>
|
||
<span class="sd"> For example, tensor has 2 dimensions. Then the data members are:</span>
|
||
<span class="sd"> self.min = [dim 0 min, dim 1 min]</span>
|
||
<span class="sd"> self.opt = [dim 0 opt, dim 1 opt]</span>
|
||
<span class="sd"> self.max = [dim 0 max, dim 1 max]</span>
|
||
<span class="sd"> For static dimension, it has min==opt==max, thus the \p shape param in the ctor can be an integer</span>
|
||
<span class="sd"> '''</span>
|
||
|
||
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">shape</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]]],</span>
|
||
<span class="n">names</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> shape: a list with length N, each element is an integer or a 3-elements tuple/list of int,</span>
|
||
<span class="sd"> where N is the number of dimensions for a tensor.</span>
|
||
<span class="sd"> When one element is an integer, it means that dimension is static.</span>
|
||
<span class="sd"> Otherwise, when one element is a tuple/list, it means the dimension is dynamic.</span>
|
||
<span class="sd"> The 3 elements in one tuple/list is ordered by (min, opt, max), and this function asserts</span>
|
||
<span class="sd"> 0 <= min <= opt <= max.</span>
|
||
|
||
<span class="sd"> Example, for a 3 rank tensor, with 1st dimension being static and has value 16, and second dimension being dynamic with</span>
|
||
<span class="sd"> min/opt/max values being 1/8/32, and 3rd dimension being static and has value 8.</span>
|
||
<span class="sd"> The shape parameter could be:</span>
|
||
<span class="sd"> [16, (1, 8, 32), 8]</span>
|
||
<span class="sd"> It has same semantics of</span>
|
||
<span class="sd"> [(16, 16, 16), (1, 8, 32), (8, 8, 8)]</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">min</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">opt</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">max</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">dimension_names</span> <span class="o">=</span> <span class="n">names</span>
|
||
<span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">names</span><span class="p">)</span> <span class="o">==</span> <span class="nb">len</span><span class="p">(</span>
|
||
<span class="n">shape</span>
|
||
<span class="p">),</span> <span class="s2">"Expecting shape list and name list must have same length, got {shape=}, {name=}"</span>
|
||
<span class="k">for</span> <span class="n">dim</span> <span class="ow">in</span> <span class="n">shape</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="p">(</span><span class="nb">list</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">)):</span>
|
||
<span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span> <span class="o">==</span> <span class="mi">3</span> <span class="ow">and</span> <span class="mi">0</span> <span class="o"><=</span> <span class="n">dim</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o"><=</span> <span class="n">dim</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o"><=</span> <span class="n">dim</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span> \
|
||
<span class="s2">"Each dimension must specify a 3-elements tuple or list in the order of (min,opt,max), got {dim=}"</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">min</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">dim</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">opt</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">dim</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">max</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">dim</span><span class="p">[</span><span class="mi">2</span><span class="p">])</span>
|
||
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">min</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">opt</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">max</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">AttributeError</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s1">'Dimension should be [min, opt, max] (dynamic shape) or int (specific value). Got </span><span class="si">{</span><span class="nb">type</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span><span class="si">}</span><span class="s1">'</span>
|
||
<span class="p">)</span>
|
||
|
||
<span class="k">def</span> <span class="fm">__eq__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">__value</span><span class="p">:</span> <span class="nb">object</span><span class="p">)</span> <span class="o">-></span> <span class="nb">bool</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">__value</span><span class="p">,</span> <span class="n">DimRange</span><span class="p">)</span> <span class="ow">and</span> \
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">dimension_names</span> <span class="o">==</span> <span class="n">__value</span><span class="o">.</span><span class="n">dimension_names</span> <span class="ow">and</span> \
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">min</span> <span class="o">==</span> <span class="n">__value</span><span class="o">.</span><span class="n">min</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">opt</span> <span class="o">==</span> <span class="n">__value</span><span class="o">.</span><span class="n">opt</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">max</span> <span class="o">==</span> <span class="n">__value</span><span class="o">.</span><span class="n">max</span>
|
||
|
||
<span class="k">def</span> <span class="fm">__repr__</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">str</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="nb">str</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span>
|
||
|
||
<span class="k">def</span> <span class="fm">__str__</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">str</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">dimension_names</span><span class="si">=}</span><span class="s2"> </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">min</span><span class="si">=}</span><span class="s2">, </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">opt</span><span class="si">=}</span><span class="s2">, </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">max</span><span class="si">=}</span><span class="s2">)"</span>
|
||
|
||
<span class="k">def</span> <span class="fm">__hash__</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">int</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="nb">hash</span><span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="bp">self</span><span class="p">))</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="Tensor">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor">[docs]</a>
|
||
<span class="k">class</span> <span class="nc">Tensor</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> The class to represent dense tensors.</span>
|
||
|
||
<span class="sd"> A dense tensor is named, has a shape and contains typed elements. Each</span>
|
||
<span class="sd"> dimension of a tensor can either be static or dynamic. Static dimensions</span>
|
||
<span class="sd"> are known at engine compilation by TensorRT. Dynamic dimensions can take</span>
|
||
<span class="sd"> values determined at runtime. The tensor can be located on the host (CPU)</span>
|
||
<span class="sd"> or the device (GPU).</span>
|
||
<span class="sd"> '''</span>
|
||
|
||
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
|
||
<span class="n">name</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">shape</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">dim_range</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">is_network_input</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||
<span class="n">location</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">TensorLocation</span><span class="o">.</span><span class="n">DEVICE</span><span class="p">,</span>
|
||
<span class="n">network</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">trt_tensor</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> name : str</span>
|
||
<span class="sd"> The name of the tensor.</span>
|
||
|
||
<span class="sd"> dtype : tensorrt.DataType</span>
|
||
<span class="sd"> The type of the elements of the tensor. See the TensorRT</span>
|
||
<span class="sd"> documentation for list of supported data types.</span>
|
||
|
||
<span class="sd"> shape : tensorrt.Dims</span>
|
||
<span class="sd"> The dimensions of the tensor. In TensorRT-LLM, tensors can have</span>
|
||
<span class="sd"> static or dynamic dimensions (it is possible to mix static and</span>
|
||
<span class="sd"> dynamic dimensions). A static dimension is known when the</span>
|
||
<span class="sd"> TensorRT engine is built. A dynamic dimension can be set when</span>
|
||
<span class="sd"> the engine is executed. Use -1 for dynamic dimensions.</span>
|
||
|
||
<span class="sd"> dim_range : OrderedDict</span>
|
||
<span class="sd"> An ordered dictionary (the positions of the elements matter)</span>
|
||
<span class="sd"> that associates a name and a range of values to the dimensions.</span>
|
||
<span class="sd"> For a static dimension, the range must be limited to a single</span>
|
||
<span class="sd"> value. For a dynamic dimension, the range is defined by three</span>
|
||
<span class="sd"> values [min, opt, max] where min and max are, respectively, the</span>
|
||
<span class="sd"> smallest and largest possible values of that dimension. The</span>
|
||
<span class="sd"> opt value is used by TensorRT to optimize the engine for the</span>
|
||
<span class="sd"> most common case.</span>
|
||
|
||
<span class="sd"> Assume there is N optimization profiles, each item dim_range dict is ordered by:</span>
|
||
<span class="sd"> (dynamic dimension name : [profile 0 (min, opt, max), profile 1 (min, opt, max), ... profile N(min, opt, max)])</span>
|
||
<span class="sd"> or it's following when the dimension is static (can think as min==opt==max):</span>
|
||
<span class="sd"> (static dimension name : [profile 0 value, profile 1 value, ... profile N value])</span>
|
||
<span class="sd"> For static dimension the profile 0-N value must be same, (TODO: can it be simplified to be only 1 value?)</span>
|
||
<span class="sd"> And number of keys is equal to number of optimization profiles.</span>
|
||
|
||
<span class="sd"> is_network_input : bool</span>
|
||
<span class="sd"> A boolean indicating if that tensor is an input of the network.</span>
|
||
<span class="sd"> Inputs must be provided by the user to run the engine.</span>
|
||
|
||
<span class="sd"> location : tensorrt.TensorLocation</span>
|
||
<span class="sd"> A flag to indicate where the tensor will be located. It can be</span>
|
||
<span class="sd"> on the host (CPU) or the device (GPU).</span>
|
||
|
||
<span class="sd"> network: Network</span>
|
||
<span class="sd"> A parent Network instance, that helps to fine the users of this tensor.</span>
|
||
|
||
<span class="sd"> trt_tensor: trt.ITensor</span>
|
||
<span class="sd"> Construct with the ITensor instance directly, and no shape profiles are required.</span>
|
||
<span class="sd"> '''</span>
|
||
|
||
<span class="c1"># Layout of self.profiles</span>
|
||
<span class="c1"># Opt profile 0: dim 0 (min, opt, max), dim 1 (min, opt, max) ... dim M</span>
|
||
<span class="c1"># Opt profile 1: dim 0 (min, opt, max), dim 1 (min, opt, max) ... dim M</span>
|
||
<span class="c1"># ...</span>
|
||
<span class="c1"># Opt profile N: dim 0 ... dim M</span>
|
||
|
||
<span class="c1"># So from the dim_range arg to self.profiles conversion, there is a layout transpose</span>
|
||
<span class="c1"># dim_range arg is: {M dimension x N profiles}, while self.profiles layout is {N profiles x M dimensions}</span>
|
||
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">profiles</span> <span class="o">=</span> <span class="p">[]</span>
|
||
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">is_tensor_wrapper</span> <span class="o">=</span> <span class="kc">False</span> <span class="c1"># specially for the graph rewriter</span>
|
||
|
||
<span class="c1"># work as a wrapper for a trt.ITensor, this is used specially in the graph rewriter</span>
|
||
<span class="k">if</span> <span class="n">trt_tensor</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">is_tensor_wrapper</span> <span class="o">=</span> <span class="kc">True</span>
|
||
<span class="k">assert</span> <span class="n">network</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span> <span class="o">=</span> <span class="n">trt_tensor</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">_network</span> <span class="o">=</span> <span class="n">weakref</span><span class="o">.</span><span class="n">ref</span><span class="p">(</span><span class="n">network</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="ow">not</span> <span class="n">is_network_input</span><span class="p">,</span> <span class="s2">"is_network_input should be False when trt_tensor is not None"</span>
|
||
<span class="k">return</span>
|
||
|
||
<span class="c1"># be cautious here, the weakref is critical to avoid circular referencing before Network and Tensor</span>
|
||
<span class="c1"># using strong reference will likely cause significant peak memory increase, since Network objects</span>
|
||
<span class="c1"># holds the weights data.</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">_network</span> <span class="o">=</span> <span class="n">weakref</span><span class="o">.</span><span class="n">ref</span><span class="p">(</span><span class="n">default_net</span><span class="p">())</span>
|
||
<span class="k">if</span> <span class="n">is_network_input</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="n">dim_range</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dim_range</span><span class="p">,</span> <span class="n">OrderedDict</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="nb">len</span><span class="p">(</span>
|
||
<span class="n">dim_range</span>
|
||
<span class="p">)</span> <span class="o">>=</span> <span class="mi">1</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"Each input tensor shall have at least one dimension, tensor '</span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s2">' found </span><span class="si">{</span><span class="n">dim_range</span><span class="si">=}</span><span class="s2">"</span>
|
||
|
||
<span class="n">found_profiles</span> <span class="o">=</span> <span class="p">[</span>
|
||
<span class="nb">len</span><span class="p">(</span><span class="n">ranges</span><span class="p">)</span> <span class="k">for</span> <span class="n">_</span><span class="p">,</span> <span class="n">ranges</span> <span class="ow">in</span> <span class="n">dim_range</span><span class="o">.</span><span class="n">items</span><span class="p">()</span>
|
||
<span class="p">]</span>
|
||
<span class="k">assert</span> <span class="nb">all</span><span class="p">(</span>
|
||
<span class="p">[</span><span class="n">x</span> <span class="o">==</span> <span class="n">found_profiles</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">found_profiles</span><span class="p">]</span>
|
||
<span class="p">),</span> <span class="sa">f</span><span class="s2">"Expecting all the dimensions in the dim_range has same number of profiles, tensor '</span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s2">' got </span><span class="si">{</span><span class="n">dim_range</span><span class="si">=}</span><span class="s2">"</span>
|
||
|
||
<span class="n">num_opt_profile</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="n">dim_range</span><span class="o">.</span><span class="n">items</span><span class="p">())[</span><span class="mi">0</span><span class="p">][</span><span class="mi">1</span><span class="p">])</span>
|
||
<span class="k">assert</span> <span class="n">num_opt_profile</span> <span class="o">>=</span> <span class="mi">1</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_opt_profile</span><span class="p">):</span>
|
||
<span class="n">range_shape</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="n">dimension_names</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="k">for</span> <span class="n">dim</span><span class="p">,</span> <span class="n">ranges</span> <span class="ow">in</span> <span class="n">dim_range</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
|
||
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">ranges</span><span class="p">,</span> <span class="p">(</span><span class="nb">list</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">))</span>
|
||
<span class="n">range_shape</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">ranges</span><span class="p">[</span><span class="n">i</span><span class="p">])</span>
|
||
<span class="n">dimension_names</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">profiles</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">DimRange</span><span class="p">(</span><span class="n">range_shape</span><span class="p">,</span> <span class="n">dimension_names</span><span class="p">))</span>
|
||
|
||
<span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">_add_input</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">name</span><span class="p">,</span> <span class="n">dtype</span><span class="p">,</span> <span class="n">shape</span><span class="p">,</span> <span class="n">dim_range</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">name</span> <span class="o">=</span> <span class="n">name</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">dtype</span> <span class="o">=</span> <span class="n">dtype</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">shape</span> <span class="o">=</span> <span class="n">shape</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">location</span> <span class="o">=</span> <span class="n">location</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span> <span class="nf">network</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_network</span><span class="p">()</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span> <span class="nf">name</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> The name of the tensor.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">name</span>
|
||
|
||
<span class="nd">@name</span><span class="o">.</span><span class="n">setter</span>
|
||
<span class="k">def</span> <span class="nf">name</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">name</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Set the name of the tensor.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="n">name</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">name</span> <span class="o">=</span> <span class="n">name</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span> <span class="nf">dtype</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> The type of the elements in the tensor.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">dtype</span>
|
||
|
||
<span class="nd">@dtype</span><span class="o">.</span><span class="n">setter</span>
|
||
<span class="k">def</span> <span class="nf">dtype</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dtype</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Set the type of the elements in the tensor.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="n">dtype</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">dtype</span> <span class="o">=</span> <span class="n">dtype</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span> <span class="nf">shape</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> The shape of the tensor.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">size</span><span class="p">()</span>
|
||
|
||
<span class="nd">@shape</span><span class="o">.</span><span class="n">setter</span>
|
||
<span class="k">def</span> <span class="nf">shape</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">shape</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Set the shape of the tensor. See __init__.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="n">shape</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">shape</span> <span class="o">=</span> <span class="n">shape</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span> <span class="nf">location</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> The physical location of the tensor (on the host or the device).</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">location</span>
|
||
|
||
<span class="nd">@location</span><span class="o">.</span><span class="n">setter</span>
|
||
<span class="k">def</span> <span class="nf">location</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">location</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Set the physical location of the tensor (on the host or the device). See __init__.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="n">location</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">location</span> <span class="o">=</span> <span class="n">location</span>
|
||
|
||
<div class="viewcode-block" id="Tensor.mark_output">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.mark_output">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">mark_output</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">name</span><span class="p">,</span> <span class="n">dtype</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Mark a tensor as a network output.</span>
|
||
|
||
<span class="sd"> When a tensor is marked as an output, its content can be obtained after</span>
|
||
<span class="sd"> the execution of the TensorRT engine. The user is responsible for</span>
|
||
<span class="sd"> allocating buffers to store the output tensors when preparing the</span>
|
||
<span class="sd"> execution of the TensorRT engine.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">_mark_output</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">name</span><span class="p">,</span> <span class="n">dtype</span><span class="p">)</span></div>
|
||
|
||
|
||
<span class="k">def</span> <span class="fm">__add__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.add.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">add</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
|
||
|
||
<span class="k">def</span> <span class="fm">__radd__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.add.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">add</span><span class="p">(</span><span class="n">b</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span>
|
||
|
||
<span class="k">def</span> <span class="fm">__sub__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.sub.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">sub</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
|
||
|
||
<span class="k">def</span> <span class="fm">__rsub__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.sub.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">sub</span><span class="p">(</span><span class="n">b</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span>
|
||
|
||
<span class="k">def</span> <span class="fm">__mul__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.mul.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">mul</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
|
||
|
||
<span class="k">def</span> <span class="fm">__rmul__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.mul.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">mul</span><span class="p">(</span><span class="n">b</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span>
|
||
|
||
<span class="k">def</span> <span class="fm">__truediv__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.div.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">div</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
|
||
|
||
<span class="k">def</span> <span class="fm">__lt__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.lt.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">lt</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
|
||
|
||
<span class="k">def</span> <span class="fm">__gt__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.gt.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">gt</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
|
||
|
||
<span class="k">def</span> <span class="fm">__eq__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.eq.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_tensor_wrapper</span><span class="p">:</span>
|
||
<span class="c1"># for graph rewriter</span>
|
||
<span class="k">return</span> <span class="nb">hash</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">==</span> <span class="nb">hash</span><span class="p">(</span><span class="n">b</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="c1"># for creating the network</span>
|
||
<span class="k">return</span> <span class="n">eq</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
|
||
|
||
<span class="k">def</span> <span class="fm">__ge__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Maps to functional.gt or functional.eq.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">op_or</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="fm">__gt__</span><span class="p">(</span><span class="n">b</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="fm">__eq__</span><span class="p">(</span><span class="n">b</span><span class="p">))</span>
|
||
|
||
<span class="k">def</span> <span class="fm">__le__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Maps to functional.lt or functional.eq.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">op_or</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="fm">__lt__</span><span class="p">(</span><span class="n">b</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="fm">__eq__</span><span class="p">(</span><span class="n">b</span><span class="p">))</span>
|
||
|
||
<div class="viewcode-block" id="Tensor.view">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.view">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">view</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">shape</span><span class="p">,</span> <span class="n">zero_is_placeholder</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.view.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">view</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">shape</span><span class="p">,</span> <span class="n">zero_is_placeholder</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="Tensor.permute">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.permute">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">permute</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dims</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.permute.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">permute</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dims</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="Tensor.transpose">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.transpose">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">transpose</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim0</span><span class="p">,</span> <span class="n">dim1</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.transpose.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">transpose</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim0</span><span class="p">,</span> <span class="n">dim1</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="Tensor.mean">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.mean">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">mean</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.mean.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">mean</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">keepdim</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="Tensor.max">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.max">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">max</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.max.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="nb">max</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">keepdim</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="Tensor.abs">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.abs">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">abs</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.abs.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="nb">abs</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="Tensor.sqrt">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.sqrt">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">sqrt</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.sqrt.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">sqrt</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="Tensor.cast">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.cast">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">cast</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dtype</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.cast.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">cast</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dtype</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="Tensor.size">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.size">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">size</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Returns the shape of the tensor if the dim parameter is None.</span>
|
||
<span class="sd"> Otherwise, returns a size of the dimension indicated by dim. The</span>
|
||
<span class="sd"> behavior is undefined if dim is negative or exceeds the rank of the</span>
|
||
<span class="sd"> tensor.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="n">dim</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">shape</span>
|
||
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="n">dim</span><span class="p">]</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="Tensor.rank">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.rank">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">rank</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Returns the rank (i.e. the number of dimensions) of the tensor.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="Tensor.ndim">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.ndim">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">ndim</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Returns the rank (i.e. the number of dimensions) of the tensor.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="Tensor.split">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.split">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">split</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">split_size_or_sections</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.split.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">split</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">split_size_or_sections</span><span class="p">,</span> <span class="n">dim</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="Tensor.is_dynamic">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.is_dynamic">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">is_dynamic</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> If the argument 'dim' is None, that function returns a boolean that</span>
|
||
<span class="sd"> indicates if the tensor contains a dynamic dimension (True) or not</span>
|
||
<span class="sd"> (False). In that case, the first dimension is excluded (as it usually</span>
|
||
<span class="sd"> corresponds to the batch size). If the argument is an integer, that</span>
|
||
<span class="sd"> functions returns a boolean that indicates if the dimension 'dim' is</span>
|
||
<span class="sd"> dynamic (True) or not (False).</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="n">dim</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="n">dim</span><span class="p">]</span> <span class="o">==</span> <span class="o">-</span><span class="mi">1</span>
|
||
|
||
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">s</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">shape</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="n">i</span> <span class="o">!=</span> <span class="mi">0</span> <span class="ow">and</span> <span class="n">s</span> <span class="o">==</span> <span class="o">-</span><span class="mi">1</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="kc">True</span>
|
||
|
||
<span class="k">return</span> <span class="kc">False</span></div>
|
||
|
||
|
||
<span class="c1"># graph writer related functions</span>
|
||
|
||
<div class="viewcode-block" id="Tensor.get_parent">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.get_parent">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">get_parent</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">''' Get the layer that produces this tensor. '''</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">network</span><span class="o">.</span><span class="n">get_tensor_parent</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="Tensor.get_users">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.get_users">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">get_users</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">''' Get the layers that use this tensor as an input. '''</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">network</span><span class="o">.</span><span class="n">get_tensor_users</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="Tensor.replace_all_uses_with">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.replace_all_uses_with">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">replace_all_uses_with</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">new_tensor</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Replace all uses of this tensor as an input to consumer layers</span>
|
||
<span class="sd"> '''</span>
|
||
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">network</span><span class="o">.</span><span class="n">is_graph_altered</span> <span class="o">=</span> <span class="kc">True</span>
|
||
<span class="n">users</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_users</span><span class="p">()</span>
|
||
<span class="k">for</span> <span class="n">user</span> <span class="ow">in</span> <span class="n">users</span><span class="p">:</span>
|
||
<span class="n">inputs_changed</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">user</span><span class="o">.</span><span class="n">num_inputs</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="n">user</span><span class="o">.</span><span class="n">get_inputs</span><span class="p">(</span><span class="n">i</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">trt_tensor</span> <span class="ow">is</span> <span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">:</span>
|
||
<span class="n">inputs_changed</span> <span class="o">+=</span> <span class="mi">1</span>
|
||
<span class="n">user</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="n">new_tensor</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">inputs_changed</span> <span class="o">>=</span> <span class="mi">1</span><span class="p">,</span> <span class="s2">"Tensor not found in layer inputs"</span>
|
||
|
||
<span class="c1"># update the FLayerMetadata as well</span>
|
||
<span class="n">flayer</span> <span class="o">=</span> <span class="n">gw</span><span class="o">.</span><span class="n">FLayerInfoMemo</span><span class="o">.</span><span class="n">instance</span><span class="p">()</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">user</span><span class="o">.</span><span class="n">name</span><span class="p">)</span>
|
||
<span class="n">flayer</span> <span class="ow">and</span> <span class="n">flayer</span><span class="o">.</span><span class="n">replace_input_with</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">new_tensor</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="Tensor.is_trt_wrapper">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.is_trt_wrapper">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">is_trt_wrapper</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Check if there is a trt.ITensor member inside, which is required for</span>
|
||
<span class="sd"> graph rewriter. In order to differentiate usages, it may be necessary</span>
|
||
<span class="sd"> to have an inheritance hierarchy.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="s1">'trt_tensor'</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="kc">True</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="kc">False</span></div>
|
||
|
||
|
||
<span class="k">def</span> <span class="fm">__hash__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_trt_wrapper</span><span class="p">():</span>
|
||
<span class="k">return</span> <span class="nb">id</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="nb">id</span><span class="p">(</span><span class="kc">None</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<span class="k">def</span> <span class="nf">_create_tensor</span><span class="p">(</span><span class="n">trt_tensor</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">ITensor</span><span class="p">,</span>
|
||
<span class="n">producer</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">ILayer</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> A helper function to create a TensorRT-LLM Tensor object that encapsulates</span>
|
||
<span class="sd"> the connection between the TensorRT tensor (trt.ITensor) and the layer</span>
|
||
<span class="sd"> (trt.ILayer) that produces it.</span>
|
||
|
||
<span class="sd"> That function is expected to be used as:</span>
|
||
|
||
<span class="sd"> # Insert a new layer in the network using the TensorRT API:</span>
|
||
<span class="sd"> layer = default_trtnet().add_<some_layer>(...)</span>
|
||
<span class="sd"> # Extract the first output of that layer and connect it to the layer.</span>
|
||
<span class="sd"> return _create_tensor(layer.get_output(0), layer)</span>
|
||
|
||
<span class="sd"> That function also sets the precision of the layer/producer to the default</span>
|
||
<span class="sd"> precision of the network.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> trt_tensor : trt.ITensor</span>
|
||
<span class="sd"> The TensorRT tensor to connect to its producer (the layer).</span>
|
||
|
||
<span class="sd"> producer : trt.ILayer = None</span>
|
||
<span class="sd"> The producer.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The TensorRT-LLM tensor (functional.Tensor) that encapsulates the</span>
|
||
<span class="sd"> TensorRT tensor and the layer that produces it. The former is</span>
|
||
<span class="sd"> accessible through the attribute 'trt_tensor' and the latter using the</span>
|
||
<span class="sd"> attribute 'producer'.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">assert</span> <span class="n">trt_tensor</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="n">tensor</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">name</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span>
|
||
<span class="n">shape</span><span class="o">=</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span>
|
||
<span class="n">is_network_input</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||
<span class="n">tensor</span><span class="o">.</span><span class="n">trt_tensor</span> <span class="o">=</span> <span class="n">trt_tensor</span>
|
||
<span class="n">tensor</span><span class="o">.</span><span class="n">producer</span> <span class="o">=</span> <span class="n">producer</span>
|
||
|
||
<span class="c1"># Set the layer name since this is the only</span>
|
||
<span class="c1"># centralized location to pass the name from</span>
|
||
<span class="c1"># module space to the TRT IR</span>
|
||
<span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">_set_layer_name</span><span class="p">(</span><span class="n">producer</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">dtype</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">strongly_typed</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="n">producer</span><span class="o">.</span><span class="n">type</span> <span class="ow">not</span> <span class="ow">in</span> <span class="p">[</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">LayerType</span><span class="o">.</span><span class="n">CONSTANT</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">LayerType</span><span class="o">.</span><span class="n">GATHER</span><span class="p">,</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">LayerType</span><span class="o">.</span><span class="n">CONCATENATION</span>
|
||
<span class="p">]:</span>
|
||
<span class="n">producer</span><span class="o">.</span><span class="n">precision</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">dtype</span>
|
||
<span class="k">assert</span> <span class="n">tensor</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
|
||
<span class="k">if</span> <span class="n">gw</span><span class="o">.</span><span class="n">FLayerInfoMemo</span><span class="o">.</span><span class="n">instance</span><span class="p">()</span><span class="o">.</span><span class="n">cur_flayer</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">gw</span><span class="o">.</span><span class="n">FLayerInfoMemo</span><span class="o">.</span><span class="n">instance</span><span class="p">()</span><span class="o">.</span><span class="n">cur_flayer</span><span class="o">.</span><span class="n">layer_name</span> <span class="o">=</span> <span class="n">producer</span><span class="o">.</span><span class="n">name</span>
|
||
|
||
<span class="k">return</span> <span class="n">tensor</span>
|
||
|
||
|
||
<div class="viewcode-block" id="RotaryScalingType">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.RotaryScalingType">[docs]</a>
|
||
<span class="k">class</span> <span class="nc">RotaryScalingType</span><span class="p">(</span><span class="n">IntEnum</span><span class="p">):</span>
|
||
<span class="n">none</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="n">linear</span> <span class="o">=</span> <span class="mi">1</span>
|
||
<span class="n">dynamic</span> <span class="o">=</span> <span class="mi">2</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="PositionEmbeddingType">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.PositionEmbeddingType">[docs]</a>
|
||
<span class="k">class</span> <span class="nc">PositionEmbeddingType</span><span class="p">(</span><span class="n">IntEnum</span><span class="p">):</span>
|
||
<span class="n">learned_absolute</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="n">rope_gptj</span> <span class="o">=</span> <span class="mi">1</span>
|
||
<span class="n">rope_gpt_neox</span> <span class="o">=</span> <span class="mi">2</span>
|
||
<span class="n">alibi</span> <span class="o">=</span> <span class="mi">3</span>
|
||
<span class="n">alibi_with_scale</span> <span class="o">=</span> <span class="mi">4</span>
|
||
<span class="n">relative</span> <span class="o">=</span> <span class="mi">5</span>
|
||
<span class="n">chatglm</span> <span class="o">=</span> <span class="mi">6</span>
|
||
|
||
<div class="viewcode-block" id="PositionEmbeddingType.is_rope">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.PositionEmbeddingType.is_rope">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">is_rope</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">bool</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="bp">self</span> <span class="ow">in</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">rope_gptj</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">rope_gpt_neox</span><span class="p">]</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="PositionEmbeddingType.is_alibi">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.PositionEmbeddingType.is_alibi">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">is_alibi</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">bool</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="bp">self</span> <span class="ow">in</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">alibi</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">alibi_with_scale</span><span class="p">]</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="PositionEmbeddingType.choices">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.PositionEmbeddingType.choices">[docs]</a>
|
||
<span class="nd">@staticmethod</span>
|
||
<span class="k">def</span> <span class="nf">choices</span><span class="p">()</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]:</span>
|
||
<span class="k">return</span> <span class="p">[</span><span class="n">embedding</span><span class="o">.</span><span class="n">name</span> <span class="k">for</span> <span class="n">embedding</span> <span class="ow">in</span> <span class="n">PositionEmbeddingType</span><span class="p">]</span></div>
|
||
</div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="AttentionMaskType">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.AttentionMaskType">[docs]</a>
|
||
<span class="k">class</span> <span class="nc">AttentionMaskType</span><span class="p">(</span><span class="n">IntEnum</span><span class="p">):</span>
|
||
<span class="n">padding</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="n">causal</span> <span class="o">=</span> <span class="mi">1</span>
|
||
<span class="n">bidirectional</span> <span class="o">=</span> <span class="mi">2</span>
|
||
<span class="n">bidirectionalglm</span> <span class="o">=</span> <span class="mi">3</span> <span class="c1"># TODO: merge this mask into bidirectional</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="LayerNormType">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.LayerNormType">[docs]</a>
|
||
<span class="k">class</span> <span class="nc">LayerNormType</span><span class="p">(</span><span class="n">IntEnum</span><span class="p">):</span>
|
||
<span class="n">LayerNorm</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="n">RmsNorm</span> <span class="o">=</span> <span class="mi">1</span>
|
||
<span class="n">GroupNorm</span> <span class="o">=</span> <span class="mi">2</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="LayerNormPositionType">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.LayerNormPositionType">[docs]</a>
|
||
<span class="k">class</span> <span class="nc">LayerNormPositionType</span><span class="p">(</span><span class="n">IntEnum</span><span class="p">):</span>
|
||
<span class="n">pre_layernorm</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="n">post_layernorm</span> <span class="o">=</span> <span class="mi">1</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="MLPType">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.MLPType">[docs]</a>
|
||
<span class="k">class</span> <span class="nc">MLPType</span><span class="p">(</span><span class="n">IntEnum</span><span class="p">):</span>
|
||
<span class="n">MLP</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="n">GatedMLP</span> <span class="o">=</span> <span class="mi">1</span>
|
||
<span class="n">FusedGatedMLP</span> <span class="o">=</span> <span class="mi">2</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="activation">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.activation">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">activation</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">act_type</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">ActivationType</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an activation function.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor on which the activation function is applied.</span>
|
||
|
||
<span class="sd"> act_type : trt.ActivationType</span>
|
||
<span class="sd"> The type of the activation (RELU, TANH, SIGMOID, ...).</span>
|
||
|
||
<span class="sd"> The following closures are defined in functional.*:</span>
|
||
|
||
<span class="sd"> relu for op=trt.ActivationType.RELU</span>
|
||
<span class="sd"> tanh for op=trt.ActivationType.TANH</span>
|
||
<span class="sd"> sigmoid for op=trt.ActivationType.SIGMOID</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the activation layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_activation</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">act_type</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="clip">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.clip">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">clip</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">alpha</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">beta</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add a CLIP operation that sets the range to [alpha, beta].</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor on which the activation function is applied.</span>
|
||
|
||
<span class="sd"> alpha : float</span>
|
||
<span class="sd"> The lower bound of the CLIP function.</span>
|
||
|
||
<span class="sd"> beta : float</span>
|
||
<span class="sd"> The upper bound of the CLIP function.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the activation layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_activation</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">ActivationType</span><span class="o">.</span><span class="n">CLIP</span><span class="p">)</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">alpha</span> <span class="o">=</span> <span class="n">alpha</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">beta</span> <span class="o">=</span> <span class="n">beta</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<span class="n">relu</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">activation</span><span class="p">,</span> <span class="n">act_type</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ActivationType</span><span class="o">.</span><span class="n">RELU</span><span class="p">)</span>
|
||
<span class="n">tanh</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">activation</span><span class="p">,</span> <span class="n">act_type</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ActivationType</span><span class="o">.</span><span class="n">TANH</span><span class="p">)</span>
|
||
<span class="n">sigmoid</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">activation</span><span class="p">,</span> <span class="n">act_type</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ActivationType</span><span class="o">.</span><span class="n">SIGMOID</span><span class="p">)</span>
|
||
|
||
|
||
<div class="viewcode-block" id="silu">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.silu">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">silu</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add a SiLU (`x * sigmoid(x)`) operation.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor on which the activation function is applied.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the activation layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="nb">input</span> <span class="o">*</span> <span class="n">sigmoid</span><span class="p">(</span><span class="nb">input</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="swiglu">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.swiglu">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">swiglu</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add a SwiGLU (`x * SiLU(gate)`) operation.</span>
|
||
|
||
<span class="sd"> That function takes a tensor, splits it into two halves along the last</span>
|
||
<span class="sd"> dimension, applies SiLU to the second half and multiply the results. The</span>
|
||
<span class="sd"> behaviour is undefined if the last dimension is not even.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor on which the activation function is applied.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the activation layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">x</span><span class="p">,</span> <span class="n">gate</span> <span class="o">=</span> <span class="n">chunk</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">silu</span><span class="p">(</span><span class="n">gate</span><span class="p">)</span> <span class="o">*</span> <span class="n">x</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="squared_relu">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.squared_relu">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">squared_relu</span><span class="p">(</span><span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add a Squared ReLU operation.</span>
|
||
|
||
<span class="sd"> This function applies ReLU and squares the output.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor on which the activation function is applied.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the activation layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="nb">pow</span><span class="p">(</span><span class="n">relu</span><span class="p">(</span><span class="n">x</span><span class="p">),</span> <span class="mf">2.0</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="cast">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.cast">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">cast</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dtype</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="p">]):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add a cast operation.</span>
|
||
|
||
<span class="sd"> For an input tensor of type INT8, this function sets the dynamic range of</span>
|
||
<span class="sd"> the input to [-127, 127] for automatic dequantization. For a cast into</span>
|
||
<span class="sd"> INT8, that function sets the dynamic range of the output to [-127, 127] for</span>
|
||
<span class="sd"> automatic quantization.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor on which the cast is applied.</span>
|
||
|
||
<span class="sd"> dtype : str or trt.DataType</span>
|
||
<span class="sd"> The data type of the output tensor after the cast. When 'dtype' is</span>
|
||
<span class="sd"> provided as a string, it must be a name amongst the valid names.</span>
|
||
<span class="sd"> See _str_to_trt_dtype_dict in _utils.py for a list of supported</span>
|
||
<span class="sd"> types and type names.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the inserted layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dtype</span><span class="p">,</span> <span class="nb">str</span><span class="p">):</span>
|
||
<span class="n">cvt_dtype</span> <span class="o">=</span> <span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span>
|
||
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dtype</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="p">):</span>
|
||
<span class="n">cvt_dtype</span> <span class="o">=</span> <span class="n">dtype</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">TypeError</span><span class="p">(</span><span class="s2">"</span><span class="si">%s</span><span class="s2"> is not supported"</span> <span class="o">%</span> <span class="nb">type</span><span class="p">(</span><span class="n">dtype</span><span class="p">))</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_cast</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">cvt_dtype</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">strongly_typed</span><span class="p">:</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">set_output_type</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">cvt_dtype</span><span class="p">)</span>
|
||
<span class="n">output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="nb">input</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="s1">'int8'</span><span class="p">):</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">get_input</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">set_dynamic_range</span><span class="p">(</span><span class="o">-</span><span class="mi">127</span><span class="p">,</span> <span class="mi">127</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">cvt_dtype</span> <span class="o">==</span> <span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="s1">'int8'</span><span class="p">):</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">set_dynamic_range</span><span class="p">(</span><span class="o">-</span><span class="mi">127</span><span class="p">,</span> <span class="mi">127</span><span class="p">)</span>
|
||
|
||
<span class="k">return</span> <span class="n">output</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="flip">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.flip">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">flip</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dims</span><span class="p">:</span> <span class="n">Sequence</span><span class="p">[</span><span class="nb">int</span><span class="p">])</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Reverses the order of an n-D tensor along given axis in dims.</span>
|
||
|
||
<span class="sd"> That flip operation maps to a TensorRT ISliceLayer. For the dimensions</span>
|
||
<span class="sd"> listed in dims it copies the elements from the last one to the first one</span>
|
||
<span class="sd"> (from (N-1) down to 0 with a step of -1). For the dimensions not in 'dims',</span>
|
||
<span class="sd"> it copies the elements from the first one to the last one (from 0 to N-1</span>
|
||
<span class="sd"> with a step of 1).</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor on which the cast is applied.</span>
|
||
|
||
<span class="sd"> dims : list or tuple</span>
|
||
<span class="sd"> The axes to flip. Negative indices are supported.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the inserted layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">assert</span> <span class="ow">not</span> <span class="nb">input</span><span class="o">.</span><span class="n">is_dynamic</span><span class="p">()</span>
|
||
|
||
<span class="n">ndim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
|
||
|
||
<span class="k">for</span> <span class="n">index</span><span class="p">,</span> <span class="n">value</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">dims</span><span class="p">):</span>
|
||
<span class="k">assert</span> <span class="o">-</span><span class="n">ndim</span> <span class="o"><=</span> <span class="n">value</span> <span class="o"><</span> <span class="n">ndim</span>
|
||
<span class="k">if</span> <span class="o">-</span><span class="n">ndim</span> <span class="o"><=</span> <span class="n">value</span> <span class="o"><</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="n">dims</span><span class="p">[</span><span class="n">index</span><span class="p">]</span> <span class="o">+=</span> <span class="n">ndim</span>
|
||
|
||
<span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">dims</span><span class="p">)</span> <span class="o">==</span> <span class="nb">len</span><span class="p">(</span><span class="nb">set</span><span class="p">(</span><span class="n">dims</span><span class="p">))</span>
|
||
|
||
<span class="n">start_values</span> <span class="o">=</span> <span class="p">[</span>
|
||
<span class="nb">input</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="n">i</span><span class="p">]</span> <span class="o">-</span> <span class="mi">1</span> <span class="k">if</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">dims</span> <span class="k">else</span> <span class="mi">0</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)</span>
|
||
<span class="p">]</span>
|
||
<span class="n">stride_values</span> <span class="o">=</span> <span class="p">[</span><span class="o">-</span><span class="mi">1</span> <span class="k">if</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">dims</span> <span class="k">else</span> <span class="mi">1</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)]</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_slice</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">start</span><span class="o">=</span><span class="n">start_values</span><span class="p">,</span>
|
||
<span class="n">shape</span><span class="o">=</span><span class="nb">input</span><span class="o">.</span><span class="n">size</span><span class="p">(),</span>
|
||
<span class="n">stride</span><span class="o">=</span><span class="n">stride_values</span><span class="p">)</span>
|
||
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="interpolate">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.interpolate">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">interpolate</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">size</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">scale_factor</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">mode</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s1">'nearest'</span><span class="p">,</span>
|
||
<span class="n">align_corners</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">recompute_scale_factor</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">antialias</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="c1">##</span>
|
||
<span class="c1">## TODO: Document that function!</span>
|
||
<span class="c1">##</span>
|
||
|
||
<span class="k">assert</span> <span class="ow">not</span> <span class="nb">input</span><span class="o">.</span><span class="n">is_dynamic</span><span class="p">()</span>
|
||
|
||
<span class="n">input_ndim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
|
||
|
||
<span class="k">assert</span> <span class="mi">2</span> <span class="o"><</span> <span class="n">input_ndim</span> <span class="o"><</span> <span class="mi">6</span><span class="p">,</span> <span class="s2">"Only 3D, 4D and 5D input Tensors supported"</span>
|
||
<span class="k">assert</span> <span class="p">(</span><span class="n">size</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">)</span> <span class="o">^</span> <span class="p">(</span>
|
||
<span class="n">scale_factor</span>
|
||
<span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">),</span> <span class="s2">"Only one of out_shape or scales should be defined"</span>
|
||
|
||
<span class="k">assert</span> <span class="n">mode</span> <span class="ow">in</span> <span class="p">(</span><span class="s1">'nearest'</span><span class="p">,</span> <span class="s1">'linear'</span><span class="p">,</span> <span class="s1">'bilinear'</span><span class="p">,</span> <span class="s1">'bicubic'</span><span class="p">,</span> <span class="s1">'trilinear'</span><span class="p">,</span>
|
||
<span class="s1">'nearest-exact'</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">mode</span> <span class="o">==</span> <span class="s1">'trilinear'</span> <span class="ow">and</span> <span class="n">input_ndim</span> <span class="o">!=</span> <span class="mi">5</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"trilinear only supports 5D tensor"</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">mode</span> <span class="o">==</span> <span class="s2">"bilinear"</span> <span class="ow">and</span> <span class="n">input_ndim</span> <span class="o">!=</span> <span class="mi">4</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"bilinear only supports 4D tensor"</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">mode</span> <span class="o">==</span> <span class="s2">"linear"</span> <span class="ow">and</span> <span class="n">input_ndim</span> <span class="o">!=</span> <span class="mi">3</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"linear only supports 3D tensor"</span><span class="p">)</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_resize</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
|
||
<span class="n">input_shape</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">size</span><span class="p">()</span>
|
||
|
||
<span class="n">updated_shape</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="k">if</span> <span class="n">scale_factor</span><span class="p">:</span>
|
||
<span class="n">scale_len</span> <span class="o">=</span> <span class="mi">1</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scale_factor</span><span class="p">,</span>
|
||
<span class="p">(</span><span class="nb">float</span><span class="p">,</span> <span class="nb">int</span><span class="p">))</span> <span class="k">else</span> <span class="nb">len</span><span class="p">(</span><span class="n">scale_factor</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">scale_len</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">and</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scale_factor</span><span class="p">,</span> <span class="p">(</span><span class="nb">float</span><span class="p">,</span> <span class="nb">int</span><span class="p">)):</span>
|
||
<span class="n">updated_scale</span> <span class="o">=</span> <span class="p">[</span><span class="n">scale_factor</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">input_ndim</span> <span class="o">-</span> <span class="mi">2</span><span class="p">)]</span>
|
||
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">updated_scale</span> <span class="o">=</span> <span class="n">scale_factor</span>
|
||
<span class="n">updated_shape</span> <span class="o">=</span> <span class="p">[</span>
|
||
<span class="nb">int</span><span class="p">(</span><span class="n">math</span><span class="o">.</span><span class="n">floor</span><span class="p">(</span><span class="n">updated_scale</span><span class="p">[</span><span class="n">i</span> <span class="o">-</span> <span class="mi">2</span><span class="p">]</span> <span class="o">*</span>
|
||
<span class="n">input_shape</span><span class="p">[</span><span class="n">i</span><span class="p">]))</span> <span class="k">if</span> <span class="n">i</span> <span class="o">></span> <span class="mi">1</span> <span class="k">else</span> <span class="n">input_shape</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">input_ndim</span><span class="p">)</span>
|
||
<span class="p">]</span>
|
||
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">size_len</span> <span class="o">=</span> <span class="mi">1</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="nb">int</span><span class="p">)</span> <span class="k">else</span> <span class="nb">len</span><span class="p">(</span><span class="n">size</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">size_len</span> <span class="o">==</span> <span class="n">input_ndim</span> <span class="o">-</span> <span class="mi">2</span>
|
||
<span class="k">if</span> <span class="n">size_len</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">and</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
|
||
<span class="n">updated_size</span> <span class="o">=</span> <span class="p">[</span><span class="n">size</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">input_ndim</span> <span class="o">-</span> <span class="mi">2</span><span class="p">)]</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">updated_size</span> <span class="o">=</span> <span class="n">size</span>
|
||
|
||
<span class="n">updated_shape</span> <span class="o">=</span> <span class="p">[</span>
|
||
<span class="n">input_shape</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="k">if</span> <span class="n">i</span> <span class="o"><</span> <span class="mi">2</span> <span class="k">else</span> <span class="n">updated_size</span><span class="p">[</span><span class="n">i</span> <span class="o">-</span> <span class="mi">2</span><span class="p">]</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">input_ndim</span><span class="p">)</span>
|
||
<span class="p">]</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">shape</span> <span class="o">=</span> <span class="n">updated_shape</span>
|
||
|
||
<span class="k">if</span> <span class="n">mode</span> <span class="ow">in</span> <span class="p">[</span><span class="s1">'nearest'</span><span class="p">,</span> <span class="s1">'nearest-exact'</span><span class="p">]</span> <span class="ow">or</span> <span class="n">mode</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">resize_mode</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">ResizeMode</span><span class="o">.</span><span class="n">NEAREST</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">coordinate_transformation</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">ResizeCoordinateTransformation</span><span class="o">.</span><span class="n">ASYMMETRIC</span>
|
||
<span class="k">elif</span> <span class="n">mode</span> <span class="ow">in</span> <span class="p">[</span><span class="s1">'linear'</span><span class="p">,</span> <span class="s1">'bilinear'</span><span class="p">,</span> <span class="s1">'trilinear'</span><span class="p">]:</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">resize_mode</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">ResizeMode</span><span class="o">.</span><span class="n">LINEAR</span>
|
||
<span class="k">if</span> <span class="n">align_corners</span><span class="p">:</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">coordinate_transformation</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">ResizeCoordinateTransformation</span><span class="o">.</span><span class="n">ALIGN_CORNERS</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">coordinate_transformation</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">ResizeCoordinateTransformation</span><span class="o">.</span><span class="n">HALF_PIXEL</span>
|
||
<span class="c1"># TODO, need to confirm the align_corners effect on bilinear mode.</span>
|
||
<span class="k">if</span> <span class="n">mode</span> <span class="o">==</span> <span class="s1">'bilinear'</span><span class="p">:</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">coordinate_transformation</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">ResizeCoordinateTransformation</span><span class="o">.</span><span class="n">HALF_PIXEL</span>
|
||
|
||
<span class="k">elif</span> <span class="n">mode</span> <span class="ow">in</span> <span class="p">[</span><span class="s1">'bicubic'</span><span class="p">]:</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">resize_mode</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">ResizeMode</span><span class="o">.</span><span class="n">CUBIC</span>
|
||
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">coordinate_transformation</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">ResizeCoordinateTransformation</span><span class="o">.</span><span class="n">HALF_PIXEL</span>
|
||
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">resize_mode</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">ResizeMode</span><span class="o">.</span><span class="n">NEAREST</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">coordinate_transformation</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">ResizeCoordinateTransformation</span><span class="o">.</span><span class="n">ASYMMETRIC</span>
|
||
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="matmul">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.matmul">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">matmul</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">mat2</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">transa</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">transb</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add a matrix multiplication.</span>
|
||
|
||
<span class="sd"> That operation maps to a tensorrt.IMatrixMultiplyLayer layer. As explained</span>
|
||
<span class="sd"> in the TensorRT documentation, it computes the inner product between the</span>
|
||
<span class="sd"> two inputs after applying an optional transposition on the inputs.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The first tensor (often called A).</span>
|
||
|
||
<span class="sd"> mat2 : Tensor</span>
|
||
<span class="sd"> The second tensor (often called B).</span>
|
||
|
||
<span class="sd"> transa : bool</span>
|
||
<span class="sd"> Is the first input transposed? Set to 'True' if you want the first</span>
|
||
<span class="sd"> input to be transposed, 'False' otherwise.</span>
|
||
|
||
<span class="sd"> transb : bool</span>
|
||
<span class="sd"> Is the second input transposed? Set to 'True' if you want the</span>
|
||
<span class="sd"> second input to be transposed, 'False' otherwise.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the inserted layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="nb">input</span><span class="p">,</span> <span class="n">mat2</span> <span class="o">=</span> <span class="n">broadcast_helper</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">mat2</span><span class="p">)</span>
|
||
<span class="n">op0</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">MatrixOperation</span><span class="o">.</span><span class="n">TRANSPOSE</span> <span class="k">if</span> <span class="n">transa</span> \
|
||
<span class="k">else</span> <span class="n">trt</span><span class="o">.</span><span class="n">MatrixOperation</span><span class="o">.</span><span class="n">NONE</span>
|
||
<span class="n">op1</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">MatrixOperation</span><span class="o">.</span><span class="n">TRANSPOSE</span> <span class="k">if</span> <span class="n">transb</span> \
|
||
<span class="k">else</span> <span class="n">trt</span><span class="o">.</span><span class="n">MatrixOperation</span><span class="o">.</span><span class="n">NONE</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_matrix_multiply</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">op0</span><span class="p">,</span>
|
||
<span class="n">mat2</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">op1</span><span class="p">)</span>
|
||
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="constant">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.constant">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">constant</span><span class="p">(</span><span class="n">ndarray</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add a constant layer.</span>
|
||
|
||
<span class="sd"> TensorRT graphs encapsulate constant values in the form of constant layers</span>
|
||
<span class="sd"> (tensorrt.IConstantLayer). That function creates such a layer from a Numpy</span>
|
||
<span class="sd"> array of values. After compilation of the network by TensorRT, those</span>
|
||
<span class="sd"> weights are stored in the serialized TensorRT engine.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> ndarray : numpy.ndarray</span>
|
||
<span class="sd"> The array of values (weights) encapsulated by this constant layer.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the inserted layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">weights</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">Weights</span><span class="p">(</span><span class="n">np_dtype_to_trt</span><span class="p">(</span><span class="n">ndarray</span><span class="o">.</span><span class="n">dtype</span><span class="p">),</span> <span class="n">ndarray</span><span class="o">.</span><span class="n">ctypes</span><span class="o">.</span><span class="n">data</span><span class="p">,</span>
|
||
<span class="n">ndarray</span><span class="o">.</span><span class="n">size</span><span class="p">)</span>
|
||
<span class="c1"># Prevent underlying numpy array from going out of scope</span>
|
||
<span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">register_ndarray</span><span class="p">(</span><span class="n">ndarray</span><span class="p">)</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_constant</span><span class="p">(</span><span class="n">trt</span><span class="o">.</span><span class="n">Dims</span><span class="p">(</span><span class="n">ndarray</span><span class="o">.</span><span class="n">shape</span><span class="p">),</span> <span class="n">weights</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">_strongly_typed</span><span class="p">:</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">set_output_type</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">np_dtype_to_trt</span><span class="p">(</span><span class="n">ndarray</span><span class="o">.</span><span class="n">dtype</span><span class="p">))</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<span class="c1"># TODO: TensorRT uses sizes of the output dimensions.</span>
|
||
<span class="c1"># DL framework uses ends usually. Will change it to ends.</span>
|
||
<div class="viewcode-block" id="slice">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.slice">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">slice</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">starts</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">Sequence</span><span class="p">[</span><span class="nb">int</span><span class="p">]],</span>
|
||
<span class="n">sizes</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">Sequence</span><span class="p">[</span><span class="nb">int</span><span class="p">]],</span>
|
||
<span class="n">strides</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">Sequence</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to extract a slice from a tensor.</span>
|
||
|
||
<span class="sd"> As described in the TensorRT documentation of the ISliceLayer, the slice</span>
|
||
<span class="sd"> layer has two variants: Static and dynamic.</span>
|
||
|
||
<span class="sd"> For static slicing, this function takes the starts and sizes values in the</span>
|
||
<span class="sd"> different dimensions to slice at layer creation time via a sequence of</span>
|
||
<span class="sd"> integers. For dynamic slicing, it accepts starts and sizes as</span>
|
||
<span class="sd"> tensorrt.ITensor`s.</span>
|
||
|
||
<span class="sd"> The slice layer selects for each dimension a start location from within the</span>
|
||
<span class="sd"> input tensor, and copies elements to the output tensor using a stride of 1</span>
|
||
<span class="sd"> across the input tensor. Start and size tensors must be 1-D int32 shape</span>
|
||
<span class="sd"> tensors if not specified as a sequence of integers.</span>
|
||
|
||
<span class="sd"> As an example, on input = [[0, 2, 4], [1, 3, 5]], the call to</span>
|
||
|
||
<span class="sd"> slice(input, start=[1, 0], size=[1, 2])</span>
|
||
|
||
<span class="sd"> will produce the tensor [[1, 3]] as output. The slice operator when</span>
|
||
<span class="sd"> executed by TensorRT will copy one row (because size[0] == 1) starting from</span>
|
||
<span class="sd"> the 2nd row (because start[0] == 1) and two columns (size[1] == 2) starting</span>
|
||
<span class="sd"> from the 1st column (because start[1] == 0).</span>
|
||
|
||
<span class="sd"> In pseudo-code the behaviour of that operation can be described as follows</span>
|
||
<span class="sd"> for a 2D tensor (and easily be extended to more dimensions):</span>
|
||
|
||
<span class="sd"> output = Tensor(shape=sizes)</span>
|
||
<span class="sd"> for ii in range(sizes[0]):</span>
|
||
<span class="sd"> for jj in range(sizes[1]):</span>
|
||
<span class="sd"> output[ii][jj] = input[starts[0]+ii][starts[1]+jj]</span>
|
||
|
||
<span class="sd"> Note that it is common in deep-learning frameworks to use ranges</span>
|
||
<span class="sd"> [start:end] for similar operations. It can be emulated by setting the sizes</span>
|
||
<span class="sd"> argument such that in each dimension [start:start+size] == [start:end] i.e.</span>
|
||
<span class="sd"> size = end-start.</span>
|
||
|
||
<span class="sd"> TensorRT supports different slice modes but that function restricts that</span>
|
||
<span class="sd"> choice to `mode == tensorrt.SliceMode.STRICT_BOUNDS`.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor on which the slicing is performed.</span>
|
||
|
||
<span class="sd"> starts : Union[Tensor, Sequence[int]]</span>
|
||
<span class="sd"> The starting points, in the input tensor, and each dimension.</span>
|
||
|
||
<span class="sd"> sizes : Union[Tensor, Sequence[int]]</span>
|
||
<span class="sd"> The number of elements in each dimension of the sliced tensor (output).</span>
|
||
|
||
<span class="sd"> strides : Union[Tensor, Sequence[int]]</span>
|
||
<span class="sd"> The step be taken from start, in input tensor.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the slice layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">input_ndim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
|
||
|
||
<span class="n">trt_starts</span> <span class="o">=</span> <span class="n">starts</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">starts</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">):</span>
|
||
<span class="n">trt_starts</span> <span class="o">=</span> <span class="p">[</span><span class="mi">0</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">input_ndim</span><span class="p">)]</span> <span class="c1"># unused dummy value</span>
|
||
|
||
<span class="n">trt_sizes</span> <span class="o">=</span> <span class="n">sizes</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">sizes</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">):</span>
|
||
<span class="n">trt_sizes</span> <span class="o">=</span> <span class="p">[</span><span class="mi">1</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">input_ndim</span><span class="p">)]</span> <span class="c1"># unused dummy value</span>
|
||
|
||
<span class="n">trt_strides</span> <span class="o">=</span> <span class="n">strides</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">strides</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">)</span> <span class="ow">or</span> <span class="n">strides</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">trt_strides</span> <span class="o">=</span> <span class="p">[</span><span class="mi">1</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">input_ndim</span><span class="p">)]</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_slice</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">start</span><span class="o">=</span><span class="n">trt_starts</span><span class="p">,</span>
|
||
<span class="n">shape</span><span class="o">=</span><span class="n">trt_sizes</span><span class="p">,</span>
|
||
<span class="n">stride</span><span class="o">=</span><span class="n">trt_strides</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">starts</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">):</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">starts</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">sizes</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">):</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">sizes</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">strides</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">):</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="n">strides</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<span class="c1"># TODO: support step.</span>
|
||
<div class="viewcode-block" id="arange">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.arange">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">arange</span><span class="p">(</span><span class="n">start</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">],</span> <span class="n">end</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">],</span>
|
||
<span class="n">dtype</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to fill a 1D tensor.</span>
|
||
|
||
<span class="sd"> The tensor is filled with the values between start and end with a step of 1</span>
|
||
<span class="sd"> between the different elements. In pseudo-code, it corresponds to a tensor</span>
|
||
<span class="sd"> populated with the values:</span>
|
||
|
||
<span class="sd"> output = Tensor([dtype(ii) for ii in range(start, end, 1)])</span>
|
||
|
||
<span class="sd"> For example, a call to arange(3, 6, 'int32') will add an operation to the</span>
|
||
<span class="sd"> TensorRT graph that will produce [3, 4, 5] when executed. The call to</span>
|
||
<span class="sd"> arange(2, 5, 'float32') will add a layer to generate [2.0, 3.0, 4.0].</span>
|
||
|
||
<span class="sd"> This operation is implemented using a tensorrt.IFillLayer in</span>
|
||
<span class="sd"> trt.FillOperation.LINSPACE mode.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> start : Union[Tensor, int]</span>
|
||
<span class="sd"> The starting point of the range.</span>
|
||
|
||
<span class="sd"> end : Union[Tensor, int]</span>
|
||
<span class="sd"> The end point of the range.</span>
|
||
|
||
<span class="sd"> dtype : str</span>
|
||
<span class="sd"> The type of the elements. See _str_to_trt_dtype_dict in _utils.py</span>
|
||
<span class="sd"> for a list of supported types and type names.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the fill layer. It is a 1D tensor containing</span>
|
||
<span class="sd"> `end-start` elements of type `dtype`.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">start</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
|
||
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">end</span><span class="p">,</span> <span class="nb">int</span><span class="p">)</span>
|
||
<span class="n">start</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">int32_array</span><span class="p">(</span><span class="n">start</span><span class="p">))</span>
|
||
<span class="n">end</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">int32_array</span><span class="p">(</span><span class="n">end</span><span class="p">))</span>
|
||
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">start</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">):</span>
|
||
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">end</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">TypeError</span><span class="p">(</span><span class="s2">"</span><span class="si">%s</span><span class="s2"> is not supported"</span> <span class="o">%</span> <span class="nb">type</span><span class="p">(</span><span class="n">start</span><span class="p">))</span>
|
||
|
||
<span class="n">step</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">int32_array</span><span class="p">([</span><span class="mi">1</span><span class="p">]))</span>
|
||
|
||
<span class="n">num</span> <span class="o">=</span> <span class="n">end</span> <span class="o">-</span> <span class="n">start</span>
|
||
<span class="n">num</span> <span class="o">=</span> <span class="n">num</span><span class="o">.</span><span class="n">view</span><span class="p">([</span><span class="mi">1</span><span class="p">])</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_fill</span><span class="p">([</span><span class="mi">0</span><span class="p">],</span> <span class="n">trt</span><span class="o">.</span><span class="n">FillOperation</span><span class="o">.</span><span class="n">LINSPACE</span><span class="p">)</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">num</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span> <span class="c1"># rank = 1</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">start</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span> <span class="c1"># rank = 0</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">step</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span> <span class="c1"># rank = 1</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="expand">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.expand">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">expand</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">expand_shape</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to expand a tensor.</span>
|
||
|
||
<span class="sd"> The operation expands the input tensor in the singleton dimensions to the</span>
|
||
<span class="sd"> size indicated by the corresponding dimension in the `expand_shape` tensor.</span>
|
||
<span class="sd"> In other words, given an input tensor with dimensions of size 1, those</span>
|
||
<span class="sd"> dimensions will be expanded to the size in `expand_shape`.</span>
|
||
|
||
<span class="sd"> For example, a tensor of shape [4, 3, 1, 3] will be expanded to a tensor of</span>
|
||
<span class="sd"> shape [4, 3, 2, 3] by the layer created using expand(input, [4, 3, 2, 3]).</span>
|
||
|
||
<span class="sd"> The expansion may either replicate the values or be mapped to a view with a</span>
|
||
<span class="sd"> stride of 0 in the expanded dimensions. For example, for a tensor [[3, 2]] of</span>
|
||
<span class="sd"> shape [1, 2],</span>
|
||
|
||
<span class="sd"> expand([[3, 2]], [2, 2])</span>
|
||
|
||
<span class="sd"> can be used to expand the input to [[3, 2], [3, 2]].</span>
|
||
|
||
<span class="sd"> This operation is implemented using a tensorrt.ISliceLayer. The current</span>
|
||
<span class="sd"> implementation does not verify that non singleton dimensions are not</span>
|
||
<span class="sd"> shrunk. In other words, for an input of shape [4, 1, 2],</span>
|
||
|
||
<span class="sd"> expand(input, [3, 2, 2])</span>
|
||
|
||
<span class="sd"> will produce a tensor of shape [3, 2, 2]. That behaviour is subject to</span>
|
||
<span class="sd"> change in the future.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor.</span>
|
||
|
||
<span class="sd"> expand_shape : Tensor</span>
|
||
<span class="sd"> The new shape of the expanded tensor.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the expand layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">ndim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_slice</span><span class="p">(</span>
|
||
<span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">start</span><span class="o">=</span><span class="p">[</span><span class="mi">0</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)],</span>
|
||
<span class="n">shape</span><span class="o">=</span><span class="p">[</span><span class="mi">1</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)],</span> <span class="c1"># unused dummy value</span>
|
||
<span class="n">stride</span><span class="o">=</span><span class="p">[</span><span class="mi">1</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)]</span> <span class="c1"># unused dummy value</span>
|
||
<span class="p">)</span>
|
||
|
||
<span class="c1"># The stride is either:</span>
|
||
<span class="c1"># 0 for dimensions of size 1 (i.e. shape(input, i) - 1 == 1 - 1 == 0) or,</span>
|
||
<span class="c1"># 1 for dimensions of size > 1 since minimum(value >= 1, 1) == 1.</span>
|
||
<span class="n">stride_tensor</span> <span class="o">=</span> <span class="n">concat</span><span class="p">(</span>
|
||
<span class="p">[</span><span class="n">minimum</span><span class="p">((</span><span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">i</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span><span class="p">),</span> <span class="mi">1</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)])</span>
|
||
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">expand_shape</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="n">stride_tensor</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="einsum">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.einsum">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">einsum</span><span class="p">(</span><span class="n">einsum_eq</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">inputs</span><span class="p">:</span> <span class="n">Sequence</span><span class="p">[</span><span class="n">Tensor</span><span class="p">])</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an Einsum operation.</span>
|
||
|
||
<span class="sd"> That operation maps to tensorrt.IEinsumLayer. As explained in the TensorRT</span>
|
||
<span class="sd"> documentation, this layer implements a summation over the elements of the</span>
|
||
<span class="sd"> inputs along dimensions specified by the equation parameter, based on the</span>
|
||
<span class="sd"> Einstein summation convention. The layer can have one or more inputs of</span>
|
||
<span class="sd"> rank >= 0. All the inputs must be of same data type. This layer supports</span>
|
||
<span class="sd"> all TensorRT data types except bool. There is one output tensor of the same</span>
|
||
<span class="sd"> type as the input tensors. The shape of output tensor is determined by the</span>
|
||
<span class="sd"> equation.</span>
|
||
|
||
<span class="sd"> The equation specifies ASCII lower-case letters for each dimension in the</span>
|
||
<span class="sd"> inputs in the same order as the dimensions, separated by comma for each</span>
|
||
<span class="sd"> input. The dimensions labeled with the same subscript must match or be</span>
|
||
<span class="sd"> broadcastable. Repeated subscript labels in one input take the diagonal.</span>
|
||
<span class="sd"> Repeating a label across multiple inputs means that those axes will be</span>
|
||
<span class="sd"> multiplied. Omitting a label from the output means values along those axes</span>
|
||
<span class="sd"> will be summed. In implicit mode, the indices which appear once in the</span>
|
||
<span class="sd"> expression will be part of the output in increasing alphabetical order. In</span>
|
||
<span class="sd"> explicit mode, the output can be controlled by specifying output subscript</span>
|
||
<span class="sd"> labels by adding an arrow (‘->’) followed by subscripts for the output. For</span>
|
||
<span class="sd"> example, “ij,jk->ik” is equivalent to “ij,jk”. Ellipsis (‘…’) can be used</span>
|
||
<span class="sd"> in place of subscripts to broadcast the dimensions. See the TensorRT</span>
|
||
<span class="sd"> Developer Guide for more details on equation syntax.</span>
|
||
|
||
<span class="sd"> Many common operations can be expressed using the Einsum equation. For</span>
|
||
<span class="sd"> example:</span>
|
||
<span class="sd"> Matrix Transpose: ij->ji</span>
|
||
<span class="sd"> Sum: ij-> Matrix-Matrix</span>
|
||
<span class="sd"> Multiplication: ik,kj->ij</span>
|
||
<span class="sd"> Dot Product: i,i-></span>
|
||
<span class="sd"> Matrix-Vector Multiplication: ik,k->i</span>
|
||
<span class="sd"> Batch Matrix Multiplication: ijk,ikl->ijl</span>
|
||
<span class="sd"> Batch Diagonal: …ii->…i</span>
|
||
|
||
<span class="sd"> Note that TensorRT does not support ellipsis or diagonal operations so,</span>
|
||
<span class="sd"> neither, does TensorRT-LLM.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> einsum_eq : str</span>
|
||
<span class="sd"> The Einsum equation.</span>
|
||
|
||
<span class="sd"> inputs: Sequence[Tensor]</span>
|
||
<span class="sd"> The sequence of inputs consumed by the Einsum operation.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the Einsum operation.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_einsum</span><span class="p">([</span><span class="n">i</span><span class="o">.</span><span class="n">trt_tensor</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">inputs</span><span class="p">],</span>
|
||
<span class="n">einsum_eq</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="permute">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.permute">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">permute</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dims</span><span class="p">:</span> <span class="n">Sequence</span><span class="p">[</span><span class="nb">int</span><span class="p">])</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to permute the dimensions of a tensor.</span>
|
||
|
||
<span class="sd"> The dimensions of the input tensor are permutted according to the sequence</span>
|
||
<span class="sd"> of dimensions in 'dims'. That operation maps to tensorrt.IShuffleLayer where</span>
|
||
<span class="sd"> the second transposition is described by the indices in 'dims'.</span>
|
||
|
||
<span class="sd"> Given a tensor of rank N, the result of the permutation is a tensor of rank</span>
|
||
<span class="sd"> N in which the i-th input dimension maps to the dims[i]-th dimension.</span>
|
||
|
||
<span class="sd"> For example, permute(input, [1, 0]) will transpose a 2D tensor by permuting</span>
|
||
<span class="sd"> the rows and columns.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor to permute.</span>
|
||
|
||
<span class="sd"> dims : Sequence[int]</span>
|
||
<span class="sd"> The description of the permutation.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the permutation layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">dims</span> <span class="o">=</span> <span class="n">dim_resolve_negative</span><span class="p">(</span><span class="nb">tuple</span><span class="p">(</span><span class="n">dims</span><span class="p">),</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">())</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_shuffle</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">second_transpose</span> <span class="o">=</span> <span class="n">dims</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="transpose">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.transpose">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">transpose</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dim0</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">dim1</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to transpose two dimensions of a tensor.</span>
|
||
|
||
<span class="sd"> That operation produces a tensor in which the dimensions 'dim0' and 'dim1'</span>
|
||
<span class="sd"> are permuted. The other dimensions, if the rank of the tensor is greater</span>
|
||
<span class="sd"> than 2, remain untouched.</span>
|
||
|
||
<span class="sd"> That function is a helper built on the 'functional.permute' function.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor to transpose.</span>
|
||
|
||
<span class="sd"> dim0 : int</span>
|
||
<span class="sd"> The first dimension to transpose.</span>
|
||
|
||
<span class="sd"> dim1 : int</span>
|
||
<span class="sd"> The second dimension to transpose.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the permutation layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">permutation</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()))</span>
|
||
<span class="n">permutation</span><span class="p">[</span><span class="n">dim0</span><span class="p">]</span> <span class="o">=</span> <span class="n">dim1</span>
|
||
<span class="n">permutation</span><span class="p">[</span><span class="n">dim1</span><span class="p">]</span> <span class="o">=</span> <span class="n">dim0</span>
|
||
|
||
<span class="k">return</span> <span class="n">permute</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">permutation</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="view">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.view">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">view</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">shape</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">Sequence</span><span class="p">[</span><span class="nb">int</span><span class="p">]],</span>
|
||
<span class="n">zero_is_placeholder</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to create a view of a tensor.</span>
|
||
|
||
<span class="sd"> That operation adds a tensorrt.IShuffleLayer to the network. If the 'shape'</span>
|
||
<span class="sd"> parameter is a Tensor, that view is dynamic. Otherwise, it is a static</span>
|
||
<span class="sd"> view.</span>
|
||
|
||
<span class="sd"> Note that TensorRT limits the number of inferred dimensions to 1. It means</span>
|
||
<span class="sd"> that the shape sequence or tensor cannot contain more than one -1. This</span>
|
||
<span class="sd"> function enforces that constraint and will assert if it is not respected.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor to transpose.</span>
|
||
|
||
<span class="sd"> shape : Union[Tensor, Sequence[int]]</span>
|
||
<span class="sd"> The shape of the new tensor.</span>
|
||
|
||
<span class="sd"> zero_is_placeholder : bool</span>
|
||
<span class="sd"> When that parameter is True, the 0s in 'shape' are replaced by the</span>
|
||
<span class="sd"> sizes of the corresponding dimensions from the 'input'. Otherwise,</span>
|
||
<span class="sd"> the dimensions corresponding to 0s are shrunk.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the view/shuffle layer.</span>
|
||
<span class="sd"> '''</span>
|
||
|
||
<span class="c1"># TensorRT demands that at most one dimension is permitted to be specified as -1</span>
|
||
<span class="k">def</span> <span class="nf">assert_no_more_than_one_inferred_dim</span><span class="p">(</span><span class="nb">list</span><span class="p">):</span>
|
||
<span class="n">inferred_dim_list</span> <span class="o">=</span> <span class="p">[</span><span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">list</span> <span class="k">if</span> <span class="n">i</span> <span class="o">==</span> <span class="o">-</span><span class="mi">1</span><span class="p">]</span>
|
||
<span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">inferred_dim_list</span><span class="p">)</span> <span class="o"><=</span> <span class="mi">1</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_shuffle</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">zero_is_placeholder</span> <span class="o">=</span> <span class="n">zero_is_placeholder</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">shape</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">):</span>
|
||
<span class="n">assert_no_more_than_one_inferred_dim</span><span class="p">(</span><span class="n">shape</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">shape</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">shape</span><span class="p">,</span> <span class="p">(</span><span class="nb">list</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">)):</span>
|
||
<span class="n">assert_no_more_than_one_inferred_dim</span><span class="p">(</span><span class="n">shape</span><span class="p">)</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">reshape_dims</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">(</span><span class="n">shape</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">TypeError</span><span class="p">(</span><span class="s2">"</span><span class="si">%s</span><span class="s2"> is not supported"</span> <span class="o">%</span> <span class="nb">type</span><span class="p">(</span><span class="n">shape</span><span class="p">))</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="expand_dims">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.expand_dims">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">expand_dims</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">Sequence</span><span class="p">[</span><span class="nb">int</span><span class="p">]])</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to expand the tensor shape with singleton dimensions.</span>
|
||
|
||
<span class="sd"> That function adds a tensorrt.IShuffleLayer to the network. Given an 'input'</span>
|
||
<span class="sd"> of rank N and a sequence of M dimensions, the output tensor produced by</span>
|
||
<span class="sd"> this operation (when executed by TensorRT) will have a rank of N+M. Singleton</span>
|
||
<span class="sd"> dimensions will be inserted at the different positions in 'dim'.</span>
|
||
|
||
<span class="sd"> The pseudo-code for that operation is:</span>
|
||
|
||
<span class="sd"> new_shape, ii = [], 0</span>
|
||
<span class="sd"> for jj in range(input.rank() + len(dim)):</span>
|
||
<span class="sd"> new_shape.append(1 if jj in dims else input.shape[ii++])</span>
|
||
|
||
<span class="sd"> For example, for a tensor of shape [3, 4, 1, 5]</span>
|
||
|
||
<span class="sd"> expand_dims(input, [0, 2])</span>
|
||
|
||
<span class="sd"> will produce a tensor of shape [1, 3, 1, 4, 1, 5].</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor to expand.</span>
|
||
|
||
<span class="sd"> dim : Union[int, Sequence[int]]</span>
|
||
<span class="sd"> The positions in the output tensor where to insert singleton</span>
|
||
<span class="sd"> dimensions.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the shuffle layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
|
||
<span class="n">dim</span> <span class="o">=</span> <span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="p">)</span>
|
||
|
||
<span class="n">out_ndim</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span> <span class="o">+</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
|
||
|
||
<span class="n">input_shape</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">)</span>
|
||
<span class="n">out_shapes</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="n">j</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">out_ndim</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">dim</span><span class="p">:</span>
|
||
<span class="n">out_shapes</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">out_shapes</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">gather</span><span class="p">(</span><span class="n">input_shape</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">j</span><span class="p">))</span>
|
||
<span class="n">j</span> <span class="o">=</span> <span class="n">j</span> <span class="o">+</span> <span class="mi">1</span>
|
||
|
||
<span class="n">out_shape</span> <span class="o">=</span> <span class="n">concat</span><span class="p">(</span><span class="n">out_shapes</span><span class="p">)</span>
|
||
|
||
<span class="k">return</span> <span class="n">view</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">out_shape</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="unsqueeze">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.unsqueeze">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">unsqueeze</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">axis</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to insert a singleton dimension to a tensor.</span>
|
||
|
||
<span class="sd"> That functions creates an operation that insert a singleton dimension</span>
|
||
<span class="sd"> (dimension of size 1) at position 'dim' in the output tensor. It works with</span>
|
||
<span class="sd"> negative values for the 'axis'.</span>
|
||
|
||
<span class="sd"> For example, for a tensor 'input' of shape [4, 4]:</span>
|
||
|
||
<span class="sd"> unsqueeze(input, 0) will produce an output of shape [1, 4, 4],</span>
|
||
<span class="sd"> unsqueeze(input, 1) will produce an output of shape [4, 1, 4],</span>
|
||
<span class="sd"> unsqueeze(input, -1) will produce an output of shape [4, 4, 1],</span>
|
||
<span class="sd"> unsqueeze(input, -2) will produce an output of shape [4, 1, 4],</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor to expand with a singleton dimension.</span>
|
||
|
||
<span class="sd"> axis : int</span>
|
||
<span class="sd"> The index of the singleton dimension in the output tensor.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="n">axis</span> <span class="o"><</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="n">axis</span> <span class="o">=</span> <span class="n">axis</span> <span class="o">+</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">+</span> <span class="mi">1</span>
|
||
|
||
<span class="k">return</span> <span class="n">expand_dims</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">axis</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="expand_dims_like">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.expand_dims_like">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">expand_dims_like</span><span class="p">(</span><span class="n">left</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">],</span> <span class="n">right</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to expand the first tensor to the same rank as the second</span>
|
||
<span class="sd"> tensor.</span>
|
||
|
||
<span class="sd"> That function takes a first tensor. It also accepts an integer or a float,</span>
|
||
<span class="sd"> in which case it creates a constant tensor from it. In both cases, the rank</span>
|
||
<span class="sd"> of that first tensor is compared to the rank of the second tensor. If they</span>
|
||
<span class="sd"> are of the same rank, the first tensor is returned. Otherwise, the first</span>
|
||
<span class="sd"> tensor is expanded on the left to match the rank of the second tensor.</span>
|
||
|
||
<span class="sd"> Note that the shapes do not have to match, only the rank is considered in</span>
|
||
<span class="sd"> that function.</span>
|
||
|
||
<span class="sd"> For example, for a pair of tensors of shapes [3, 4] and [4, 3, 2], the</span>
|
||
<span class="sd"> first tensor will be expanded to a tensor of rank 3 and shape [1, 3, 4].</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> left : Union[Tensor, int, float]</span>
|
||
<span class="sd"> The first tensor to expand. When a scalar value is provided as a</span>
|
||
<span class="sd"> parameter, that function first creates a tensor before expanding it</span>
|
||
<span class="sd"> (if needed).</span>
|
||
|
||
<span class="sd"> right : Tensor</span>
|
||
<span class="sd"> The reference tensor to match.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the shuffle layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">left</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
|
||
<span class="n">left</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">int32_array</span><span class="p">([</span><span class="n">left</span><span class="p">]))</span>
|
||
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">left</span><span class="p">,</span> <span class="nb">float</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">strongly_typed</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">right</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">)</span> <span class="ow">and</span> <span class="n">right</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">HALF</span><span class="p">:</span>
|
||
<span class="n">left</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">fp16_array</span><span class="p">([</span><span class="n">left</span><span class="p">]))</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">left</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">fp32_array</span><span class="p">([</span><span class="n">left</span><span class="p">]))</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">left</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">fp32_array</span><span class="p">([</span><span class="n">left</span><span class="p">]))</span>
|
||
<span class="n">left_ndim</span> <span class="o">=</span> <span class="n">left</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
|
||
<span class="n">right_ndim</span> <span class="o">=</span> <span class="n">right</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
|
||
<span class="k">if</span> <span class="n">right_ndim</span> <span class="o">></span> <span class="n">left_ndim</span><span class="p">:</span>
|
||
<span class="n">new_ndim</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">right_ndim</span> <span class="o">-</span> <span class="n">left_ndim</span><span class="p">))</span>
|
||
<span class="k">return</span> <span class="n">expand_dims</span><span class="p">(</span><span class="n">left</span><span class="p">,</span> <span class="n">new_ndim</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">left</span></div>
|
||
|
||
|
||
|
||
<span class="c1"># If dim is None, return a 1-D TensorRT-LLM tensor of the size</span>
|
||
<span class="c1"># If dim is not None, return a 0-D TensorRT-LLM tensor of the dimension size</span>
|
||
<div class="viewcode-block" id="shape">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.shape">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to create a shape tensor.</span>
|
||
|
||
<span class="sd"> The shape tensor can either be the shape of the input tensor when the</span>
|
||
<span class="sd"> parameter dim is None or a scalar (tensor of rank 0) that corresponds to</span>
|
||
<span class="sd"> the size of dim-th dimension.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor from which we want to extract the shape or the</span>
|
||
<span class="sd"> size in one dimension.</span>
|
||
|
||
<span class="sd"> dim : Optional[int]</span>
|
||
<span class="sd"> The dimension from which to extract the size. If it is None, the</span>
|
||
<span class="sd"> entire shape of the input tensor is returned.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> A tensor that contains the shape of the input tensor (if 'dim' is None)</span>
|
||
<span class="sd"> or the size in the dimension 'dim' of the input tensor. If 'dim' is</span>
|
||
<span class="sd"> 'None', that tensor has the same rank as the input tensor, otherwise</span>
|
||
<span class="sd"> its rank is 0.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_shape</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="n">res</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">dim</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">res</span>
|
||
|
||
<span class="k">return</span> <span class="n">gather</span><span class="p">(</span><span class="n">res</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">indices</span><span class="o">=</span><span class="n">dim</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">([])</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="gather">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.gather">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">gather</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">indices</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">])</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to gather elements from a tensor.</span>
|
||
|
||
<span class="sd"> That function implements the GatherElements operator from the ONNX</span>
|
||
<span class="sd"> specification as described in</span>
|
||
|
||
<span class="sd"> https://github.com/onnx/onnx/blob/main/docs/Operators.md#GatherElements</span>
|
||
|
||
<span class="sd"> The input and indices arguments must have the same rank >= 1. The operation</span>
|
||
<span class="sd"> will produce a tensor with the same shape as the indices tensor. The axis</span>
|
||
<span class="sd"> is the dimension to gather on.</span>
|
||
|
||
<span class="sd"> As shown in the ONNX description, for a 3D tensor, the output is:</span>
|
||
|
||
<span class="sd"> out[i][j][k] = input[indices[i][j][k]][j][k] if axis = 0,</span>
|
||
<span class="sd"> out[i][j][k] = input[i][indices[i][j][k]][k] if axis = 1,</span>
|
||
<span class="sd"> out[i][j][k] = input[i][j][indices[i][j][k]] if axis = 2.</span>
|
||
|
||
<span class="sd"> For example,</span>
|
||
|
||
<span class="sd"> gather([[4, 2], [5, 3]], 0, [[1, 0], [0, 1]])</span>
|
||
|
||
<span class="sd"> will produce [[5, 2], [4, 3]].</span>
|
||
|
||
<span class="sd"> gather([[1, 2, 3], [4, 5, 6], 1, [[1], [0]])</span>
|
||
|
||
<span class="sd"> will produce [[2], [4]]. See the ONNX documentation for more examples.</span>
|
||
|
||
<span class="sd"> That operation maps to the TensorRT IGatherLayer.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor to gather elements from.</span>
|
||
|
||
<span class="sd"> dim : int</span>
|
||
<span class="sd"> The dimension to gather on.</span>
|
||
|
||
<span class="sd"> indices : Union[Tensor, int]</span>
|
||
<span class="sd"> The positions in the 'dim' dimension to gather from.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor containing the gathered elements. It has the same shape as</span>
|
||
<span class="sd"> the indices tensor.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">indices</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
|
||
<span class="n">indices</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">int32_array</span><span class="p">([</span><span class="n">indices</span><span class="p">]))</span>
|
||
|
||
<span class="c1"># The input and indices tensors must have the same rank.</span>
|
||
<span class="k">assert</span> <span class="nb">input</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span> <span class="o">==</span> <span class="n">indices</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_gather_v2</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">indices</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">mode</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">GatherMode</span><span class="o">.</span><span class="n">ELEMENT</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">dim</span> <span class="o"><</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="n">dim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">+</span> <span class="n">dim</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">axis</span> <span class="o">=</span> <span class="n">dim</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="select">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.select">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">select</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">index</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">])</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to select a slice of elements from a tensor.</span>
|
||
|
||
<span class="sd"> Given an input tensor, that function creates an operation that selects the</span>
|
||
<span class="sd"> index-th slice of elements in the dimension 'dim' to create a new tensor.</span>
|
||
<span class="sd"> The output tensor has a shape in which the input dimension 'dim' is</span>
|
||
<span class="sd"> removed.</span>
|
||
|
||
<span class="sd"> The 'index' can either be an integer or a 1D tensor containing a single</span>
|
||
<span class="sd"> element.</span>
|
||
|
||
<span class="sd"> For example, on input=[[4, 2, 5], [2, 1, 2], [4, 7, 1]], which has a shape</span>
|
||
<span class="sd"> [3, 3],</span>
|
||
|
||
<span class="sd"> select(input, 0, 1)</span>
|
||
|
||
<span class="sd"> will create a tensor of shape [3] that contains the [2, 1, 2].</span>
|
||
|
||
<span class="sd"> Regarding the shape of the output tensor, the dimension 'dim' is removed.</span>
|
||
<span class="sd"> It means that for a tensor of shape [4, 2, 6, 3],</span>
|
||
|
||
<span class="sd"> select(input, 2, 4)</span>
|
||
|
||
<span class="sd"> will select the 5th slice (index == 4) from the 3rd dimension (dim == 2)</span>
|
||
<span class="sd"> and return a tensor of shape [4, 2, 3] (i.e. the 3rd dimension is removed).</span>
|
||
|
||
<span class="sd"> That operation maps to the TensorRT IGatherLayer.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor to select from.</span>
|
||
|
||
<span class="sd"> dim : int</span>
|
||
<span class="sd"> The dimension to select from.</span>
|
||
|
||
<span class="sd"> index : Union[Tensor, int]</span>
|
||
<span class="sd"> The index of the slice in the 'dim' dimension to select.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor containing the selected slice.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">index</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
|
||
<span class="n">index</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">int32_array</span><span class="p">([</span><span class="n">index</span><span class="p">]))</span>
|
||
<span class="k">assert</span> <span class="n">index</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">and</span> <span class="n">index</span><span class="o">.</span><span class="n">size</span><span class="p">(</span>
|
||
<span class="mi">0</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"index should have rank 1, got </span><span class="si">{</span><span class="n">index</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span><span class="si">}</span><span class="s2">"</span>
|
||
|
||
<span class="n">new_shape</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">rank</span><span class="p">()):</span>
|
||
<span class="k">if</span> <span class="n">i</span> <span class="o">!=</span> <span class="n">dim</span><span class="p">:</span>
|
||
<span class="n">new_shape</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">i</span><span class="p">))</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_gather</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">index</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">dim</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">concat</span><span class="p">(</span><span class="n">new_shape</span><span class="p">))</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="index_select">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.index_select">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">index_select</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">index</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to select slices of elements from a tensor.</span>
|
||
|
||
<span class="sd"> Given an input tensor, that function creates an operation that selects the</span>
|
||
<span class="sd"> slices of elements in the dimension 'dim' at the indices listed in 'index'</span>
|
||
<span class="sd"> to create a new tensor. The output tensor has the same rank as the input</span>
|
||
<span class="sd"> tensor.</span>
|
||
|
||
<span class="sd"> The 'index' is a tensor of rank 1.</span>
|
||
|
||
<span class="sd"> For example, on input=[[4, 2, 5], [2, 1, 2], [4, 7, 1]], which has a shape</span>
|
||
<span class="sd"> [3, 3],</span>
|
||
|
||
<span class="sd"> index_select(input, 0, [0, 1])</span>
|
||
|
||
<span class="sd"> will create a tensor of shape [2, 3] that contains the [[4, 2, 5], [2, 1, 2]].</span>
|
||
|
||
<span class="sd"> Regarding the shape of the output tensor, the dimension 'dim' has the same</span>
|
||
<span class="sd"> size as the 'index' tensor. It means that for a input tensor of shape [4, 2, 6, 3],</span>
|
||
|
||
<span class="sd"> index_select(input, 2, [1, 4])</span>
|
||
|
||
<span class="sd"> will select the 2nd and 5th slices (index == 1 or 4) from the 3rd dimension</span>
|
||
<span class="sd"> (dim == 2) and return a tensor of shape [4, 2, 2, 3] (i.e. the 3rd</span>
|
||
<span class="sd"> dimension is shrunk to 2).</span>
|
||
|
||
<span class="sd"> Note that this operation can also be used to expand a tensor in the 'dim'</span>
|
||
<span class="sd"> dimension, for example, on input [[0, 1], [2, 3]],</span>
|
||
|
||
<span class="sd"> index_select(input, 1, [0, 0, 0])</span>
|
||
|
||
<span class="sd"> will produce a tensor of shape [2, 3] containing [[0, 0, 0], [2, 2, 2]].</span>
|
||
|
||
<span class="sd"> That operation maps to the TensorRT IGatherLayer.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor to select from.</span>
|
||
|
||
<span class="sd"> dim : int</span>
|
||
<span class="sd"> The dimension to select from.</span>
|
||
|
||
<span class="sd"> index : Tensor</span>
|
||
<span class="sd"> The indices of the slices in the 'dim' dimension to select.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor containing the selected slices.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">assert</span> <span class="n">index</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"index should have rank 1, got </span><span class="si">{</span><span class="n">index</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span><span class="si">}</span><span class="s2">"</span>
|
||
|
||
<span class="n">new_shape</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">rank</span><span class="p">()):</span>
|
||
<span class="k">if</span> <span class="n">i</span> <span class="o">!=</span> <span class="n">dim</span><span class="p">:</span>
|
||
<span class="n">new_shape</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">i</span><span class="p">))</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">new_shape</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">shape</span><span class="p">(</span><span class="n">index</span><span class="p">,</span> <span class="mi">0</span><span class="p">))</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_gather</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">index</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">dim</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">concat</span><span class="p">(</span><span class="n">new_shape</span><span class="p">))</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="concat">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.concat">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">concat</span><span class="p">(</span><span class="n">inputs</span><span class="p">:</span> <span class="n">Sequence</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">]],</span> <span class="n">dim</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to concatenate tensors.</span>
|
||
|
||
<span class="sd"> The function creates an operation that concatenates the tensors from the</span>
|
||
<span class="sd"> sequence 'inputs'. The concatenation is done along the dimension 'dim'.</span>
|
||
|
||
<span class="sd"> All the tensors in 'inputs' must have the same shape expect for the</span>
|
||
<span class="sd"> dimension 'dim'.</span>
|
||
|
||
<span class="sd"> for ii in range(inputs[0].rank()):</span>
|
||
<span class="sd"> assert (ii == dim) or all(inp.shape[ii] == inputs[0].shape[ii] for inp in inputs)</span>
|
||
|
||
<span class="sd"> The shape of the output tensor is defined as:</span>
|
||
|
||
<span class="sd"> for ii in range(inputs[0].rank()):</span>
|
||
<span class="sd"> # Same size as all the inputs in dimension ii != dim.</span>
|
||
<span class="sd"> output.shape[ii] = inputs[0].shape[ii]</span>
|
||
|
||
<span class="sd"> # Sum of the sizes in the different inputs in dimension 'dim'.</span>
|
||
<span class="sd"> if ii == dim:</span>
|
||
<span class="sd"> for jj in range(1, len(inputs)):</span>
|
||
<span class="sd"> output.shape[ii] += inputs[jj].shape[ii]</span>
|
||
|
||
<span class="sd"> For example, given a sequence of two 2D tensors [[0, 1], [2, 3]] and</span>
|
||
<span class="sd"> [[4, 5], [6, 7]] both of shape [2, 2],</span>
|
||
|
||
<span class="sd"> concat(inputs, 0)</span>
|
||
|
||
<span class="sd"> will produce [[[0, 1], [2, 3]], [[4, 5], [6, 7]]] of shape [4, 2] and</span>
|
||
|
||
<span class="sd"> concat(inputs, 1)</span>
|
||
|
||
<span class="sd"> will produce [[0, 1, 4, 5], [2, 3, 6, 7]] of shape [2, 4].</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> inputs : Sequence[Union[Tensor, int]]</span>
|
||
<span class="sd"> The sequence of tensors to concatenate. For integers, that function</span>
|
||
<span class="sd"> creates constant tensors.</span>
|
||
|
||
<span class="sd"> dim : int</span>
|
||
<span class="sd"> The dimension in which the concatenation is performed.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> A tensor that contains the concatenation of the tensors.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">tmp</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">inputs</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
|
||
<span class="n">tmp</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">constant</span><span class="p">(</span><span class="n">int32_array</span><span class="p">([</span><span class="n">i</span><span class="p">])))</span>
|
||
<span class="k">elif</span> <span class="n">i</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="n">tmp</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">i</span><span class="o">.</span><span class="n">view</span><span class="p">([</span><span class="mi">1</span><span class="p">]))</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">tmp</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">i</span><span class="p">)</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_concatenation</span><span class="p">([</span><span class="n">i</span><span class="o">.</span><span class="n">trt_tensor</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">tmp</span><span class="p">])</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">axis</span> <span class="o">=</span> <span class="n">dim</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="softmax">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.softmax">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">softmax</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to compute softmax on a tensor.</span>
|
||
|
||
<span class="sd"> That operation computes the softmax on the input tensor in the dimension</span>
|
||
<span class="sd"> 'dim' if specified. Otherwise, it is applied on the last dimension.</span>
|
||
|
||
<span class="sd"> It inserts a ISoftmaxLayer to the TensorRT graph.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor on which to apply softmax.</span>
|
||
|
||
<span class="sd"> dim : Optional[int]</span>
|
||
<span class="sd"> The dimension used to apply softmax.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The output tensor of the softmax layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="n">dim</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">dim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">-</span> <span class="mi">1</span>
|
||
<span class="k">if</span> <span class="n">dim</span> <span class="o"><</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="n">dim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">+</span> <span class="n">dim</span>
|
||
<span class="n">axes</span> <span class="o">=</span> <span class="n">dim_to_trt_axes</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_softmax</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">axes</span> <span class="o">=</span> <span class="n">axes</span>
|
||
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<span class="k">def</span> <span class="nf">_lookup_plugin</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">weight</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">rank</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to perform lookup in a tensor.</span>
|
||
|
||
<span class="sd"> That operation performs the lookup needed by embedding layers. Given a</span>
|
||
<span class="sd"> 'weight' tensor of shape [rows, cols], it produces a tensor of shape</span>
|
||
<span class="sd"> [inputs.size(0), cols] where the ith row corresponds to the input[i] row in</span>
|
||
<span class="sd"> the weight tensor.</span>
|
||
|
||
<span class="sd"> It inserts a IPluginV2Layer.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor contains the indices to perform the lookup.</span>
|
||
|
||
<span class="sd"> weight : Tensor</span>
|
||
<span class="sd"> The table to gather from.</span>
|
||
|
||
<span class="sd"> rank : int</span>
|
||
<span class="sd"> The mpi rank.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The output tensor of the lookup layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
|
||
<span class="s1">'Lookup'</span><span class="p">,</span> <span class="s1">'1'</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
|
||
<span class="n">p_dtype</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">lookup_plugin</span>
|
||
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"type_id"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">))],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
|
||
<span class="n">rank</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"rank"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">rank</span><span class="p">)],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
|
||
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span><span class="n">pf_type</span><span class="p">,</span> <span class="n">rank</span><span class="p">])</span>
|
||
<span class="n">lookup_plug</span> <span class="o">=</span> <span class="n">plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">"lookup"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">weight</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">]</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">lookup_plug</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
|
||
|
||
<div class="viewcode-block" id="embedding">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.embedding">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">embedding</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">weight</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">tp_size</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">tp_group</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">sharding_dim</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">tp_rank</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">workspace</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">instance_id</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to perform embedding lookup.</span>
|
||
|
||
<span class="sd"> That operation performs the embedding lookup. The 'input' tensor contains</span>
|
||
<span class="sd"> the identifiers of the rows of 'weight' to gather.</span>
|
||
|
||
<span class="sd"> 1. Distribute the embedding lookup table over multiple GPU</span>
|
||
<span class="sd"> When 'tp_size' is greater than 1 and the 'tp_group' is defined, this</span>
|
||
<span class="sd"> embedding lookup is distributed among multiple GPUs.</span>
|
||
|
||
<span class="sd"> When 'sharding_dim==0', each GPU stores a subset of the rows of the embedding</span>
|
||
<span class="sd"> table rows(that number of rows per GPU is given by weights.shape[0] and the offset to</span>
|
||
<span class="sd"> the 1st row stored on the GPU is given by rank * weights.shape[0]). Each</span>
|
||
<span class="sd"> parallel rank will query all the indices and set 0s for the weights that</span>
|
||
<span class="sd"> are not stored on the associated GPU. To compute the final result, a</span>
|
||
<span class="sd"> parallel all-reduce operation is added to the TensorRT graph. That lookup</span>
|
||
<span class="sd"> can be performed using either the plugin or the operators TensorRT support.</span>
|
||
|
||
<span class="sd"> When'sharding_dim==1', each GPU stores a subset of the embedding table's columns.</span>
|
||
<span class="sd"> Each rank can obtain a portion of the embedding results.</span>
|
||
<span class="sd"> Then the embedding is collected using the all-gather operation.</span>
|
||
<span class="sd"> Related transposition operations are also used to obtain the final results.</span>
|
||
|
||
<span class="sd"> 2. Store embedding lookup table as a whole</span>
|
||
<span class="sd"> When 'tp_size' is not greater than 1, the embedding lookup table will not</span>
|
||
<span class="sd"> be divided. In this case, when the default_net().plugin_config.lookup_plugin is set,</span>
|
||
<span class="sd"> the operation is implemented using a plugin (without the all-reduce operation).</span>
|
||
<span class="sd"> Otherwise, this operation is implemented using the standard IGatherLayer in TensorRT.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor the contains the indices to perform the lookup.</span>
|
||
|
||
<span class="sd"> weight : Tensor</span>
|
||
<span class="sd"> The table to gather from.</span>
|
||
|
||
<span class="sd"> tp_size : int</span>
|
||
<span class="sd"> The number of GPUs collaborating to perform that embedding.</span>
|
||
|
||
<span class="sd"> tg_group : Optional[List[int]]</span>
|
||
<span class="sd"> The group of world ranks participating in the all-reduce when</span>
|
||
<span class="sd"> tp_size > 1.</span>
|
||
|
||
<span class="sd"> sharding_dim : int</span>
|
||
<span class="sd"> sharding_dim = 0 means that we shard the embedding table in vocab dim;</span>
|
||
<span class="sd"> sharding_dim = 1 means that we shard the embedding table in embedding dim.</span>
|
||
|
||
<span class="sd"> tp_rank : int</span>
|
||
<span class="sd"> The tensor parallelism rank. Used to calculate offset in TP on vocab dim.</span>
|
||
|
||
<span class="sd"> workspace: Optional[Tensor]</span>
|
||
<span class="sd"> See allreduce's documentation for workspace.</span>
|
||
|
||
<span class="sd"> instance_id: int</span>
|
||
<span class="sd"> See allreduce's documentation for instance_id.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the embedding lookup layer.</span>
|
||
<span class="sd"> '''</span>
|
||
|
||
<span class="c1"># Distribute embedding lookup table across multiple GPU</span>
|
||
<span class="k">if</span> <span class="n">tp_size</span> <span class="o">></span> <span class="mi">1</span> <span class="ow">and</span> <span class="n">tp_group</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="n">sharding_dim</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span> <span class="c1"># TP on vocab_size dimension</span>
|
||
<span class="k">if</span> <span class="n">tp_rank</span> <span class="o">==</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
|
||
<span class="s2">"Rank cannot be none for tensor parallelism on vocab dim"</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">lookup_plugin</span><span class="p">:</span>
|
||
<span class="n">x</span> <span class="o">=</span> <span class="n">_lookup_plugin</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">tp_rank</span><span class="p">)</span>
|
||
<span class="n">x</span> <span class="o">=</span> <span class="n">allreduce</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">tp_group</span><span class="p">,</span> <span class="n">workspace</span><span class="p">,</span> <span class="n">instance_id</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">shape_weight</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="n">weight</span><span class="p">)</span>
|
||
<span class="n">vocab_size</span> <span class="o">=</span> <span class="nb">slice</span><span class="p">(</span><span class="n">shape_weight</span><span class="p">,</span> <span class="n">starts</span><span class="o">=</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">sizes</span><span class="o">=</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
|
||
<span class="n">tmp_input</span> <span class="o">=</span> <span class="nb">input</span> <span class="o">-</span> <span class="n">vocab_size</span> <span class="o">*</span> <span class="n">tp_rank</span>
|
||
|
||
<span class="c1"># Identify the valid indices</span>
|
||
<span class="n">is_qualified</span> <span class="o">=</span> <span class="n">op_and</span><span class="p">(</span><span class="n">tmp_input</span> <span class="o">>=</span> <span class="mi">0</span><span class="p">,</span> <span class="n">tmp_input</span> <span class="o"><</span> <span class="n">vocab_size</span><span class="p">)</span>
|
||
<span class="n">is_qualified_expand</span> <span class="o">=</span> <span class="n">expand_dims</span><span class="p">(</span><span class="n">is_qualified</span><span class="p">,</span>
|
||
<span class="p">[</span><span class="n">is_qualified</span><span class="o">.</span><span class="n">ndim</span><span class="p">()])</span>
|
||
|
||
<span class="c1"># Replace the invalid ones to zero</span>
|
||
<span class="n">placeholder_input</span> <span class="o">=</span> <span class="n">where</span><span class="p">(</span><span class="n">is_qualified</span><span class="p">,</span> <span class="n">tmp_input</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
|
||
|
||
<span class="c1"># Get the temporal results</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_gather</span><span class="p">(</span>
|
||
<span class="n">weight</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">placeholder_input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
|
||
<span class="n">tmp_output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
|
||
<span class="c1"># Set zero for invalid results</span>
|
||
<span class="n">placeholder_tmp</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">is_qualified_expand</span><span class="p">,</span> <span class="n">tmp_output</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
|
||
<span class="n">placeholder</span> <span class="o">=</span> <span class="n">placeholder_tmp</span> <span class="o">-</span> <span class="n">placeholder_tmp</span>
|
||
<span class="n">x</span> <span class="o">=</span> <span class="n">where</span><span class="p">(</span><span class="n">is_qualified_expand</span><span class="p">,</span> <span class="n">tmp_output</span><span class="p">,</span> <span class="n">placeholder</span><span class="p">)</span>
|
||
|
||
<span class="c1"># Use all reduce to collect the results</span>
|
||
<span class="n">x</span> <span class="o">=</span> <span class="n">allreduce</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">tp_group</span><span class="p">,</span> <span class="n">workspace</span><span class="p">,</span> <span class="n">instance_id</span><span class="p">)</span>
|
||
|
||
<span class="k">elif</span> <span class="n">sharding_dim</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span> <span class="c1"># TP on hidden dimension</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_gather</span><span class="p">(</span><span class="n">weight</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
|
||
<span class="n">x</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
|
||
<span class="c1"># 1. [dim0, local_dim] -> [dim0 * tp_size, local_dim]</span>
|
||
<span class="n">x</span> <span class="o">=</span> <span class="n">allgather</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">tp_group</span><span class="p">)</span>
|
||
|
||
<span class="c1"># 2. [dim0 * tp_size, local_dim] -> [dim0, local_dim * tp_size]</span>
|
||
<span class="c1"># 2.1 split</span>
|
||
<span class="n">split_size</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="o">/</span> <span class="n">tp_size</span>
|
||
<span class="n">ndim</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
|
||
<span class="n">starts</span> <span class="o">=</span> <span class="p">[</span><span class="n">constant</span><span class="p">(</span><span class="n">int32_array</span><span class="p">([</span><span class="mi">0</span><span class="p">]))</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)]</span>
|
||
<span class="n">sizes</span> <span class="o">=</span> <span class="p">[</span><span class="n">shape</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="n">d</span><span class="p">)</span> <span class="k">for</span> <span class="n">d</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)]</span>
|
||
<span class="n">sizes</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">split_size</span>
|
||
<span class="n">sections</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">tp_size</span><span class="p">):</span>
|
||
<span class="n">starts</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">split_size</span> <span class="o">*</span> <span class="n">i</span>
|
||
<span class="n">sections</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="nb">slice</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">concat</span><span class="p">(</span><span class="n">starts</span><span class="p">),</span> <span class="n">concat</span><span class="p">(</span><span class="n">sizes</span><span class="p">)))</span>
|
||
<span class="c1"># 2.2 concat</span>
|
||
<span class="n">x</span> <span class="o">=</span> <span class="n">concat</span><span class="p">(</span><span class="n">sections</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">-</span> <span class="mi">1</span><span class="p">))</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
|
||
<span class="s1">'Tensor Parallelism only support splitting Embedding lookup along hidden (sharding_dim==1) and vocab (sharding_dim==0) dimensionis'</span>
|
||
<span class="p">)</span>
|
||
|
||
<span class="c1"># Store embedding lookup table as a whole</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">lookup_plugin</span><span class="p">:</span>
|
||
<span class="n">x</span> <span class="o">=</span> <span class="n">_lookup_plugin</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">rank</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_gather</span><span class="p">(</span><span class="n">weight</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
|
||
<span class="n">x</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">x</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="constant_to_tensor_">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.constant_to_tensor_">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">constant_to_tensor_</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">],</span>
|
||
<span class="n">dtype</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="n">constant</span><span class="p">(</span><span class="n">int32_array</span><span class="p">([</span><span class="nb">input</span><span class="p">]))</span>
|
||
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="nb">float</span><span class="p">):</span>
|
||
<span class="k">assert</span> <span class="n">dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">float32</span> <span class="ow">or</span> <span class="n">dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">float16</span>
|
||
<span class="k">if</span> <span class="n">dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">float32</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">constant</span><span class="p">(</span><span class="n">fp32_array</span><span class="p">([</span><span class="nb">input</span><span class="p">]))</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">constant</span><span class="p">(</span><span class="n">fp16_array</span><span class="p">([</span><span class="nb">input</span><span class="p">]))</span>
|
||
|
||
<span class="k">return</span> <span class="nb">input</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="broadcast_helper">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.broadcast_helper">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">broadcast_helper</span><span class="p">(</span><span class="n">left</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">],</span>
|
||
<span class="n">right</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">])</span> <span class="o">-></span> <span class="n">Tuple</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">]:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Helper function to perform a broadcast.</span>
|
||
|
||
<span class="sd"> For each input, that function first creates a constant tensor if the input</span>
|
||
<span class="sd"> is an integer or a float. Then, if needed, it expands the smaller tensor to</span>
|
||
<span class="sd"> make sure its rank is the same as the larger one.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> left : Union[Tensor, int, float]</span>
|
||
<span class="sd"> The first input. If that input is an integer or a float, the</span>
|
||
<span class="sd"> function creates a constant tensor.</span>
|
||
|
||
<span class="sd"> right : Union[Tensor, int, float]</span>
|
||
<span class="sd"> The second input. If that input is an integer or a float, the</span>
|
||
<span class="sd"> function creates a constant tensor.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> A pair of tensors of same rank.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">strongly_typed</span><span class="p">:</span>
|
||
<span class="n">left</span> <span class="o">=</span> <span class="n">constant_to_tensor_</span><span class="p">(</span><span class="n">left</span><span class="p">)</span>
|
||
<span class="n">right</span> <span class="o">=</span> <span class="n">constant_to_tensor_</span><span class="p">(</span><span class="n">right</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">left</span> <span class="o">=</span> <span class="n">constant_to_tensor_</span><span class="p">(</span>
|
||
<span class="n">left</span><span class="p">,</span> <span class="n">right</span><span class="o">.</span><span class="n">dtype</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">right</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">)</span> <span class="k">else</span> <span class="n">trt</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
||
<span class="n">right</span> <span class="o">=</span> <span class="n">constant_to_tensor_</span><span class="p">(</span><span class="n">right</span><span class="p">,</span> <span class="n">left</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">left</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span> <span class="o">==</span> <span class="n">right</span><span class="o">.</span><span class="n">rank</span><span class="p">():</span>
|
||
<span class="k">return</span> <span class="p">(</span><span class="n">left</span><span class="p">,</span> <span class="n">right</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">left</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span> <span class="o"><</span> <span class="n">right</span><span class="o">.</span><span class="n">rank</span><span class="p">():</span>
|
||
<span class="n">left</span> <span class="o">=</span> <span class="n">expand_dims_like</span><span class="p">(</span><span class="n">left</span><span class="p">,</span> <span class="n">right</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="p">(</span><span class="n">left</span><span class="p">,</span> <span class="n">right</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">left</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span> <span class="o">></span> <span class="n">right</span><span class="o">.</span><span class="n">rank</span><span class="p">():</span>
|
||
<span class="n">right</span> <span class="o">=</span> <span class="n">expand_dims_like</span><span class="p">(</span><span class="n">right</span><span class="p">,</span> <span class="n">left</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="p">(</span><span class="n">left</span><span class="p">,</span> <span class="n">right</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="elementwise_binary">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.elementwise_binary">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">elementwise_binary</span><span class="p">(</span><span class="n">left</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="nb">float</span><span class="p">],</span> <span class="n">right</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">],</span>
|
||
<span class="n">op</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an elementwise operation with two inputs.</span>
|
||
|
||
<span class="sd"> For each input, that function first creates a constant tensor if the input</span>
|
||
<span class="sd"> is an integer or a float. Then, if needed, it expands the smaller tensor to</span>
|
||
<span class="sd"> make sure its rank is the same as the larger one. Then, it performs the</span>
|
||
<span class="sd"> elementwise operation 'op'.</span>
|
||
|
||
<span class="sd"> The following closures are defined in functional.*:</span>
|
||
|
||
<span class="sd"> add for op=trt.ElementWiseOperation.SUM</span>
|
||
<span class="sd"> sub for op=trt.ElementWiseOperation.SUB</span>
|
||
<span class="sd"> mul for op=trt.ElementWiseOperation.PROD</span>
|
||
<span class="sd"> div for op=trt.ElementWiseOperation.DIV</span>
|
||
<span class="sd"> gt for op=trt.ElementWiseOperation.GREATER</span>
|
||
<span class="sd"> lt for op=trt.ElementWiseOperation.LESS</span>
|
||
<span class="sd"> op_and for op=trt.ElementWiseOperation.AND</span>
|
||
<span class="sd"> op_or for op=trt.ElementWiseOperation.OR</span>
|
||
<span class="sd"> eq for op=trt.ElementWiseOperation.EQUAL</span>
|
||
<span class="sd"> minimum for op=trt.ElementWiseOperation.MIN</span>
|
||
<span class="sd"> maximum for op=trt.ElementWiseOperation.MAX</span>
|
||
<span class="sd"> pow for op=trt.ElementWiseOperation.POW</span>
|
||
|
||
<span class="sd"> It is implemented using the IElementWiseLayer from TensorRT.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> left : Union[Tensor, int, float]</span>
|
||
<span class="sd"> The first input. If that input is an integer or a float, the</span>
|
||
<span class="sd"> function creates a constant tensor.</span>
|
||
|
||
<span class="sd"> right : Union[Tensor, int, float]</span>
|
||
<span class="sd"> The second input. If that input is an integer or a float, the</span>
|
||
<span class="sd"> function creates a constant tensor.</span>
|
||
|
||
<span class="sd"> op : trt.ElementWiseOperation</span>
|
||
<span class="sd"> The binary operation to perform.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by this elementwise operation.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">left</span><span class="p">,</span> <span class="n">right</span> <span class="o">=</span> <span class="n">broadcast_helper</span><span class="p">(</span><span class="n">left</span><span class="p">,</span> <span class="n">right</span><span class="p">)</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_elementwise</span><span class="p">(</span><span class="n">left</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">right</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">op</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<span class="n">add</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">SUM</span><span class="p">)</span>
|
||
<span class="n">sub</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">SUB</span><span class="p">)</span>
|
||
<span class="n">mul</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">PROD</span><span class="p">)</span>
|
||
<span class="n">div</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">DIV</span><span class="p">)</span>
|
||
<span class="n">gt</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">GREATER</span><span class="p">)</span>
|
||
<span class="n">lt</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">LESS</span><span class="p">)</span>
|
||
<span class="n">op_and</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">AND</span><span class="p">)</span>
|
||
<span class="n">op_or</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">OR</span><span class="p">)</span>
|
||
<span class="n">eq</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">EQUAL</span><span class="p">)</span>
|
||
<span class="n">minimum</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">MIN</span><span class="p">)</span>
|
||
<span class="n">maximum</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">MAX</span><span class="p">)</span>
|
||
<span class="nb">pow</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">POW</span><span class="p">)</span>
|
||
|
||
|
||
<div class="viewcode-block" id="where">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.where">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">where</span><span class="p">(</span><span class="n">condition</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">],</span> <span class="n">left</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">],</span>
|
||
<span class="n">right</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">])</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add a where (aka select or if-then-else) operation.</span>
|
||
|
||
<span class="sd"> Assuming the three input parameters have the same shape, that function creates</span>
|
||
<span class="sd"> the operation to compute a tensor of the same shape such that:</span>
|
||
|
||
<span class="sd"> for ii in range(mul(condition.shape)):</span>
|
||
<span class="sd"> output[ii] = left[ii] if condition[ii] else right[ii]</span>
|
||
|
||
<span class="sd"> For each input, that function first creates a constant tensor if the input</span>
|
||
<span class="sd"> is an integer or a float. Then, if needed, it expands the smaller tensor to</span>
|
||
<span class="sd"> make sure its rank is the same as the larger one. Then, it performs the</span>
|
||
<span class="sd"> selection.</span>
|
||
|
||
<span class="sd"> It is implemented using the ISelectLayer from TensorRT.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> left : Union[Tensor, int, float]</span>
|
||
<span class="sd"> The condition. If that input is an integer or a float, the function</span>
|
||
<span class="sd"> creates a constant tensor.</span>
|
||
|
||
<span class="sd"> left : Union[Tensor, int, float]</span>
|
||
<span class="sd"> The first input. If that input is an integer or a float, the</span>
|
||
<span class="sd"> function creates a constant tensor.</span>
|
||
|
||
<span class="sd"> right : Union[Tensor, int, float]</span>
|
||
<span class="sd"> The second input. If that input is an integer or a float, the</span>
|
||
<span class="sd"> function creates a constant tensor.</span>
|
||
|
||
<span class="sd"> op : trt.ElementWiseOperation</span>
|
||
<span class="sd"> The binary operation to perform.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by this select operation.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="c1"># Convert to tensors.</span>
|
||
<span class="n">condition</span> <span class="o">=</span> <span class="n">constant_to_tensor_</span><span class="p">(</span><span class="n">condition</span><span class="p">)</span>
|
||
<span class="n">left</span> <span class="o">=</span> <span class="n">constant_to_tensor_</span><span class="p">(</span><span class="n">left</span><span class="p">)</span>
|
||
<span class="n">right</span> <span class="o">=</span> <span class="n">constant_to_tensor_</span><span class="p">(</span><span class="n">right</span><span class="p">)</span>
|
||
|
||
<span class="c1"># Find the tensor with the largest rank of the three.</span>
|
||
<span class="n">largest</span> <span class="o">=</span> <span class="n">condition</span>
|
||
<span class="k">if</span> <span class="n">largest</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span> <span class="o"><</span> <span class="n">left</span><span class="o">.</span><span class="n">rank</span><span class="p">():</span>
|
||
<span class="n">largest</span> <span class="o">=</span> <span class="n">left</span>
|
||
<span class="k">if</span> <span class="n">largest</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span> <span class="o"><</span> <span class="n">right</span><span class="o">.</span><span class="n">rank</span><span class="p">():</span>
|
||
<span class="n">largest</span> <span class="o">=</span> <span class="n">right</span>
|
||
|
||
<span class="c1"># Expand the tensors to match the largest one.</span>
|
||
<span class="k">if</span> <span class="n">condition</span> <span class="ow">is</span> <span class="ow">not</span> <span class="n">largest</span><span class="p">:</span>
|
||
<span class="n">condition</span> <span class="o">=</span> <span class="n">expand_dims_like</span><span class="p">(</span><span class="n">condition</span><span class="p">,</span> <span class="n">largest</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">left</span> <span class="ow">is</span> <span class="ow">not</span> <span class="n">largest</span><span class="p">:</span>
|
||
<span class="n">left</span> <span class="o">=</span> <span class="n">expand_dims_like</span><span class="p">(</span><span class="n">left</span><span class="p">,</span> <span class="n">largest</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">right</span> <span class="ow">is</span> <span class="ow">not</span> <span class="n">largest</span><span class="p">:</span>
|
||
<span class="n">right</span> <span class="o">=</span> <span class="n">expand_dims_like</span><span class="p">(</span><span class="n">right</span><span class="p">,</span> <span class="n">largest</span><span class="p">)</span>
|
||
|
||
<span class="c1"># Insert the operation.</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_select</span><span class="p">(</span><span class="n">condition</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">left</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">right</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="unary">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.unary">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">unary</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">op</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">UnaryOperation</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an elementwise operation on a single input.</span>
|
||
|
||
<span class="sd"> The following closures are defined in functional.*:</span>
|
||
|
||
<span class="sd"> round for op=trt.UnaryOperation.ROUND</span>
|
||
<span class="sd"> sqrt for op=trt.UnaryOperation.SQRT</span>
|
||
<span class="sd"> exp for op=trt.UnaryOperation.EXP</span>
|
||
<span class="sd"> sin for op=trt.UnaryOperation.SIN</span>
|
||
<span class="sd"> cos for op=trt.UnaryOperation.COS</span>
|
||
<span class="sd"> abs for op=trt.UnaryOperation.ABS</span>
|
||
|
||
<span class="sd"> It is implemented using the IUnaryLayer from TensorRT.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor.</span>
|
||
|
||
<span class="sd"> op : trt.UnaryOperation</span>
|
||
<span class="sd"> The unary operation to perform.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by this elementwise operation.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_unary</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">op</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<span class="nb">round</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">unary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">UnaryOperation</span><span class="o">.</span><span class="n">ROUND</span><span class="p">)</span>
|
||
<span class="n">sqrt</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">unary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">UnaryOperation</span><span class="o">.</span><span class="n">SQRT</span><span class="p">)</span>
|
||
<span class="n">exp</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">unary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">UnaryOperation</span><span class="o">.</span><span class="n">EXP</span><span class="p">)</span>
|
||
<span class="n">sin</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">unary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">UnaryOperation</span><span class="o">.</span><span class="n">SIN</span><span class="p">)</span>
|
||
<span class="n">cos</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">unary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">UnaryOperation</span><span class="o">.</span><span class="n">COS</span><span class="p">)</span>
|
||
<span class="nb">abs</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">unary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">UnaryOperation</span><span class="o">.</span><span class="n">ABS</span><span class="p">)</span>
|
||
|
||
|
||
<div class="viewcode-block" id="mean">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.mean">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">mean</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">keepdim</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to compute the mean along a dimension.</span>
|
||
|
||
<span class="sd"> Computes the mean along the dimension 'dim' of the input tensor.</span>
|
||
|
||
<span class="sd"> It is implemented using the IReduceLayer from TensorRT.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor.</span>
|
||
|
||
<span class="sd"> dim : int</span>
|
||
<span class="sd"> The dimension along which the mean is computed.</span>
|
||
|
||
<span class="sd"> keepdim : bool</span>
|
||
<span class="sd"> Is the dimension kept in the reduced tensor? When True the</span>
|
||
<span class="sd"> dimension is kept, it is removed from the shape otherwise.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by this reduction operation.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">dim</span> <span class="o">=</span> <span class="n">dim_resolve_negative</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">())</span>
|
||
<span class="n">axes</span> <span class="o">=</span> <span class="n">dim_to_trt_axes</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_reduce</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">ReduceOperation</span><span class="o">.</span><span class="n">AVG</span><span class="p">,</span>
|
||
<span class="n">axes</span><span class="p">,</span>
|
||
<span class="n">keep_dims</span><span class="o">=</span><span class="n">keepdim</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="max">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.max">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">max</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">keepdim</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to compute the max along a dimension.</span>
|
||
|
||
<span class="sd"> Computes the max along the dimension 'dim' of the input tensor.</span>
|
||
|
||
<span class="sd"> It is implemented using the IReduceLayer from TensorRT.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor.</span>
|
||
|
||
<span class="sd"> dim : int</span>
|
||
<span class="sd"> The dimension along which the mean is computed.</span>
|
||
|
||
<span class="sd"> keepdim : bool</span>
|
||
<span class="sd"> Is the dimension kept in the reduced tensor? When True the</span>
|
||
<span class="sd"> dimension is kept, it is removed from the shape otherwise.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by this reduction operation.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">dim</span> <span class="o">=</span> <span class="n">dim_resolve_negative</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">())</span>
|
||
<span class="n">axes</span> <span class="o">=</span> <span class="n">dim_to_trt_axes</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_reduce</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">ReduceOperation</span><span class="o">.</span><span class="n">MAX</span><span class="p">,</span>
|
||
<span class="n">axes</span><span class="p">,</span>
|
||
<span class="n">keep_dims</span><span class="o">=</span><span class="n">keepdim</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="identity">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.identity">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">identity</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an identity operation.</span>
|
||
|
||
<span class="sd"> TODO: Document why it can be done using a plugin!!!</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by this identity operation.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">identity_plugin</span><span class="p">:</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_identity</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
|
||
<span class="s1">'Identity'</span><span class="p">,</span> <span class="s1">'1'</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">()</span>
|
||
<span class="n">id_plug</span> <span class="o">=</span> <span class="n">plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">"identity"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">]</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">id_plug</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="argmax">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.argmax">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">argmax</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">keepdim</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an argmax operation.</span>
|
||
|
||
<span class="sd"> As explained in the ONNX documentation,</span>
|
||
|
||
<span class="sd"> https://github.com/onnx/onnx/blob/main/docs/Operators.md#argmax</span>
|
||
|
||
<span class="sd"> that function creates a layer computing the indices of the max elements of</span>
|
||
<span class="sd"> the input tensor's element along the provided dim. The resulting tensor</span>
|
||
<span class="sd"> has the same rank as the input if keepdims is True. If keepdims is False,</span>
|
||
<span class="sd"> then the resulting tensor has the reduced dimension pruned.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor.</span>
|
||
|
||
<span class="sd"> dim : int</span>
|
||
<span class="sd"> The dimension in which to compute the argmax indices.</span>
|
||
|
||
<span class="sd"> keepdim : bool</span>
|
||
<span class="sd"> Do we keep the dimension along which the reduction is performed?</span>
|
||
<span class="sd"> Yes, if set to True, no otherwise.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by this argmax operation.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">dim</span> <span class="o">=</span> <span class="n">dim_resolve_negative</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">())</span>
|
||
<span class="n">axes</span> <span class="o">=</span> <span class="n">dim_to_trt_axes</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_topk</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">TopKOperation</span><span class="o">.</span><span class="n">MAX</span><span class="p">,</span>
|
||
<span class="mi">1</span><span class="p">,</span> <span class="n">axes</span><span class="p">)</span>
|
||
<span class="n">output</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">keepdim</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">layer</span><span class="p">)</span>
|
||
|
||
<span class="n">a</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">())))</span>
|
||
<span class="n">a</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span>
|
||
<span class="n">indices</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">int32_array</span><span class="p">([</span><span class="n">a</span><span class="p">]))</span>
|
||
<span class="n">output_shape</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="n">output</span><span class="p">)</span>
|
||
<span class="n">new_shape</span> <span class="o">=</span> <span class="n">gather</span><span class="p">(</span><span class="n">output_shape</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">indices</span><span class="p">)</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">view</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">new_shape</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="gelu">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.gelu">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">gelu</span><span class="p">(</span><span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add a GELU operation.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor on which the activation function is applied.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the activation layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">strongly_typed</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="n">x</span> <span class="o">*</span> <span class="p">(</span>
|
||
<span class="n">tanh</span><span class="p">(</span><span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="mf">2.0</span> <span class="o">/</span> <span class="n">math</span><span class="o">.</span><span class="n">pi</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="n">x</span> <span class="o">+</span> <span class="mf">0.044715</span> <span class="o">*</span> <span class="nb">pow</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mf">3.0</span><span class="p">)))</span> <span class="o">+</span> <span class="mf">1.0</span><span class="p">)</span>
|
||
|
||
<span class="n">array_fn</span> <span class="o">=</span> <span class="p">{</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">float32</span><span class="p">:</span> <span class="n">fp32_array</span><span class="p">,</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">float16</span><span class="p">:</span> <span class="n">fp16_array</span><span class="p">,</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">bfloat16</span><span class="p">:</span> <span class="n">bf16_array</span><span class="p">,</span>
|
||
<span class="p">}[</span><span class="n">x</span><span class="o">.</span><span class="n">dtype</span><span class="p">]</span>
|
||
|
||
<span class="n">v1</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">array_fn</span><span class="p">([</span><span class="mf">0.5</span><span class="p">]))</span>
|
||
<span class="n">v2</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">array_fn</span><span class="p">([</span><span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="mf">2.0</span> <span class="o">/</span> <span class="n">math</span><span class="o">.</span><span class="n">pi</span><span class="p">)]))</span>
|
||
<span class="n">v3</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">array_fn</span><span class="p">([</span><span class="mf">0.044715</span><span class="p">]))</span>
|
||
<span class="n">v4</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">array_fn</span><span class="p">([</span><span class="mf">3.0</span><span class="p">]))</span>
|
||
<span class="n">v5</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">array_fn</span><span class="p">([</span><span class="mf">1.0</span><span class="p">]))</span>
|
||
<span class="k">return</span> <span class="n">v1</span> <span class="o">*</span> <span class="n">x</span> <span class="o">*</span> <span class="p">(</span><span class="n">tanh</span><span class="p">(</span><span class="n">v2</span> <span class="o">*</span> <span class="p">(</span><span class="n">x</span> <span class="o">+</span> <span class="n">v3</span> <span class="o">*</span> <span class="nb">pow</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">v4</span><span class="p">)))</span> <span class="o">+</span> <span class="n">v5</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="geglu">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.geglu">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">geglu</span><span class="p">(</span><span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add a Gated-GELU operation.</span>
|
||
|
||
<span class="sd"> That function takes a tensor, splits it into two halves along the last</span>
|
||
<span class="sd"> dimension, applies GELU to the second half and multiply the results. The</span>
|
||
<span class="sd"> behaviour is undefined if the last dimension is not even.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor on which the activation function is applied.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the activation layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">a</span><span class="p">,</span> <span class="n">b</span> <span class="o">=</span> <span class="n">chunk</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">a</span> <span class="o">*</span> <span class="n">gelu</span><span class="p">(</span><span class="n">b</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="group_norm">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.group_norm">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">group_norm</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">num_groups</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">weight</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">bias</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">eps</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-05</span><span class="p">):</span>
|
||
|
||
<span class="c1">##</span>
|
||
<span class="c1">## TODO: Document that function!</span>
|
||
<span class="c1">##</span>
|
||
|
||
<span class="k">assert</span> <span class="ow">not</span> <span class="nb">input</span><span class="o">.</span><span class="n">is_dynamic</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="n">num_channels</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="mi">1</span><span class="p">]</span>
|
||
|
||
<span class="n">ndim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
|
||
<span class="n">old_shape</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">)</span>
|
||
<span class="n">new_shape</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span>
|
||
<span class="nb">input</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
|
||
<span class="n">num_groups</span><span class="p">,</span>
|
||
<span class="n">num_channels</span> <span class="o">//</span> <span class="n">num_groups</span><span class="p">,</span>
|
||
<span class="p">]</span> <span class="o">+</span> <span class="p">[</span><span class="nb">input</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="n">i</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">ndim</span><span class="p">)])</span>
|
||
<span class="n">x</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">new_shape</span><span class="p">)</span>
|
||
|
||
<span class="n">reduce_dim</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">ndim</span> <span class="o">+</span> <span class="mi">1</span><span class="p">))</span>
|
||
<span class="n">ux</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">reduce_dim</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||
<span class="n">numerator</span> <span class="o">=</span> <span class="n">x</span> <span class="o">-</span> <span class="n">ux</span>
|
||
<span class="n">varx</span> <span class="o">=</span> <span class="n">numerator</span> <span class="o">*</span> <span class="n">numerator</span>
|
||
<span class="n">varx</span> <span class="o">=</span> <span class="n">varx</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">reduce_dim</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||
|
||
<span class="n">denom</span> <span class="o">=</span> <span class="n">varx</span> <span class="o">+</span> <span class="n">eps</span>
|
||
<span class="n">denom</span> <span class="o">=</span> <span class="n">denom</span><span class="o">.</span><span class="n">sqrt</span><span class="p">()</span>
|
||
<span class="n">y</span> <span class="o">=</span> <span class="n">numerator</span> <span class="o">/</span> <span class="n">denom</span>
|
||
<span class="n">y</span> <span class="o">=</span> <span class="n">y</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">old_shape</span><span class="p">)</span>
|
||
|
||
<span class="n">new_shape</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span><span class="n">num_channels</span><span class="p">]</span> <span class="o">+</span> <span class="p">[</span><span class="mi">1</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">ndim</span><span class="p">)])</span>
|
||
<span class="k">if</span> <span class="n">weight</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">y</span> <span class="o">=</span> <span class="n">y</span> <span class="o">*</span> <span class="n">weight</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">new_shape</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">y</span> <span class="o">=</span> <span class="n">y</span> <span class="o">+</span> <span class="n">bias</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">new_shape</span><span class="p">)</span>
|
||
|
||
<span class="k">return</span> <span class="n">y</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="softplus">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.softplus">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">softplus</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">beta</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">threshold</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add the softplus activation base on PyTorch definition.</span>
|
||
|
||
<span class="sd"> See https://pytorch.org/docs/stable/generated/torch.nn.functional.softplus.html for a</span>
|
||
<span class="sd"> description of that function.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> Input TensorRT-LLM Tensor.</span>
|
||
<span class="sd"> beta : float</span>
|
||
<span class="sd"> The parameter for softplus computation.</span>
|
||
<span class="sd"> threshold : float</span>
|
||
<span class="sd"> The threshold for reverting to the linear function when input * beta > threshold</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The output tensor created by that layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">sf_layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_activation</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">ActivationType</span><span class="o">.</span><span class="n">SOFTPLUS</span><span class="p">)</span>
|
||
<span class="n">sf_layer</span><span class="o">.</span><span class="n">alpha</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">/</span> <span class="n">beta</span>
|
||
<span class="n">sf_layer</span><span class="o">.</span><span class="n">beta</span> <span class="o">=</span> <span class="n">beta</span>
|
||
|
||
<span class="n">prod_tensor</span> <span class="o">=</span> <span class="nb">input</span> <span class="o">*</span> <span class="n">beta</span>
|
||
<span class="n">result</span> <span class="o">=</span> <span class="n">prod_tensor</span> <span class="o">></span> <span class="n">threshold</span>
|
||
|
||
<span class="k">return</span> <span class="n">where</span><span class="p">(</span><span class="n">result</span><span class="p">,</span> <span class="nb">input</span><span class="p">,</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">sf_layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
|
||
<span class="n">sf_layer</span><span class="p">))</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="outer">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.outer">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">outer</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">vec2</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to compute the outer product between two tensors.</span>
|
||
|
||
<span class="sd"> That operation creates an Einsum node.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The first input tensor.</span>
|
||
|
||
<span class="sd"> vec2 : Tensor</span>
|
||
<span class="sd"> The second input tensor.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The output tensor produced by this layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">einsum</span><span class="p">(</span><span class="s1">'i,j->ij'</span><span class="p">,</span> <span class="p">[</span><span class="nb">input</span><span class="p">,</span> <span class="n">vec2</span><span class="p">])</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="avg_pool2d">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.avg_pool2d">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">avg_pool2d</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">kernel_size</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span>
|
||
<span class="n">stride</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">padding</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span>
|
||
<span class="n">ceil_mode</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">count_include_pad</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
|
||
<span class="c1">##</span>
|
||
<span class="c1">## TODO: Document that function!</span>
|
||
<span class="c1">##</span>
|
||
|
||
<span class="k">assert</span> <span class="ow">not</span> <span class="nb">input</span><span class="o">.</span><span class="n">is_dynamic</span><span class="p">()</span>
|
||
<span class="n">ndim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
|
||
<span class="k">if</span> <span class="n">ndim</span> <span class="o">==</span> <span class="mi">3</span><span class="p">:</span>
|
||
<span class="nb">input</span> <span class="o">=</span> <span class="n">expand_dims</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_pooling</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PoolingType</span><span class="o">.</span><span class="n">AVERAGE</span><span class="p">,</span> <span class="n">kernel_size</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">stride</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">stride</span> <span class="o">=</span> <span class="n">kernel_size</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">stride</span> <span class="o">=</span> <span class="n">stride</span>
|
||
|
||
<span class="n">output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">ndim</span> <span class="o">==</span> <span class="mi">3</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">output</span><span class="o">.</span><span class="n">view</span><span class="p">(</span>
|
||
<span class="n">concat</span><span class="p">([</span><span class="n">output</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span>
|
||
<span class="n">output</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">2</span><span class="p">),</span>
|
||
<span class="n">output</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">3</span><span class="p">)]))</span>
|
||
|
||
<span class="k">return</span> <span class="n">output</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="conv2d">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.conv2d">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">conv2d</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">weight</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">bias</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">stride</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
|
||
<span class="n">padding</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span>
|
||
<span class="n">dilation</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
|
||
<span class="n">groups</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
|
||
<span class="c1">##</span>
|
||
<span class="c1">## TODO: Document that function!</span>
|
||
<span class="c1">##</span>
|
||
|
||
<span class="k">assert</span> <span class="ow">not</span> <span class="nb">input</span><span class="o">.</span><span class="n">is_dynamic</span><span class="p">()</span>
|
||
|
||
<span class="n">ndim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
|
||
<span class="k">if</span> <span class="n">ndim</span> <span class="o">==</span> <span class="mi">3</span><span class="p">:</span>
|
||
<span class="nb">input</span> <span class="o">=</span> <span class="n">expand_dims</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
|
||
|
||
<span class="n">noutput</span> <span class="o">=</span> <span class="n">weight</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="mi">0</span><span class="p">]</span>
|
||
<span class="n">kernel_size</span> <span class="o">=</span> <span class="p">(</span><span class="n">weight</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="o">-</span><span class="mi">2</span><span class="p">],</span> <span class="n">weight</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
|
||
|
||
<span class="n">is_weight_constant</span> <span class="o">=</span> <span class="p">(</span><span class="n">weight</span><span class="o">.</span><span class="n">producer</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="ow">and</span> <span class="n">weight</span><span class="o">.</span><span class="n">producer</span><span class="o">.</span><span class="n">type</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">LayerType</span><span class="o">.</span><span class="n">CONSTANT</span><span class="p">)</span>
|
||
<span class="n">weight</span> <span class="o">=</span> <span class="n">weight</span><span class="o">.</span><span class="n">producer</span><span class="o">.</span><span class="n">weights</span> <span class="k">if</span> <span class="n">is_weight_constant</span> <span class="k">else</span> <span class="n">trt</span><span class="o">.</span><span class="n">Weights</span><span class="p">()</span>
|
||
|
||
<span class="k">if</span> <span class="n">bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">is_bias_constant</span> <span class="o">=</span> <span class="p">(</span><span class="n">bias</span><span class="o">.</span><span class="n">producer</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="ow">and</span> <span class="n">bias</span><span class="o">.</span><span class="n">producer</span><span class="o">.</span><span class="n">type</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">LayerType</span><span class="o">.</span><span class="n">CONSTANT</span><span class="p">)</span>
|
||
<span class="n">bias</span> <span class="o">=</span> <span class="n">bias</span><span class="o">.</span><span class="n">producer</span><span class="o">.</span><span class="n">weights</span> <span class="k">if</span> <span class="n">is_bias_constant</span> <span class="k">else</span> <span class="n">trt</span><span class="o">.</span><span class="n">Weights</span><span class="p">()</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_convolution_nd</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">noutput</span><span class="p">,</span>
|
||
<span class="n">kernel_size</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">bias</span><span class="p">)</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">stride_nd</span> <span class="o">=</span> <span class="n">stride</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">padding_nd</span> <span class="o">=</span> <span class="n">padding</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">dilation</span> <span class="o">=</span> <span class="n">dilation</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">num_groups</span> <span class="o">=</span> <span class="n">groups</span>
|
||
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="n">is_weight_constant</span><span class="p">:</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">weight</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">is_bias_constant</span><span class="p">:</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">bias</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
|
||
<span class="n">output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">ndim</span> <span class="o">==</span> <span class="mi">3</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">output</span><span class="o">.</span><span class="n">view</span><span class="p">(</span>
|
||
<span class="n">concat</span><span class="p">([</span><span class="n">output</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span>
|
||
<span class="n">output</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">2</span><span class="p">),</span>
|
||
<span class="n">output</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">3</span><span class="p">)]))</span>
|
||
|
||
<span class="k">return</span> <span class="n">output</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="conv_transpose2d">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.conv_transpose2d">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">conv_transpose2d</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">weight</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">bias</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">stride</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
|
||
<span class="n">padding</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span>
|
||
<span class="n">output_padding</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span>
|
||
<span class="n">dilation</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
|
||
<span class="n">groups</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="c1">##</span>
|
||
<span class="c1">## TODO: Document that function!</span>
|
||
<span class="c1">##</span>
|
||
|
||
<span class="k">assert</span> <span class="ow">not</span> <span class="nb">input</span><span class="o">.</span><span class="n">is_dynamic</span><span class="p">()</span>
|
||
|
||
<span class="n">ndim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
|
||
<span class="k">if</span> <span class="n">ndim</span> <span class="o">==</span> <span class="mi">3</span><span class="p">:</span>
|
||
<span class="nb">input</span> <span class="o">=</span> <span class="n">expand_dims</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
|
||
|
||
<span class="n">noutput</span> <span class="o">=</span> <span class="n">weight</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="mi">1</span><span class="p">]</span>
|
||
<span class="n">kernel_size</span> <span class="o">=</span> <span class="p">(</span><span class="n">weight</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="o">-</span><span class="mi">2</span><span class="p">],</span> <span class="n">weight</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
|
||
|
||
<span class="n">is_weight_constant</span> <span class="o">=</span> <span class="p">(</span><span class="n">weight</span><span class="o">.</span><span class="n">producer</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="ow">and</span> <span class="n">weight</span><span class="o">.</span><span class="n">producer</span><span class="o">.</span><span class="n">type</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">LayerType</span><span class="o">.</span><span class="n">CONSTANT</span><span class="p">)</span>
|
||
<span class="n">weight</span> <span class="o">=</span> <span class="n">weight</span><span class="o">.</span><span class="n">producer</span><span class="o">.</span><span class="n">weights</span> <span class="k">if</span> <span class="n">is_weight_constant</span> <span class="k">else</span> <span class="n">trt</span><span class="o">.</span><span class="n">Weights</span><span class="p">()</span>
|
||
|
||
<span class="k">if</span> <span class="n">bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">is_bias_constant</span> <span class="o">=</span> <span class="p">(</span><span class="n">bias</span><span class="o">.</span><span class="n">producer</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="ow">and</span> <span class="n">bias</span><span class="o">.</span><span class="n">producer</span><span class="o">.</span><span class="n">type</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">LayerType</span><span class="o">.</span><span class="n">CONSTANT</span><span class="p">)</span>
|
||
<span class="n">bias</span> <span class="o">=</span> <span class="n">bias</span><span class="o">.</span><span class="n">producer</span><span class="o">.</span><span class="n">weights</span> <span class="k">if</span> <span class="n">is_bias_constant</span> <span class="k">else</span> <span class="n">trt</span><span class="o">.</span><span class="n">Weights</span><span class="p">()</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_deconvolution_nd</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">noutput</span><span class="p">,</span>
|
||
<span class="n">kernel_size</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">bias</span><span class="p">)</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">stride_nd</span> <span class="o">=</span> <span class="n">stride</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">padding_nd</span> <span class="o">=</span> <span class="n">padding</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">num_groups</span> <span class="o">=</span> <span class="n">groups</span>
|
||
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="n">is_weight_constant</span><span class="p">:</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">weight</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">is_bias_constant</span><span class="p">:</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">bias</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
|
||
<span class="n">output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">ndim</span> <span class="o">==</span> <span class="mi">3</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">output</span><span class="o">.</span><span class="n">view</span><span class="p">(</span>
|
||
<span class="n">concat</span><span class="p">([</span><span class="n">output</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span>
|
||
<span class="n">output</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">2</span><span class="p">),</span>
|
||
<span class="n">output</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">3</span><span class="p">)]))</span>
|
||
|
||
<span class="k">return</span> <span class="n">output</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="split">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.split">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">split</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">split_size_or_sections</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">Sequence</span><span class="p">[</span><span class="nb">int</span><span class="p">]],</span>
|
||
<span class="n">dim</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">)</span> <span class="o">-></span> <span class="n">Sequence</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation that splits a tensor into sub-tensors.</span>
|
||
|
||
<span class="sd"> This operation creates a list of tensors that are obtained from the input</span>
|
||
<span class="sd"> tensor by slicing it along the dimension 'dim'. If 'split_size_or_sections'</span>
|
||
<span class="sd"> is an integer, the tensor is split into 'input.shape[dim] /</span>
|
||
<span class="sd"> split_size_or_sections' slices. If 'split_size_or_sections' is a list of</span>
|
||
<span class="sd"> sizes, the tensor is split into 'len(split_size_or_sections)' slices and</span>
|
||
<span class="sd"> the size of the ith slice is given by 'split_size_or_sections[i]'.</span>
|
||
|
||
<span class="sd"> There are several constraints with the current implementation:</span>
|
||
|
||
<span class="sd"> - The input tensor must be static (no dynamic dimension),</span>
|
||
<span class="sd"> - If 'split_size_or_sections' is an integer, the number of elements in</span>
|
||
<span class="sd"> the 'dim' dimension of the input must be a multiple of</span>
|
||
<span class="sd"> 'split_size_or_sections': 'input.shape[dim] % split_size_or_sections == 0'.</span>
|
||
<span class="sd"> - If 'split_size_or_sections' is a sequence, the sum of the elements in</span>
|
||
<span class="sd"> 'split_size_or_sections' must be equal to the size in the dimension</span>
|
||
<span class="sd"> 'dim': 'input.shape[dim] == sum(ii for ii in split_size_or_sections)'.</span>
|
||
|
||
<span class="sd"> That operation is implemented using a 'slice' operation for each output</span>
|
||
<span class="sd"> slice.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> tensor : Tensor</span>
|
||
<span class="sd"> The input tensor to slice.</span>
|
||
|
||
<span class="sd"> split_size_or_sections : Union[int, Sequence[int]]</span>
|
||
<span class="sd"> If it is an integer, it encodes the size of each slice. Otherwise,</span>
|
||
<span class="sd"> if it is a sequence, it is the size of each slice.</span>
|
||
|
||
<span class="sd"> dim : int</span>
|
||
<span class="sd"> The dimension of the tensor to slice.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The list of tensors produced by the different operations.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">assert</span> <span class="ow">not</span> <span class="n">tensor</span><span class="o">.</span><span class="n">is_dynamic</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span>
|
||
|
||
<span class="n">ndim</span> <span class="o">=</span> <span class="n">tensor</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
|
||
<span class="k">if</span> <span class="n">dim</span> <span class="o"><</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="n">dim</span> <span class="o">+=</span> <span class="n">ndim</span>
|
||
<span class="n">dim_value</span> <span class="o">=</span> <span class="n">tensor</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="n">dim</span><span class="p">]</span>
|
||
<span class="n">starts</span> <span class="o">=</span> <span class="p">[</span><span class="n">constant</span><span class="p">(</span><span class="n">int32_array</span><span class="p">([</span><span class="mi">0</span><span class="p">]))</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)]</span>
|
||
<span class="n">sizes</span> <span class="o">=</span> <span class="p">[</span><span class="n">shape</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">i</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)]</span>
|
||
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">split_size_or_sections</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
|
||
<span class="c1"># TODO: support non-divisible cases</span>
|
||
<span class="k">assert</span> <span class="n">dim_value</span> <span class="o">%</span> <span class="n">split_size_or_sections</span> <span class="o">==</span> <span class="mi">0</span>
|
||
<span class="n">num_sections</span> <span class="o">=</span> <span class="n">dim_value</span> <span class="o">//</span> <span class="n">split_size_or_sections</span>
|
||
<span class="n">sizes</span><span class="p">[</span><span class="n">dim</span><span class="p">]</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">int32_array</span><span class="p">([</span><span class="n">split_size_or_sections</span><span class="p">]))</span>
|
||
|
||
<span class="n">outputs</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_sections</span><span class="p">):</span>
|
||
<span class="n">starts</span><span class="p">[</span><span class="n">dim</span><span class="p">]</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">int32_array</span><span class="p">([</span><span class="n">split_size_or_sections</span> <span class="o">*</span> <span class="n">i</span><span class="p">]))</span>
|
||
<span class="n">outputs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="nb">slice</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">concat</span><span class="p">(</span><span class="n">starts</span><span class="p">),</span> <span class="n">concat</span><span class="p">(</span><span class="n">sizes</span><span class="p">)))</span>
|
||
<span class="k">return</span> <span class="n">outputs</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">total_size</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">split_size_or_sections</span><span class="p">:</span>
|
||
<span class="n">total_size</span> <span class="o">+=</span> <span class="n">i</span>
|
||
<span class="k">assert</span> <span class="n">dim_value</span> <span class="o">==</span> <span class="n">total_size</span>
|
||
<span class="n">num_sections</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">split_size_or_sections</span><span class="p">)</span>
|
||
|
||
<span class="n">outputs</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_sections</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="n">i</span> <span class="o">></span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="n">starts</span><span class="p">[</span><span class="n">dim</span><span class="p">]</span> <span class="o">=</span> <span class="n">starts</span><span class="p">[</span><span class="n">dim</span><span class="p">]</span> <span class="o">+</span> <span class="n">sizes</span><span class="p">[</span><span class="n">dim</span><span class="p">]</span>
|
||
<span class="n">sizes</span><span class="p">[</span><span class="n">dim</span><span class="p">]</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">int32_array</span><span class="p">([</span><span class="n">split_size_or_sections</span><span class="p">[</span><span class="n">i</span><span class="p">]]))</span>
|
||
<span class="n">outputs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="nb">slice</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">concat</span><span class="p">(</span><span class="n">starts</span><span class="p">),</span> <span class="n">concat</span><span class="p">(</span><span class="n">sizes</span><span class="p">)))</span>
|
||
<span class="k">return</span> <span class="n">outputs</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="chunk">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.chunk">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">chunk</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">chunks</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation that splits a tensor into sub-tensors.</span>
|
||
|
||
<span class="sd"> This operation creates a list of tensors that are obtained from the input</span>
|
||
<span class="sd"> tensor by chunking it along the dimension 'dim'. It produces 'chunks'</span>
|
||
<span class="sd"> sub-tensors.</span>
|
||
|
||
<span class="sd"> That operation is only defined for static tensors (no dynamic dimension)</span>
|
||
<span class="sd"> and the size of the tensor in the dimension 'dim' must be a multiple of</span>
|
||
<span class="sd"> 'chunks': 'input.shape[dim] % chunks == 0'.</span>
|
||
|
||
<span class="sd"> It maps to 'split' with 'split_size = input.shape[dim] / chunks'.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> tensor : Tensor</span>
|
||
<span class="sd"> The input tensor to slice.</span>
|
||
|
||
<span class="sd"> chunks : int</span>
|
||
<span class="sd"> The number of slices to split the input tensor into.</span>
|
||
|
||
<span class="sd"> dim : int</span>
|
||
<span class="sd"> The dimension of the tensor to slice.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The list of tensors produced by the different operations.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">assert</span> <span class="ow">not</span> <span class="n">tensor</span><span class="o">.</span><span class="n">is_dynamic</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span>
|
||
|
||
<span class="n">ndim</span> <span class="o">=</span> <span class="n">tensor</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
|
||
<span class="k">if</span> <span class="n">dim</span> <span class="o"><</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="n">dim</span> <span class="o">+=</span> <span class="n">ndim</span>
|
||
<span class="n">dim_value</span> <span class="o">=</span> <span class="n">tensor</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="n">dim</span><span class="p">]</span>
|
||
<span class="k">assert</span> <span class="n">dim_value</span> <span class="o">%</span> <span class="n">chunks</span> <span class="o">==</span> <span class="mi">0</span>
|
||
|
||
<span class="k">return</span> <span class="n">split</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">dim_value</span> <span class="o">//</span> <span class="n">chunks</span><span class="p">,</span> <span class="n">dim</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="AllReduceStrategy">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.AllReduceStrategy">[docs]</a>
|
||
<span class="k">class</span> <span class="nc">AllReduceStrategy</span><span class="p">(</span><span class="n">IntEnum</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> Warning: actual definition is in cpp/tensorrt_llm/kernels/customAllReduceKernels.h</span>
|
||
<span class="sd"> they must be kept in sync</span>
|
||
<span class="sd"> """</span>
|
||
<span class="n">RING</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="n">ONESHOT</span> <span class="o">=</span> <span class="mi">1</span>
|
||
<span class="n">TWOSHOT</span> <span class="o">=</span> <span class="mi">2</span>
|
||
<span class="n">AUTO</span> <span class="o">=</span> <span class="mi">3</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="allreduce">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.allreduce">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">allreduce</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">group</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span>
|
||
<span class="n">workspace</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">instance_id</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">strategy</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">AllReduceStrategy</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation that performs a collective all-reduce.</span>
|
||
|
||
<span class="sd"> Let's define 'world_size' as the length of the 'group' list. That functions</span>
|
||
<span class="sd"> creates a layer to compute the sum of 'world_size' tensors distributed</span>
|
||
<span class="sd"> amongst the 'world_size' participating ranks (one GPU per rank).</span>
|
||
|
||
<span class="sd"> The list 'group' contains the identifiers of the ranks participating into</span>
|
||
<span class="sd"> the collective operation.</span>
|
||
|
||
<span class="sd"> The tensors in the different ranks must be 1D tensors (or views) and the output</span>
|
||
<span class="sd"> tensor will have that same shape. The output tensor will be replicated on</span>
|
||
<span class="sd"> the 'world_size' ranks.</span>
|
||
|
||
<span class="sd"> That operation is implemented using a plugin that wraps the NCCL all-reduce</span>
|
||
<span class="sd"> collective operation. See</span>
|
||
<span class="sd"> https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#allreduce</span>
|
||
<span class="sd"> for details.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> tensor : Tensor</span>
|
||
<span class="sd"> The input tensor.</span>
|
||
|
||
<span class="sd"> group : List[int]</span>
|
||
<span class="sd"> The ranks participating into the all-reduce operation.</span>
|
||
|
||
<span class="sd"> workspace: Optional[Tensor]</span>
|
||
<span class="sd"> When using CUSTOM or AUTO mode, a tensor containing pointers to memory</span>
|
||
<span class="sd"> visible to all GPUs. It should be 3 poitners per TP rank -</span>
|
||
<span class="sd"> ptr to data buffer, ptr to barriers in, ptr to barriers out.</span>
|
||
<span class="sd"> It must be initialized using IpcMemory class.</span>
|
||
|
||
<span class="sd"> instance_id: int</span>
|
||
<span class="sd"> Used for synchronization with CUSTOM or AUTO. Corresponding plugins MUST have the same</span>
|
||
<span class="sd"> instance_id. I.e. GPU#0's allreduce after MLP at layer i must have the same instance_id as</span>
|
||
<span class="sd"> GPU#1, GPU#2... Also, instance_id MUST be unique per model. There should not be two allreduce instances</span>
|
||
<span class="sd"> in GPU#0 that have the same id.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by that layer.</span>
|
||
<span class="sd"> '''</span>
|
||
|
||
<span class="n">allreduce_plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
|
||
<span class="s1">'AllReduce'</span><span class="p">,</span> <span class="s1">'1'</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">strategy</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_custom_all_reduce</span><span class="p">:</span>
|
||
<span class="n">strategy</span> <span class="o">=</span> <span class="n">AllReduceStrategy</span><span class="o">.</span><span class="n">AUTO</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">strategy</span> <span class="o">=</span> <span class="n">AllReduceStrategy</span><span class="o">.</span><span class="n">RING</span>
|
||
|
||
<span class="k">if</span> <span class="n">strategy</span> <span class="o">!=</span> <span class="n">AllReduceStrategy</span><span class="o">.</span><span class="n">RING</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">allreduce</span><span class="p">,</span> <span class="s2">"ids"</span><span class="p">):</span>
|
||
<span class="n">allreduce</span><span class="o">.</span><span class="n">ids</span> <span class="o">=</span> <span class="nb">set</span><span class="p">()</span>
|
||
|
||
<span class="k">if</span> <span class="n">instance_id</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">allreduce</span><span class="o">.</span><span class="n">ids</span><span class="p">:</span>
|
||
<span class="n">allreduce</span><span class="o">.</span><span class="n">ids</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">instance_id</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s2">"Custom allreduce has already used id </span><span class="si">{</span><span class="n">instance_id</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||
|
||
<span class="k">assert</span> <span class="n">allreduce_plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
|
||
<span class="n">group</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"group"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">group</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
|
||
<span class="n">p_dtype</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">nccl_plugin</span>
|
||
<span class="n">pf_dtype</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"type_id"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">))],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">pfc</span> <span class="o">=</span> <span class="p">[</span><span class="n">group</span><span class="p">,</span> <span class="n">pf_dtype</span><span class="p">]</span>
|
||
<span class="n">p_strategy</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"strategy"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">strategy</span><span class="p">)],</span> <span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">pfc</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">p_strategy</span><span class="p">)</span>
|
||
<span class="n">p_counter</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"counter"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">instance_id</span> <span class="o">+</span> <span class="mi">1</span><span class="p">],</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">pfc</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">p_counter</span><span class="p">)</span>
|
||
|
||
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">(</span><span class="n">pfc</span><span class="p">)</span>
|
||
<span class="n">ar_plug</span> <span class="o">=</span> <span class="n">allreduce_plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">"allreduce"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">tensor</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">]</span>
|
||
<span class="k">if</span> <span class="n">strategy</span> <span class="o">!=</span> <span class="n">AllReduceStrategy</span><span class="o">.</span><span class="n">RING</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">workspace</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">ar_plug</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="allgather">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.allgather">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">allgather</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">group</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">])</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation that performs a collective all-gather.</span>
|
||
|
||
<span class="sd"> Let's define 'world_size' as the length of the 'group' list. That functions</span>
|
||
<span class="sd"> creates a layer to gather 'world_size' tensors distributed</span>
|
||
<span class="sd"> amongst the 'world_size' participating ranks (one GPU per rank).</span>
|
||
|
||
<span class="sd"> The list 'group' contains the identifiers of the ranks participating into</span>
|
||
<span class="sd"> the collective operation.</span>
|
||
|
||
<span class="sd"> The tensors in the different ranks must be 1D tensors (or views) and the</span>
|
||
<span class="sd"> output tensor will have that same shape.</span>
|
||
|
||
<span class="sd"> Given the 'section_size = input.shape[0] / world_size', each rank</span>
|
||
<span class="sd"> contributes a section of its input tensor that correspond to</span>
|
||
<span class="sd"> 'rank*section_size:(rank+1)*section_size'.</span>
|
||
|
||
<span class="sd"> That operation is implemented using a plugin that wraps the NCCL all-gather</span>
|
||
<span class="sd"> collective operation. See</span>
|
||
<span class="sd"> https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#allgather</span>
|
||
<span class="sd"> for details.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> tensor : Tensor</span>
|
||
<span class="sd"> The input tensor.</span>
|
||
|
||
<span class="sd"> group : List[int]</span>
|
||
<span class="sd"> The ranks participating into the all-gather operation.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by that layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">allgather_plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
|
||
<span class="s1">'AllGather'</span><span class="p">,</span> <span class="s1">'1'</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">allgather_plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
|
||
<span class="n">group</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"group"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">group</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
|
||
<span class="n">p_dtype</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">nccl_plugin</span>
|
||
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"type_id"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">))],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
|
||
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span><span class="n">group</span><span class="p">,</span> <span class="n">pf_type</span><span class="p">])</span>
|
||
<span class="n">allgather</span> <span class="o">=</span> <span class="n">allgather_plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">"allgather"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">tensor</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">]</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">allgather</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="send">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.send">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">send</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">tgt</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation that performs a send from a rank to another.</span>
|
||
|
||
<span class="sd"> The send operation sends a tensor from one rank to another. If a rank 'i'</span>
|
||
<span class="sd"> sends a tensor to a rank 'j', the rank 'j' must have a corresponding 'recv'</span>
|
||
<span class="sd"> operation from rank 'i'. See 'recv'.</span>
|
||
|
||
<span class="sd"> That operation is implemented using a plugin that wraps the NCCL send</span>
|
||
<span class="sd"> point-to-point operation. See</span>
|
||
<span class="sd"> https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/p2p.html#ncclsend</span>
|
||
<span class="sd"> for details.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> tensor : Tensor</span>
|
||
<span class="sd"> The input tensor.</span>
|
||
|
||
<span class="sd"> tgt : int</span>
|
||
<span class="sd"> The rank that receives the tensor.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by that layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">send_plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
|
||
<span class="s1">'Send'</span><span class="p">,</span> <span class="s1">'1'</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">send_plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
|
||
<span class="n">tgt</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"tgt_rank"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">tgt</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
|
||
<span class="n">p_dtype</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">nccl_plugin</span>
|
||
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"type_id"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">))],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
|
||
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span><span class="n">tgt</span><span class="p">,</span> <span class="n">pf_type</span><span class="p">])</span>
|
||
<span class="n">send_plug</span> <span class="o">=</span> <span class="n">send_plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">"send"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">tensor</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">]</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">send_plug</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="recv">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.recv">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">recv</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">src</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation that performs a recv to a rank from another.</span>
|
||
|
||
<span class="sd"> The recv operation receives a tensor from on a rank from another. If a rank 'i'</span>
|
||
<span class="sd"> receives a tensor from a rank 'j', the rank 'j' must have a corresponding 'send'</span>
|
||
<span class="sd"> operation to rank 'j'. See 'send'.</span>
|
||
|
||
<span class="sd"> That operation is implemented using a plugin that wraps the NCCL recv</span>
|
||
<span class="sd"> point-to-point operation. See</span>
|
||
<span class="sd"> https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/p2p.html#ncclrecv</span>
|
||
<span class="sd"> for details.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> tensor : Tensor</span>
|
||
<span class="sd"> The input tensor.</span>
|
||
|
||
<span class="sd"> src : int</span>
|
||
<span class="sd"> The rank that sends the tensor to.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by that layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">recv_plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
|
||
<span class="s1">'Recv'</span><span class="p">,</span> <span class="s1">'1'</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">recv_plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
|
||
<span class="n">src</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"src_rank"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">src</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">p_dtype</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">nccl_plugin</span>
|
||
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"type_id"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">))],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
|
||
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span><span class="n">src</span><span class="p">,</span> <span class="n">pf_type</span><span class="p">])</span>
|
||
<span class="n">recv_plug</span> <span class="o">=</span> <span class="n">recv_plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">"recv"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">tensor</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">]</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">recv_plug</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="bert_attention">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.bert_attention">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">bert_attention</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">input_lengths</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">num_heads</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">head_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">q_scaling</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span>
|
||
<span class="n">relative_attention</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">relative_attention_bias</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">max_distance</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">max_input_length</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tuple</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation that performs the multi-head attention in BERT.</span>
|
||
|
||
<span class="sd"> The multihead-attention (MHA) is the sequence of a batched matmul, a</span>
|
||
<span class="sd"> softmax and a batched matmul as described in</span>
|
||
<span class="sd"> https://arxiv.org/abs/1706.03762. That function adds an operation that</span>
|
||
<span class="sd"> performs those computations using a single GPU kernel.</span>
|
||
|
||
<span class="sd"> The input tensor contains the Q, K and V elements. It is a 2D tensor and</span>
|
||
<span class="sd"> its shape is '[sum_of_tokens, 3*hidden_dim]' where the 'sum_of_tokens' is</span>
|
||
<span class="sd"> the sum of the sequence lengths in the batch.</span>
|
||
|
||
<span class="sd"> In MHA, the output of the Q*K^T product is scaled by a constant value that</span>
|
||
<span class="sd"> is computed as:</span>
|
||
|
||
<span class="sd"> 1.f / (q_scaling * sqrt(head_size)).</span>
|
||
|
||
<span class="sd"> That 'q_scaling' constant is the last argument of that function.</span>
|
||
|
||
<span class="sd"> That layer is implemented using a plugin (see bertAttentionPlugin).</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> tensor : Tensor</span>
|
||
<span class="sd"> The QKV input tensor.</span>
|
||
|
||
<span class="sd"> input_lengths : Tensor</span>
|
||
<span class="sd"> The length of each sequence. It is a 1D tensor of size 'batch_size'.</span>
|
||
|
||
<span class="sd"> num_heads : int</span>
|
||
<span class="sd"> The number of heads.</span>
|
||
|
||
<span class="sd"> head_size : int</span>
|
||
<span class="sd"> The size of each head.</span>
|
||
|
||
<span class="sd"> q_scaling : float</span>
|
||
<span class="sd"> The factor to compute the scaling factor to scale the output of the</span>
|
||
<span class="sd"> 'Q*K^T' product.</span>
|
||
|
||
<span class="sd"> relative_attention: bool = False</span>
|
||
<span class="sd"> If enable relative attention.</span>
|
||
|
||
<span class="sd"> relative_attention_bias: Tensor = None</span>
|
||
<span class="sd"> The relative attention bias [num_heads, max_seq_len, max_seq_len], or The relative attention embedding table for implicit mode, [num_heads, num_buckets].</span>
|
||
|
||
<span class="sd"> max_distance: int = 0</span>
|
||
<span class="sd"> The maximum distance of relative position in attention, for implicit mode.</span>
|
||
<span class="sd"> Default value is 0, meaning to use the regular mode of relative attention bias.</span>
|
||
<span class="sd"> Implicit mode is only enabled when passing in non-zero positive max_distance value.</span>
|
||
<span class="sd"> See relative attention bias in docs/gpt_attention.md</span>
|
||
|
||
<span class="sd"> max_input_length: Tensor = None</span>
|
||
<span class="sd"> The maximum input sequence length represented by Tensor shape. Requires for remove_input_padding to pre-define plugin workspace size.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by that layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">attn_plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
|
||
<span class="s1">'BertAttention'</span><span class="p">,</span> <span class="s1">'1'</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">attn_plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
|
||
<span class="n">nheads</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"num_heads"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">num_heads</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">head_size</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"head_size"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">head_size</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">q_scaling</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"q_scaling"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">q_scaling</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">FLOAT32</span><span class="p">)</span>
|
||
<span class="n">enable_qk_half_accum</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"enable_qk_half_accum"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span>
|
||
<span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">attention_qk_half_accumulation</span><span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">context_fmha_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"context_fmha_type"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">context_fmha_type</span><span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">p_dtype</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">bert_attention_plugin</span>
|
||
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"type_id"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">))],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">do_relative_attention</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"do_relative_attention"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">relative_attention</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">max_distance</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"max_distance"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">max_distance</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">remove_padding</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"remove_padding"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span>
|
||
<span class="n">nheads</span><span class="p">,</span> <span class="n">head_size</span><span class="p">,</span> <span class="n">q_scaling</span><span class="p">,</span> <span class="n">enable_qk_half_accum</span><span class="p">,</span> <span class="n">context_fmha_type</span><span class="p">,</span>
|
||
<span class="n">pf_type</span><span class="p">,</span> <span class="n">do_relative_attention</span><span class="p">,</span> <span class="n">max_distance</span><span class="p">,</span> <span class="n">remove_padding</span>
|
||
<span class="p">])</span>
|
||
|
||
<span class="n">attn_plug</span> <span class="o">=</span> <span class="n">attn_plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">"padding_attn"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">tensor</span><span class="p">,</span> <span class="n">input_lengths</span><span class="p">]</span>
|
||
<span class="k">if</span> <span class="n">max_input_length</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="c1"># for remove padding mode</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">max_input_length</span><span class="p">]</span>
|
||
<span class="k">if</span> <span class="n">relative_attention_bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="c1"># for relative attention mode</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">relative_attention_bias</span><span class="p">]</span>
|
||
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">i</span><span class="o">.</span><span class="n">trt_tensor</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">plug_inputs</span><span class="p">]</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">attn_plug</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">layer</span><span class="o">.</span><span class="n">num_outputs</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span> \
|
||
<span class="sa">f</span><span class="s2">"Plugin outputs number mismatch with expected, got </span><span class="si">{</span><span class="n">layer</span><span class="o">.</span><span class="n">num_outputs</span><span class="si">}</span><span class="s2">, expected 1"</span>
|
||
<span class="n">output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">output</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="k">return</span> <span class="n">output</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="gpt_attention">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.gpt_attention">[docs]</a>
|
||
<span class="nd">@gw</span><span class="o">.</span><span class="n">record_signature</span>
|
||
<span class="k">def</span> <span class="nf">gpt_attention</span><span class="p">(</span>
|
||
<span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">past_key_value</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">sequence_length</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">host_past_key_value_lengths</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">host_max_kv_cache_lengths</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">context_lengths</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">cache_indirection</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">host_request_types</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">num_heads</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">num_kv_heads</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">hidden_size_per_head</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">q_scaling</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span>
|
||
<span class="n">rotary_embedding_dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">rotary_embedding_base</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">10000.0</span><span class="p">,</span>
|
||
<span class="n">rotary_embedding_scale_type</span><span class="p">:</span> <span class="n">RotaryScalingType</span> <span class="o">=</span> <span class="n">RotaryScalingType</span><span class="o">.</span><span class="n">none</span><span class="p">,</span>
|
||
<span class="n">rotary_embedding_scale</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span>
|
||
<span class="n">rotary_embedding_max_positions</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1024</span><span class="p">,</span>
|
||
<span class="n">position_embedding_type</span><span class="p">:</span> <span class="n">PositionEmbeddingType</span> <span class="o">=</span> <span class="n">PositionEmbeddingType</span><span class="o">.</span>
|
||
<span class="n">learned_absolute</span><span class="p">,</span>
|
||
<span class="n">kv_orig_quant_scale</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">kv_quant_orig_scale</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">kv_cache_quant_mode</span><span class="p">:</span> <span class="n">QuantMode</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">max_context_length</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">mask_type</span><span class="p">:</span> <span class="n">AttentionMaskType</span> <span class="o">=</span> <span class="n">AttentionMaskType</span><span class="o">.</span><span class="n">causal</span><span class="p">,</span>
|
||
<span class="n">alibi_slopes</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">tp_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">tp_rank</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">kv_cache_block_pointers</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">do_cross_attention</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">cross_qkv</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="c1"># for cross attention</span>
|
||
<span class="n">cross_qkv_length</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="c1"># for cross attention</span>
|
||
<span class="n">encoder_input_lengths</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="c1"># for cross attention</span>
|
||
<span class="n">relative_attention_bias</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="c1"># for relative attention</span>
|
||
<span class="n">max_distance</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span> <span class="c1"># for relative attention</span>
|
||
<span class="n">host_context_lengths</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="c1"># for pad-free input mode</span>
|
||
<span class="n">qkv_bias</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tuple</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation that performs the multi-head attention in GPT-like models.</span>
|
||
|
||
<span class="sd"> The signature of the function will change in the future release - we are in</span>
|
||
<span class="sd"> the process of simplifying the API. The current version is still</span>
|
||
<span class="sd"> work-in-progress! The following API is provided with hints regarding the</span>
|
||
<span class="sd"> arguments that are likely to be removed or merged with others in the future</span>
|
||
<span class="sd"> release.</span>
|
||
|
||
<span class="sd"> See docs/gpt_attention.md for the documentation of that function.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> tensor: Tensor</span>
|
||
<span class="sd"> The input QKV tensor. Its shape is [batch_beam_size, max_seqlen, 3</span>
|
||
<span class="sd"> * hidden_dim] in padded mode and [1, num_tokens, 3 * hidden_dim] in</span>
|
||
<span class="sd"> packed mode. See QKV Input in docs/gpt_attention.md.</span>
|
||
|
||
<span class="sd"> past_key_value: Tensor</span>
|
||
<span class="sd"> The tensor that stores KV cache data. Its shape is</span>
|
||
<span class="sd"> [max_batch_size * max_beam_width, 2, num_heads, max_seqlen, hidden_dim_per_head]</span>
|
||
<span class="sd"> in contiguous mode and</span>
|
||
<span class="sd"> [max_blocks, 2, num_heads, num_tokens_per_block, hidden_dim_per_head]</span>
|
||
<span class="sd"> in paged mode. See KV Cache in docs/gpt_attention.md,</span>
|
||
|
||
<span class="sd"> sequence_lengths: Tensor</span>
|
||
<span class="sd"> The tensor that stores the length of each sequence. Its shape is</span>
|
||
<span class="sd"> [batch_size]. See QKV Input in docs/gpt_attention.md,</span>
|
||
|
||
<span class="sd"> host past_key_value_length: Tensor</span>
|
||
<span class="sd"> An INT32 tensor of shape [batch_size].</span>
|
||
|
||
<span class="sd"> host max_kv_cache_lengths: Tensor</span>
|
||
<span class="sd"> An INT32 tensor of shape [1].</span>
|
||
<span class="sd"> by default, the max_kv_cache_length is determined by the shape of cache_indir_table.</span>
|
||
<span class="sd"> And we support flexible max_kv_cache_length (or max_past_length) for each layer.</span>
|
||
|
||
<span class="sd"> context_lengths: Tensor</span>
|
||
<span class="sd"> The tensor that stores the context-phase sequence length of each request. Its shape</span>
|
||
<span class="sd"> is [batch_size]. See QKV Input in doc/functional.py,</span>
|
||
|
||
<span class="sd"> cache_indirection: Tensor</span>
|
||
<span class="sd"> The tensor to reconstruct the paths when using beam-search. Its</span>
|
||
<span class="sd"> shape is [batch_size, beam_width, max_seqlen]. See Beam-Search in</span>
|
||
<span class="sd"> docs/gpt_attention.md,</span>
|
||
|
||
<span class="sd"> host_request_types: Tensor = None</span>
|
||
<span class="sd"> The tensor on the host that indicates if a request is in context or</span>
|
||
<span class="sd"> generation phase. Its shape is [batch_size]. See Inflight Batching</span>
|
||
<span class="sd"> in docs/gpt_attention.md,</span>
|
||
|
||
<span class="sd"> num_heads: int</span>
|
||
<span class="sd"> The number of heads,</span>
|
||
|
||
<span class="sd"> num_kv_heads: int</span>
|
||
<span class="sd"> The number of KV heads, generic to handle MHA/MQA/GQA,</span>
|
||
|
||
<span class="sd"> hidden_size_per_head: int</span>
|
||
<span class="sd"> The hidden size per head,</span>
|
||
|
||
<span class="sd"> q_scaling: float</span>
|
||
<span class="sd"> The value used to compute the scaling factor applied to the output</span>
|
||
<span class="sd"> of the Q*K^T product. See Scaling Factors in docs/gpt_attention.md,</span>
|
||
|
||
<span class="sd"> rotary_embedding_dim: int</span>
|
||
<span class="sd"> The dimension to compute RoPE. Use 0 when position_embedding_type is not RoPE.</span>
|
||
|
||
<span class="sd"> rotary_embedding_base: float</span>
|
||
<span class="sd"> The theta value to use for RoPE. Ignored when position_embedding_type is not RoPE.</span>
|
||
|
||
<span class="sd"> rotary_embedding_scale_type: RotaryScalingType</span>
|
||
<span class="sd"> The scaling type of RoPE. Ignored when position_embedding_type is not RoPE.</span>
|
||
<span class="sd"> Possible rotary scaling type:</span>
|
||
<span class="sd"> * RotaryScalingType.none</span>
|
||
<span class="sd"> * RotaryScalingType.linear</span>
|
||
<span class="sd"> * RotaryScalingType.dynamic</span>
|
||
|
||
<span class="sd"> rotary_embedding_scale: float</span>
|
||
<span class="sd"> The scale value to use for linear/dynamic scaling in RoPE.</span>
|
||
<span class="sd"> Ignored when position_embedding_type is not RoPE.</span>
|
||
<span class="sd"> Must be set to 1 (default) if rotary_embedding_scale_type is `none`.</span>
|
||
|
||
<span class="sd"> rotary_embedding_max_positions: int</span>
|
||
<span class="sd"> Needed only for `dynamic` RoPE scaling. Ignored otherwise.</span>
|
||
|
||
<span class="sd"> position_embedding_type: PositionEmbeddingType</span>
|
||
<span class="sd"> The position embedding type:</span>
|
||
<span class="sd"> * PositionEmbeddingType.learned_absolute</span>
|
||
<span class="sd"> * PositionEmbeddingType.relative</span>
|
||
<span class="sd"> * PositionEmbeddingType.rope_gptj</span>
|
||
<span class="sd"> * PositionEmbeddingType.rope_gpt_neox</span>
|
||
<span class="sd"> * PositionEmbeddingType.alibi</span>
|
||
<span class="sd"> * PositionEmbeddingType.alibi_with_scale</span>
|
||
|
||
<span class="sd"> kv_orig_quant_scale: Tensor</span>
|
||
<span class="sd"> The tensor to store the scaling factor for quantization to INT8/FP8</span>
|
||
<span class="sd"> in the KV cache. Its shape is [1]. See INT8/FP8 KV Cache in</span>
|
||
<span class="sd"> docs/gpt_attention.md,</span>
|
||
|
||
<span class="sd"> kv_quant_orig_scale: Tensor</span>
|
||
<span class="sd"> The tensor to store the scaling factor for dequantization from</span>
|
||
<span class="sd"> INT8/FP8 in the KV cache. Its shape is [1]. See INT8/FP8 KV Cache</span>
|
||
<span class="sd"> in docs/gpt_attention.md,</span>
|
||
|
||
<span class="sd"> kv_cache_quant_mode: QuantMode (int flags)</span>
|
||
<span class="sd"> Do we enable the INT8 or FP8 KV cache?</span>
|
||
|
||
<span class="sd"> max_context_length: int32_t</span>
|
||
<span class="sd"> The length of the longest input sequence. See QKV Input in</span>
|
||
<span class="sd"> docs/gpt_attention.md,</span>
|
||
|
||
<span class="sd"> mask_type: int = 1</span>
|
||
<span class="sd"> The type of mask:</span>
|
||
<span class="sd"> * tensorrt_llm.layers.AttentionMaskType.padding for BERT,</span>
|
||
<span class="sd"> * tensorrt_llm.layers.AttentionMaskType.causal for GPT,</span>
|
||
<span class="sd"> * tensorrt_llm.layers.AttentionMaskType.bidirectional for ChatGLM-6B,</span>
|
||
<span class="sd"> * tensorrt_llm.layers.AttentionMaskType.bidirectionalglm for GLM-10B,</span>
|
||
|
||
<span class="sd"> alibi_slopes: Tensor</span>
|
||
<span class="sd"> The ALiBi slopes. The ALiBi bias is computed on-the-fly in the kernel</span>
|
||
<span class="sd"> when possible,</span>
|
||
|
||
<span class="sd"> tp_size: int</span>
|
||
<span class="sd"> The number of processes/GPUs when tensor parallelism is activated,</span>
|
||
|
||
<span class="sd"> tp_rank: int</span>
|
||
<span class="sd"> The rank of that process (when running tensor parallelism),</span>
|
||
|
||
<span class="sd"> kv_cache_block_pointers:</span>
|
||
<span class="sd"> The tensor of block pointers for the KV cache. Its shape is</span>
|
||
<span class="sd"> [max_batch_size, max_beam_width, 2, max_blocks_per_sequence * 2]</span>
|
||
<span class="sd"> See KV cache section in docs/gpt_attention.md,</span>
|
||
|
||
<span class="sd"> do_cross_attention: bool = False</span>
|
||
<span class="sd"> Do we use this as cross attention instead of self attention,</span>
|
||
|
||
<span class="sd"> cross_qkv: Tensor = None</span>
|
||
<span class="sd"> The QKV tensor of encoder output hidden states. Its shape is [batch_size, max_seqlen, 3</span>
|
||
<span class="sd"> * hidden_dim] in padded mode and [1, num_tokens, 3 * hidden_dim] in</span>
|
||
<span class="sd"> packed mode,</span>
|
||
|
||
<span class="sd"> cross_qkv_length: Tensor = None</span>
|
||
<span class="sd"> The length of the longest encoder output sequence,</span>
|
||
|
||
<span class="sd"> encoder_input_lengths: Tensor</span>
|
||
<span class="sd"> The tensor that stores the length of each encoder input sequence. Its shape is [batch_size],</span>
|
||
|
||
<span class="sd"> relative_attention_bias: Tensor = None</span>
|
||
<span class="sd"> The relative attention bias [num_heads, max_seq_len, max_seq_len], or The relative attention embedding table for implicit mode, [num_heads, num_buckets].</span>
|
||
|
||
<span class="sd"> max_distance: int = 0</span>
|
||
<span class="sd"> The maximum distance of relative position in attention, for implicit mode.</span>
|
||
<span class="sd"> Default value is 0, meaning to use the regular mode of relative attention bias.</span>
|
||
<span class="sd"> Implicit mode is only enabled when passing in non-zero positive max_distance value.</span>
|
||
<span class="sd"> See relative attention bias in docs/gpt_attention.md</span>
|
||
|
||
<span class="sd"> host_context_lengths: Tensor = None</span>
|
||
<span class="sd"> A host tensor that contains the lengths of the different inputs,</span>
|
||
|
||
<span class="sd"> qkv_bias: Tensor = None,</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by that layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">assert</span> <span class="n">host_request_types</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="k">assert</span> <span class="p">(</span><span class="n">alibi_slopes</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">)</span> <span class="o">==</span> <span class="p">(</span><span class="n">position_embedding_type</span><span class="o">.</span><span class="n">is_alibi</span><span class="p">())</span>
|
||
<span class="n">attn_plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
|
||
<span class="s1">'GPTAttention'</span><span class="p">,</span> <span class="s1">'1'</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">attn_plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="k">assert</span> <span class="n">host_context_lengths</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">or</span> <span class="ow">not</span> <span class="n">default_net</span><span class="p">(</span>
|
||
<span class="p">)</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">remove_input_padding</span>
|
||
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">max_context_length</span><span class="p">,</span> <span class="nb">int</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">host_max_kv_cache_lengths</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
|
||
<span class="n">paged_kv_cache_flag</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">paged_kv_cache</span>
|
||
|
||
<span class="n">nheads</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"num_heads"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">num_heads</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">num_kv_heads</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"num_kv_heads"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">num_kv_heads</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">head_size</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"head_size"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">hidden_size_per_head</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">unidirectional</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"unidirectional"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">q_scaling</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"q_scaling"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">q_scaling</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">FLOAT32</span><span class="p">)</span>
|
||
<span class="n">rotary_embedding_dim</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"rotary_embedding_dim"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">rotary_embedding_dim</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">rotary_embedding_base</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"rotary_embedding_base"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">rotary_embedding_base</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">FLOAT32</span><span class="p">)</span>
|
||
<span class="n">rotary_embedding_scale_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"rotary_embedding_scale_type"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">rotary_embedding_scale_type</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">rotary_embedding_scale</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"rotary_embedding_scale"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">rotary_embedding_scale</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">FLOAT32</span><span class="p">)</span>
|
||
<span class="n">rotary_embedding_max_positions</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"rotary_embedding_max_positions"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">rotary_embedding_max_positions</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">position_embedding_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"position_embedding_type"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="nb">int</span><span class="p">(</span><span class="n">position_embedding_type</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">context_fmha_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"context_fmha_type"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">context_fmha_type</span><span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">remove_input_padding</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"remove_input_padding"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">p_dtype</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">gpt_attention_plugin</span>
|
||
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"type_id"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">))],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">mask_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"mask_type"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">mask_type</span><span class="p">)],</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">multi_block_mode</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"multi_block_mode"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">multi_block_mode</span><span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">tp_size</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"tp_size"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">tp_size</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">tp_rank</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"tp_rank"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">tp_rank</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">kv_cache_quant_mode_field</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"kv_cache_quant_mode"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">kv_cache_quant_mode</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">paged_kv_cache</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"paged_kv_cache"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">paged_kv_cache_flag</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">tokens_per_block</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"tokens_per_block"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">tokens_per_block</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">max_context_length</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"max_context_length"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">max_context_length</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">qkv_bias</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">qkv_bias_enabled</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"qkv_bias_enabled"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">qkv_bias_enabled</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"qkv_bias_enabled"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">do_cross_attention_field</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"do_cross_attention"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">do_cross_attention</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">max_distance</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"max_distance"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">max_distance</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
|
||
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span>
|
||
<span class="n">nheads</span><span class="p">,</span> <span class="n">num_kv_heads</span><span class="p">,</span> <span class="n">head_size</span><span class="p">,</span> <span class="n">unidirectional</span><span class="p">,</span> <span class="n">q_scaling</span><span class="p">,</span>
|
||
<span class="n">position_embedding_type</span><span class="p">,</span> <span class="n">rotary_embedding_dim</span><span class="p">,</span> <span class="n">rotary_embedding_base</span><span class="p">,</span>
|
||
<span class="n">rotary_embedding_scale_type</span><span class="p">,</span> <span class="n">rotary_embedding_scale</span><span class="p">,</span>
|
||
<span class="n">rotary_embedding_max_positions</span><span class="p">,</span> <span class="n">tp_size</span><span class="p">,</span> <span class="n">tp_rank</span><span class="p">,</span> <span class="n">context_fmha_type</span><span class="p">,</span>
|
||
<span class="n">multi_block_mode</span><span class="p">,</span> <span class="n">kv_cache_quant_mode_field</span><span class="p">,</span> <span class="n">remove_input_padding</span><span class="p">,</span>
|
||
<span class="n">mask_type</span><span class="p">,</span> <span class="n">paged_kv_cache</span><span class="p">,</span> <span class="n">tokens_per_block</span><span class="p">,</span> <span class="n">pf_type</span><span class="p">,</span>
|
||
<span class="n">max_context_length</span><span class="p">,</span> <span class="n">qkv_bias_enabled</span><span class="p">,</span> <span class="n">do_cross_attention_field</span><span class="p">,</span>
|
||
<span class="n">max_distance</span>
|
||
<span class="p">])</span>
|
||
|
||
<span class="n">attn_plug</span> <span class="o">=</span> <span class="n">attn_plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">"causal_attn"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span>
|
||
<span class="n">tensor</span><span class="p">,</span>
|
||
<span class="n">sequence_length</span><span class="p">,</span>
|
||
<span class="n">host_past_key_value_lengths</span><span class="p">,</span>
|
||
<span class="n">host_max_kv_cache_lengths</span><span class="p">,</span>
|
||
<span class="n">context_lengths</span><span class="p">,</span>
|
||
<span class="n">cache_indirection</span><span class="p">,</span>
|
||
<span class="n">host_request_types</span><span class="p">,</span>
|
||
<span class="p">]</span>
|
||
|
||
<span class="k">if</span> <span class="n">paged_kv_cache_flag</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">kv_cache_block_pointers</span><span class="p">]</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">past_key_value</span><span class="p">]</span>
|
||
|
||
<span class="k">if</span> <span class="n">kv_cache_quant_mode</span><span class="o">.</span><span class="n">has_kv_cache_quant</span><span class="p">():</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">kv_orig_quant_scale</span><span class="p">,</span> <span class="n">kv_quant_orig_scale</span><span class="p">]</span>
|
||
|
||
<span class="k">if</span> <span class="n">alibi_slopes</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">alibi_slopes</span><span class="p">]</span>
|
||
|
||
<span class="k">if</span> <span class="n">relative_attention_bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">relative_attention_bias</span><span class="p">]</span>
|
||
|
||
<span class="k">if</span> <span class="n">do_cross_attention</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">cross_qkv</span><span class="p">,</span> <span class="n">cross_qkv_length</span><span class="p">,</span> <span class="n">encoder_input_lengths</span><span class="p">]</span>
|
||
|
||
<span class="k">if</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">host_context_lengths</span><span class="p">]</span>
|
||
|
||
<span class="k">if</span> <span class="n">qkv_bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">qkv_bias</span><span class="p">]</span>
|
||
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">i</span><span class="o">.</span><span class="n">trt_tensor</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">plug_inputs</span><span class="p">]</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">attn_plug</span><span class="p">)</span>
|
||
<span class="n">output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
<span class="n">present_key_value</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="n">paged_kv_cache_flag</span><span class="p">:</span>
|
||
<span class="n">present_key_value</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">present_key_value</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="n">expected_outputs</span> <span class="o">=</span> <span class="mi">2</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">expected_outputs</span> <span class="o">=</span> <span class="mi">1</span>
|
||
|
||
<span class="k">assert</span> <span class="n">layer</span><span class="o">.</span><span class="n">num_outputs</span> <span class="o">==</span> <span class="n">expected_outputs</span><span class="p">,</span> \
|
||
<span class="sa">f</span><span class="s2">"Plugin outputs number mismatch with expected, got </span><span class="si">{</span><span class="n">layer</span><span class="o">.</span><span class="n">num_outputs</span><span class="si">}</span><span class="s2">, expected </span><span class="si">{</span><span class="n">expected_outputs</span><span class="si">}</span><span class="s2">"</span>
|
||
|
||
<span class="k">if</span> <span class="n">kv_cache_quant_mode</span><span class="o">.</span><span class="n">has_int8_kv_cache</span><span class="p">()</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">paged_kv_cache_flag</span><span class="p">:</span>
|
||
<span class="c1"># past key value</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">get_input</span><span class="p">(</span><span class="mi">7</span><span class="p">)</span><span class="o">.</span><span class="n">set_dynamic_range</span><span class="p">(</span><span class="o">-</span><span class="mi">127</span><span class="p">,</span> <span class="mi">127</span><span class="p">)</span>
|
||
<span class="c1"># present key value</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">set_dynamic_range</span><span class="p">(</span><span class="o">-</span><span class="mi">127</span><span class="p">,</span> <span class="mi">127</span><span class="p">)</span>
|
||
|
||
<span class="k">assert</span> <span class="n">output</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="k">return</span> <span class="n">output</span><span class="p">,</span> <span class="n">present_key_value</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="assertion">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.assertion">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">assertion</span><span class="p">(</span><span class="n">condition</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">message</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s1">''</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_assertion</span><span class="p">(</span><span class="n">condition</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">message</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="layer_norm">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.layer_norm">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">layer_norm</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">normalized_shape</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">]],</span>
|
||
<span class="n">weight</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">bias</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">eps</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-05</span><span class="p">,</span>
|
||
<span class="n">use_diff_of_squares</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add a layer-norm operation on a tensor.</span>
|
||
|
||
<span class="sd"> That operation applies the layer-normalization to its input tensor. In its</span>
|
||
<span class="sd"> simplest form, for large language models, the 'normalized_shape' should be</span>
|
||
<span class="sd"> set to the hidden dimension of the activation tensor. Otherwise, it is the</span>
|
||
<span class="sd"> shape of the normalized fraction of the tensor (starting from the</span>
|
||
<span class="sd"> right-most dimension).</span>
|
||
|
||
<span class="sd"> The 'weight' tensor corresponds to 'gamma' in the layer-norm formula and</span>
|
||
<span class="sd"> 'bias' is 'beta'. The 'eps' value is added to the variance before computing</span>
|
||
<span class="sd"> the squared-root.</span>
|
||
|
||
<span class="sd"> This implementation (when using the plugin) supports an additional flag to</span>
|
||
<span class="sd"> enable/disable the use of a difference of squares ('Var = Mean(X^2) -</span>
|
||
<span class="sd"> Mean(X)^2').</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The tensor to normalize.</span>
|
||
|
||
<span class="sd"> normalized_shape : Union[int, Tuple[int]]</span>
|
||
<span class="sd"> The shape of the sub-tensor that is normalized. Use 'hidden_dim' to</span>
|
||
<span class="sd"> normalize the inner-most dimension of an activation tensor in LLMs.</span>
|
||
|
||
<span class="sd"> weight : Optional[Tensor] = None</span>
|
||
<span class="sd"> The 'gamma' term in layer-norm. Its shape must be</span>
|
||
<span class="sd"> 'normalized_shape'.</span>
|
||
|
||
<span class="sd"> bias : Optional[Tensor] = None</span>
|
||
<span class="sd"> The 'beta' term in layer-norm. Its shape must be</span>
|
||
<span class="sd"> 'normalized_shape'.</span>
|
||
|
||
<span class="sd"> eps : float</span>
|
||
<span class="sd"> The epsilon term to be added to the variance in the squared-root.</span>
|
||
|
||
<span class="sd"> use_diff_of_squares : bool</span>
|
||
<span class="sd"> Does the plugin use the difference of squares to compute the</span>
|
||
<span class="sd"> variance?</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The output tensor of that operation.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">layernorm_plugin</span><span class="p">:</span>
|
||
<span class="nb">input</span><span class="p">,</span> <span class="n">weight</span> <span class="o">=</span> <span class="n">broadcast_helper</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">weight</span><span class="p">)</span>
|
||
<span class="nb">input</span><span class="p">,</span> <span class="n">bias</span> <span class="o">=</span> <span class="n">broadcast_helper</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">bias</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">normalized_shape</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span> <span class="c1"># FIXME: better way?</span>
|
||
<span class="n">axis</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">-</span> <span class="mi">1</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">axis</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">-</span> <span class="nb">len</span><span class="p">(</span><span class="n">normalized_shape</span><span class="p">)</span>
|
||
<span class="n">axes_mask</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">axis</span><span class="p">,</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()):</span>
|
||
<span class="n">axes_mask</span> <span class="o">|=</span> <span class="mi">1</span> <span class="o"><<</span> <span class="n">i</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_normalization</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">weight</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">bias</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">axes_mask</span><span class="p">)</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">epsilon</span> <span class="o">=</span> <span class="n">eps</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
|
||
<span class="s1">'Layernorm'</span><span class="p">,</span> <span class="s1">'1'</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
|
||
<span class="n">eps</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"eps"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">eps</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">FLOAT32</span><span class="p">)</span>
|
||
<span class="n">use_diff_of_squares</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"use_diff_of_squares"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">use_diff_of_squares</span><span class="p">)],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">p_dtype</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">layernorm_plugin</span>
|
||
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"type_id"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">))],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span><span class="n">eps</span><span class="p">,</span> <span class="n">use_diff_of_squares</span><span class="p">,</span> <span class="n">pf_type</span><span class="p">])</span>
|
||
<span class="n">layernorm_plug</span> <span class="o">=</span> <span class="n">plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">"layernorm"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
|
||
<span class="n">normalized_shape</span> <span class="o">=</span> <span class="p">[</span><span class="n">normalized_shape</span><span class="p">]</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span>
|
||
<span class="n">normalized_shape</span><span class="p">,</span> <span class="nb">int</span><span class="p">)</span> <span class="k">else</span> <span class="n">normalized_shape</span>
|
||
<span class="k">if</span> <span class="n">weight</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">weight</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">normalized_shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">str_dtype_to_np</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">)))</span>
|
||
<span class="k">if</span> <span class="n">bias</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">bias</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">normalized_shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">str_dtype_to_np</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">)))</span>
|
||
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">weight</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">bias</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">]</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">layernorm_plug</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="rms_norm">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.rms_norm">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">rms_norm</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">normalized_shape</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">]],</span>
|
||
<span class="n">weight</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">eps</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-06</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add a RMS norm operation on a tensor.</span>
|
||
|
||
<span class="sd"> That operation applies the rms-normalization to its input tensor. In its</span>
|
||
<span class="sd"> simplest form, for large language models, the 'normalized_shape' should be</span>
|
||
<span class="sd"> set to the hidden dimension of the activation tensor. Otherwise, it is the</span>
|
||
<span class="sd"> shape of the normalized fraction of the tensor (starting from the</span>
|
||
<span class="sd"> right-most dimension).</span>
|
||
|
||
<span class="sd"> The 'weight' tensor corresponds to 'gamma' in the rms-norm formula.</span>
|
||
<span class="sd"> The 'eps' value is added to the variance before computing the squared-root.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input: Tensor</span>
|
||
<span class="sd"> The tensor to normalize.</span>
|
||
|
||
<span class="sd"> normalized_shape : Union[int, Tuple[int]]</span>
|
||
<span class="sd"> The shape of the sub-tensor that is normalized. Use 'hidden_dim' to</span>
|
||
<span class="sd"> normalize the inner-most dimension of an activation tensor in LLMs.</span>
|
||
|
||
<span class="sd"> weight : Optional[Tensor] = None</span>
|
||
<span class="sd"> The 'gamma' term in layer-norm. Its shape must be</span>
|
||
<span class="sd"> 'normalized_shape'.</span>
|
||
|
||
<span class="sd"> eps : float</span>
|
||
<span class="sd"> The epsilon term to be added to the variance in the squared-root.weig</span>
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The output tensor of that operation.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">rmsnorm_plugin</span><span class="p">:</span>
|
||
<span class="n">normalized_shape</span> <span class="o">=</span> <span class="p">[</span><span class="n">normalized_shape</span><span class="p">]</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span>
|
||
<span class="n">normalized_shape</span><span class="p">,</span> <span class="nb">int</span><span class="p">)</span> <span class="k">else</span> <span class="n">normalized_shape</span>
|
||
|
||
<span class="n">dim</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">([</span><span class="o">-</span><span class="n">i</span> <span class="o">-</span> <span class="mi">1</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">normalized_shape</span><span class="p">))])</span>
|
||
|
||
<span class="k">if</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">strongly_typed</span><span class="p">:</span>
|
||
<span class="n">input_dtype</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">dtype</span>
|
||
<span class="n">fp32_input</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="s2">"float32"</span><span class="p">)</span>
|
||
<span class="n">varx</span> <span class="o">=</span> <span class="nb">pow</span><span class="p">(</span><span class="n">fp32_input</span><span class="p">,</span> <span class="mf">2.0</span><span class="p">)</span>
|
||
|
||
<span class="n">varx</span> <span class="o">=</span> <span class="n">varx</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||
<span class="n">denom</span> <span class="o">=</span> <span class="n">varx</span> <span class="o">+</span> <span class="n">eps</span>
|
||
<span class="n">denom</span> <span class="o">=</span> <span class="n">denom</span><span class="o">.</span><span class="n">sqrt</span><span class="p">()</span>
|
||
<span class="n">fp32_y</span> <span class="o">=</span> <span class="n">fp32_input</span> <span class="o">/</span> <span class="n">denom</span>
|
||
<span class="n">y</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">fp32_y</span><span class="p">,</span> <span class="n">input_dtype</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="k">with</span> <span class="n">precision</span><span class="p">(</span><span class="s2">"float32"</span><span class="p">):</span>
|
||
<span class="n">varx</span> <span class="o">=</span> <span class="nb">pow</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="mf">2.0</span><span class="p">)</span>
|
||
<span class="n">varx</span> <span class="o">=</span> <span class="n">varx</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||
<span class="n">denom</span> <span class="o">=</span> <span class="n">varx</span> <span class="o">+</span> <span class="n">eps</span>
|
||
<span class="n">denom</span> <span class="o">=</span> <span class="n">denom</span><span class="o">.</span><span class="n">sqrt</span><span class="p">()</span>
|
||
<span class="n">y</span> <span class="o">=</span> <span class="nb">input</span> <span class="o">/</span> <span class="n">denom</span>
|
||
|
||
<span class="k">if</span> <span class="n">weight</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">y</span> <span class="o">=</span> <span class="n">y</span> <span class="o">*</span> <span class="n">weight</span>
|
||
|
||
<span class="k">return</span> <span class="n">y</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="c1"># TODO remove the plugin version if rmsnorm operation can be offloaded</span>
|
||
<span class="c1"># to Myelin.</span>
|
||
<span class="n">plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
|
||
<span class="s1">'Rmsnorm'</span><span class="p">,</span> <span class="s1">'1'</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
|
||
<span class="n">eps</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"eps"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">eps</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">FLOAT32</span><span class="p">)</span>
|
||
<span class="n">p_dtype</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">rmsnorm_plugin</span>
|
||
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"type_id"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">))],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span><span class="n">eps</span><span class="p">,</span> <span class="n">pf_type</span><span class="p">])</span>
|
||
<span class="n">rmsnorm_plug</span> <span class="o">=</span> <span class="n">plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">"rmsnorm"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
|
||
<span class="n">normalized_shape</span> <span class="o">=</span> <span class="p">[</span><span class="n">normalized_shape</span><span class="p">]</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span>
|
||
<span class="n">normalized_shape</span><span class="p">,</span> <span class="nb">int</span><span class="p">)</span> <span class="k">else</span> <span class="n">normalized_shape</span>
|
||
<span class="k">if</span> <span class="n">weight</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">weight</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">normalized_shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">str_dtype_to_np</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">)))</span>
|
||
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">weight</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">]</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">rmsnorm_plug</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="repeat_interleave">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.repeat_interleave">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">repeat_interleave</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">repeats</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Repeats elements of a tensor along an axis.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> repeats : int</span>
|
||
<span class="sd"> The number of repetitions along axis specified.</span>
|
||
<span class="sd"> dim : int</span>
|
||
<span class="sd"> The dimension along which repetitions are performed.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> A tensor with the same shape as input except for repeated elements along specified dim.</span>
|
||
|
||
<span class="sd"> TODO: Allow repeats to be a list of integers and dim to be unspecified.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">expanded_tensor</span> <span class="o">=</span> <span class="n">expand_dims</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">dim</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
|
||
<span class="n">tile_output_size</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span>
|
||
<span class="n">repeats</span> <span class="k">if</span> <span class="n">i</span> <span class="o">==</span> <span class="p">(</span><span class="n">dim</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="k">else</span> <span class="n">shape</span><span class="p">(</span><span class="n">expanded_tensor</span><span class="p">,</span> <span class="n">i</span><span class="p">)</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">expanded_tensor</span><span class="o">.</span><span class="n">ndim</span><span class="p">())</span>
|
||
<span class="p">])</span>
|
||
<span class="n">tile</span> <span class="o">=</span> <span class="n">expand</span><span class="p">(</span><span class="n">expanded_tensor</span><span class="p">,</span> <span class="n">tile_output_size</span><span class="p">)</span>
|
||
<span class="n">tile_reshape_size</span> <span class="o">=</span> <span class="p">[</span><span class="n">shape</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">i</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">ndim</span><span class="p">())]</span>
|
||
<span class="n">tile_reshape_size</span><span class="p">[</span><span class="n">dim</span><span class="p">]</span> <span class="o">=</span> <span class="n">tile_reshape_size</span><span class="p">[</span><span class="n">dim</span><span class="p">]</span> <span class="o">*</span> <span class="n">repeats</span>
|
||
<span class="n">tensor</span> <span class="o">=</span> <span class="n">tile</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">concat</span><span class="p">(</span><span class="n">tile_reshape_size</span><span class="p">))</span>
|
||
<span class="k">return</span> <span class="n">tensor</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="generate_alibi_slopes">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.generate_alibi_slopes">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">generate_alibi_slopes</span><span class="p">(</span><span class="n">num_heads</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span>
|
||
<span class="n">tp_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">tp_rank</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">alibi_scale</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Compute the ALiBi slopes as described in https://arxiv.org/abs/2211.05100.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> num_heads : int</span>
|
||
<span class="sd"> The number of heads.</span>
|
||
<span class="sd"> dtype : trt.DataType</span>
|
||
<span class="sd"> The data type of the returned slopes</span>
|
||
<span class="sd"> tp_size : int</span>
|
||
<span class="sd"> The tensor parallelism size</span>
|
||
<span class="sd"> tp_rank : int</span>
|
||
<span class="sd"> The tensor parallelism rank</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> A constant tensor that contains the ALiBi slopes.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">start_head_id</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="n">end_head_id</span> <span class="o">=</span> <span class="n">num_heads</span>
|
||
|
||
<span class="k">if</span> <span class="n">tp_size</span> <span class="o">></span> <span class="mi">1</span><span class="p">:</span>
|
||
<span class="n">rank_heads</span> <span class="o">=</span> <span class="n">num_heads</span> <span class="o">//</span> <span class="n">tp_size</span>
|
||
<span class="n">start_head_id</span> <span class="o">=</span> <span class="n">rank_heads</span> <span class="o">*</span> <span class="n">tp_rank</span>
|
||
<span class="n">end_head_id</span> <span class="o">=</span> <span class="n">start_head_id</span> <span class="o">+</span> <span class="n">rank_heads</span>
|
||
|
||
<span class="n">closest_power_of_2</span> <span class="o">=</span> <span class="mi">2</span><span class="o">**</span><span class="n">np</span><span class="o">.</span><span class="n">floor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">log2</span><span class="p">(</span><span class="n">num_heads</span><span class="p">))</span>
|
||
<span class="c1"># FT's implementation</span>
|
||
<span class="c1"># https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/kernels/gen_relative_pos_bias.cu#L248</span>
|
||
<span class="n">slopes_ft</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="k">for</span> <span class="n">h_id</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">start_head_id</span><span class="p">,</span> <span class="n">end_head_id</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="n">h_id</span> <span class="o"><</span> <span class="n">closest_power_of_2</span><span class="p">:</span>
|
||
<span class="n">slopes_ft</span><span class="o">.</span><span class="n">append</span><span class="p">(</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">power</span><span class="p">(</span><span class="mi">2</span><span class="o">**</span><span class="p">(</span><span class="o">-</span><span class="p">(</span><span class="mi">2</span><span class="o">**-</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">log2</span><span class="p">(</span><span class="n">closest_power_of_2</span><span class="p">)</span> <span class="o">-</span> <span class="mi">3</span><span class="p">))),</span>
|
||
<span class="n">h_id</span> <span class="o">+</span> <span class="mi">1</span><span class="p">))</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">slopes_ft</span><span class="o">.</span><span class="n">append</span><span class="p">(</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">power</span><span class="p">(</span><span class="mi">2</span><span class="o">**</span><span class="p">(</span><span class="o">-</span><span class="p">(</span><span class="mi">2</span><span class="o">**-</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">log2</span><span class="p">(</span><span class="n">closest_power_of_2</span> <span class="o">*</span> <span class="mi">2</span><span class="p">)</span> <span class="o">-</span> <span class="mi">3</span><span class="p">))),</span>
|
||
<span class="p">(</span><span class="n">h_id</span> <span class="o">-</span> <span class="n">closest_power_of_2</span><span class="p">)</span> <span class="o">*</span> <span class="mi">2</span> <span class="o">+</span> <span class="mi">1</span><span class="p">))</span>
|
||
<span class="n">slopes</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">asarray</span><span class="p">(</span><span class="n">slopes_ft</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
||
|
||
<span class="n">slopes</span> <span class="o">=</span> <span class="n">alibi_scale</span> <span class="o">*</span> <span class="n">slopes</span>
|
||
<span class="c1"># Note that for bfloat16, we cannot case numpy tensor from float32 to bfloat16</span>
|
||
<span class="c1"># because numpy does not support bfloat16. Even if we use custom type to define</span>
|
||
<span class="c1"># the np_bfloat16, the "astype" here would be undefined.</span>
|
||
<span class="c1"># So, we must use torch to cast tensor from float32 to bfloat16, and then use torch_to_numpy</span>
|
||
<span class="c1"># to cast the tensor back.</span>
|
||
<span class="n">slopes</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">from_numpy</span><span class="p">(</span><span class="n">slopes</span><span class="p">)</span>
|
||
<span class="n">slopes</span> <span class="o">=</span> <span class="n">slopes</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">trt_dtype_to_torch</span><span class="p">(</span><span class="n">dtype</span><span class="p">))</span>
|
||
<span class="n">slopes</span> <span class="o">=</span> <span class="n">torch_to_numpy</span><span class="p">(</span><span class="n">slopes</span><span class="p">)</span>
|
||
<span class="n">slopes</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">slopes</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="p">(</span><span class="n">end_head_id</span> <span class="o">-</span> <span class="n">start_head_id</span><span class="p">),</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
|
||
<span class="k">return</span> <span class="n">slopes</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="generate_alibi_biases">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.generate_alibi_biases">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">generate_alibi_biases</span><span class="p">(</span><span class="n">slopes</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">key_length</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Compute the ALiBi biases as described in https://arxiv.org/abs/2211.05100.</span>
|
||
|
||
<span class="sd"> The ALiBi biases are added to the result of the Q*K^T product in the</span>
|
||
<span class="sd"> multihead-attention block.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> slopes : Tensor</span>
|
||
<span class="sd"> The slopes.</span>
|
||
|
||
<span class="sd"> key_length : Tensor</span>
|
||
<span class="sd"> The size of the K vector per head.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> A constant tensor that contains the ALiBi biases.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="c1"># We don't need to care about the batch size or query length since we can just broadcast</span>
|
||
<span class="c1"># across the batch and query dimensions</span>
|
||
|
||
<span class="n">trt_0</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">int32_array</span><span class="p">(</span><span class="mi">0</span><span class="p">))</span>
|
||
<span class="n">arange_shape</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">key_length</span><span class="p">])</span>
|
||
|
||
<span class="n">arange_tensor</span> <span class="o">=</span> <span class="n">arange</span><span class="p">(</span><span class="n">trt_0</span><span class="p">,</span> <span class="n">key_length</span><span class="p">,</span> <span class="s2">"float32"</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">arange_shape</span><span class="p">)</span>
|
||
<span class="n">arange_tensor</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">arange_tensor</span><span class="p">,</span> <span class="s2">"float32"</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">slopes</span> <span class="o">*</span> <span class="n">arange_tensor</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="expand_mask">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.expand_mask">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">expand_mask</span><span class="p">(</span><span class="n">mask</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">tgt_len</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Expand an attention mask.</span>
|
||
|
||
<span class="sd"> That function adds the sequence of operations to expand from a tensor of</span>
|
||
<span class="sd"> shape '[batch_size, src_seq_len]' to a tensor of shape</span>
|
||
<span class="sd"> '[batch_size, 1, tgt_seq_len, src_seq_len]'. It can be used to create the</span>
|
||
<span class="sd"> mask applied to the Q*K^T product before the softmax operation in the</span>
|
||
<span class="sd"> multihead-attention block.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> mask : Tensor</span>
|
||
<span class="sd"> The input mask</span>
|
||
|
||
<span class="sd"> tgt_len : Optional[Tensor]</span>
|
||
<span class="sd"> The dimension of the 3rd dimension in the output tensor. If None,</span>
|
||
<span class="sd"> the 2nd dimension of the input is used.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor created by that sequence of operations.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">bsz</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
|
||
<span class="n">src_len</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
|
||
<span class="n">tgt_len</span> <span class="o">=</span> <span class="n">tgt_len</span> <span class="k">if</span> <span class="n">tgt_len</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">src_len</span>
|
||
|
||
<span class="n">mask</span> <span class="o">=</span> <span class="n">mask</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">concat</span><span class="p">([</span><span class="n">bsz</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">src_len</span><span class="p">]))</span>
|
||
|
||
<span class="n">mask</span> <span class="o">=</span> <span class="n">expand</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="n">concat</span><span class="p">([</span><span class="n">bsz</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">tgt_len</span><span class="p">,</span> <span class="n">src_len</span><span class="p">]))</span>
|
||
<span class="n">mask</span> <span class="o">=</span> <span class="n">where</span><span class="p">(</span><span class="n">mask</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="nb">float</span><span class="p">(</span><span class="s1">'-inf'</span><span class="p">),</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">mask</span><span class="p">)</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="s1">'float32'</span><span class="p">))</span>
|
||
<span class="k">return</span> <span class="n">mask</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="gather_last_token_logits">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.gather_last_token_logits">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">gather_last_token_logits</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">last_token_ids</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">remove_input_padding</span><span class="p">:</span> <span class="nb">bool</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Extract the logits that correspond to the last token from the hidden states.</span>
|
||
|
||
<span class="sd"> That function adds the operations to extract the logits of the last tokens</span>
|
||
<span class="sd"> in a batch of sequences.</span>
|
||
|
||
<span class="sd"> Depending on whether 'remove_input_padding' is 'True' or 'False', that</span>
|
||
<span class="sd"> function assumes inputs of different shapes.</span>
|
||
|
||
<span class="sd"> When 'remove_input_padding' is 'True', the 'hidden_states' tensor is</span>
|
||
<span class="sd"> assumed to be packed. It has a shape '[num_tokens, hidden_dim]' where</span>
|
||
<span class="sd"> 'num_tokens' is the sum of the lengths of the sequences in the batch and</span>
|
||
<span class="sd"> 'hidden_dim' is the hidden dimension. The 'last_tokens_ids' is a 1D tensor</span>
|
||
<span class="sd"> that encodes the inclusive prefix-sums of the lengths of the sequences in</span>
|
||
<span class="sd"> the batch.</span>
|
||
|
||
<span class="sd"> When 'remove_input_padding' is 'False', the 'hidden_states' tensor is</span>
|
||
<span class="sd"> assumed to be padded. It has a shape '[batch_size, max_seqlen, hidden_dim]'</span>
|
||
<span class="sd"> where 'max_seqlen' is the length of the longest sequence in the batch and</span>
|
||
<span class="sd"> 'hidden_dim' is the hidden dimension. The 'last_token_ids' is a 1D tensor</span>
|
||
<span class="sd"> that encodes the length of each sequence in the batch.</span>
|
||
|
||
<span class="sd"> In both cases, that function produces a tensor of shape '[batch_size,</span>
|
||
<span class="sd"> hidden_size]' where the row at index 'i' corresponds to the logits of the</span>
|
||
<span class="sd"> last token from the 'i'-th sequence.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> hidden_states : Tensor</span>
|
||
<span class="sd"> The hidden states</span>
|
||
|
||
<span class="sd"> last_token_ids : Tensor</span>
|
||
<span class="sd"> The inclusive prefix-sum of the lengths or the lengths of the</span>
|
||
<span class="sd"> sequences in the batch.</span>
|
||
|
||
<span class="sd"> remove_input_padding : bool</span>
|
||
<span class="sd"> Indicate if the hidden_states are packed ('True') or padded</span>
|
||
<span class="sd"> ('False').</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor created by that sequence of operations.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="n">last_token_ids</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">hidden_states</span>
|
||
|
||
<span class="k">if</span> <span class="n">remove_input_padding</span><span class="p">:</span>
|
||
<span class="n">hidden_states</span> <span class="o">=</span> <span class="n">index_select</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">last_token_ids</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="c1"># [1, seq_len, hidden]</span>
|
||
|
||
<span class="n">hidden_states</span> <span class="o">=</span> <span class="n">hidden_states</span><span class="o">.</span><span class="n">view</span><span class="p">(</span>
|
||
<span class="n">concat</span><span class="p">([</span><span class="n">shape</span><span class="p">(</span><span class="n">last_token_ids</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span>
|
||
<span class="n">shape</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span> <span class="mi">2</span><span class="p">)]))</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="c1"># only calculate logits for the last token</span>
|
||
<span class="c1"># [batch_size, seqlen, hidden_size] -> [batch_size, hidden_size]</span>
|
||
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">last_token_ids</span><span class="o">.</span><span class="n">view</span><span class="p">(</span>
|
||
<span class="n">concat</span><span class="p">([</span><span class="n">shape</span><span class="p">(</span><span class="n">last_token_ids</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">]))</span>
|
||
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">expand</span><span class="p">(</span>
|
||
<span class="n">last_token_ids</span><span class="p">,</span>
|
||
<span class="n">concat</span><span class="p">([</span><span class="n">shape</span><span class="p">(</span><span class="n">last_token_ids</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span> <span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">shape</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span> <span class="mi">2</span><span class="p">)]))</span>
|
||
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">last_token_ids</span> <span class="o">-</span> <span class="mi">1</span>
|
||
<span class="n">hidden_states</span> <span class="o">=</span> <span class="n">gather</span><span class="p">(</span>
|
||
<span class="n">hidden_states</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">indices</span><span class="o">=</span><span class="n">last_token_ids</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span>
|
||
<span class="n">concat</span><span class="p">([</span><span class="n">shape</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span>
|
||
<span class="n">shape</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span> <span class="mi">2</span><span class="p">)]))</span>
|
||
<span class="k">return</span> <span class="n">hidden_states</span></div>
|
||
|
||
|
||
|
||
<span class="n">ACT2FN</span> <span class="o">=</span> <span class="p">{</span>
|
||
<span class="s1">'relu'</span><span class="p">:</span> <span class="n">relu</span><span class="p">,</span>
|
||
<span class="s1">'tanh'</span><span class="p">:</span> <span class="n">tanh</span><span class="p">,</span>
|
||
<span class="s1">'gelu'</span><span class="p">:</span> <span class="n">gelu</span><span class="p">,</span>
|
||
<span class="s1">'gelu_new'</span><span class="p">:</span> <span class="n">gelu</span><span class="p">,</span>
|
||
<span class="s1">'gelu_fast'</span><span class="p">:</span> <span class="n">gelu</span><span class="p">,</span>
|
||
<span class="s1">'geglu'</span><span class="p">:</span> <span class="n">geglu</span><span class="p">,</span>
|
||
<span class="s1">'silu'</span><span class="p">:</span> <span class="n">silu</span><span class="p">,</span>
|
||
<span class="s1">'softplus'</span><span class="p">:</span> <span class="n">softplus</span><span class="p">,</span>
|
||
<span class="s1">'squared-relu'</span><span class="p">:</span> <span class="n">squared_relu</span><span class="p">,</span>
|
||
<span class="s1">'swiglu'</span><span class="p">:</span> <span class="n">swiglu</span><span class="p">,</span>
|
||
<span class="s1">'fast-swiglu'</span><span class="p">:</span> <span class="n">swiglu</span><span class="p">,</span>
|
||
<span class="p">}</span>
|
||
|
||
<span class="n">GATED_ACT_2_ACT</span> <span class="o">=</span> <span class="p">{</span>
|
||
<span class="s1">'swiglu'</span><span class="p">:</span> <span class="s1">'silu'</span><span class="p">,</span>
|
||
<span class="s1">'fast-swiglu'</span><span class="p">:</span> <span class="s1">'silu'</span><span class="p">,</span>
|
||
<span class="s1">'geglu'</span><span class="p">:</span> <span class="s1">'gelu'</span><span class="p">,</span>
|
||
<span class="p">}</span>
|
||
|
||
|
||
<div class="viewcode-block" id="is_gated_activation">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.is_gated_activation">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">is_gated_activation</span><span class="p">(</span><span class="n">activation</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Is a given activation function gated?</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> activation : str</span>
|
||
<span class="sd"> The name of the activation function.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> True if the function is gated, False otherwise.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">assert</span> <span class="n">activation</span> <span class="ow">in</span> <span class="n">ACT2FN</span>
|
||
<span class="k">return</span> <span class="n">activation</span> <span class="ow">in</span> <span class="n">GATED_ACT_2_ACT</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="non_gated_version">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.non_gated_version">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">non_gated_version</span><span class="p">(</span><span class="n">activation</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Given an activation function, get the non-gated version.</span>
|
||
|
||
<span class="sd"> If the activation function is non-gated, it returns the same activation</span>
|
||
<span class="sd"> function name.</span>
|
||
|
||
<span class="sd"> For example, that function returns 'silu' for 'swiglu' and 'relu' for</span>
|
||
<span class="sd"> 'relu'.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> activation : str</span>
|
||
<span class="sd"> The name of the activation function.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The name of the non-gated activation function.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="n">is_gated_activation</span><span class="p">(</span><span class="n">activation</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="n">GATED_ACT_2_ACT</span><span class="p">[</span><span class="n">activation</span><span class="p">]</span>
|
||
<span class="k">return</span> <span class="n">activation</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="lora_plugin">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.lora_plugin">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">lora_plugin</span><span class="p">(</span>
|
||
<span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">in_hidden_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">out_hidden_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">host_request_types</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">transa</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">transb</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">host_context_lengths</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="c1"># for pad-free input mode</span>
|
||
<span class="n">max_context_length</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">max_low_rank</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">lora_ranks</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">lora_weights_pointers</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> lora_ids : cpu Tensor = None</span>
|
||
<span class="sd"> A tensor that contains the lora ids of different inputs.</span>
|
||
|
||
<span class="sd"> in_hidden_size/out_hidden_size : int</span>
|
||
<span class="sd"> the lora computation workflow is</span>
|
||
<span class="sd"> [M, in_hidden_size] -> [M, low_rank] -> [M, out_hidden_size]</span>
|
||
|
||
<span class="sd"> host_request_types : Tensor = None</span>
|
||
<span class="sd"> The tensor on the host that indicates if a request is in context or</span>
|
||
<span class="sd"> generation phase. Its shape is [batch_size]. See Inflight Batching</span>
|
||
<span class="sd"> in docs/gpt_attention.md,</span>
|
||
|
||
<span class="sd"> transa : bool</span>
|
||
<span class="sd"> Is the first input transposed? Set to 'True' if you want the first</span>
|
||
<span class="sd"> input to be transposed, 'False' otherwise.</span>
|
||
|
||
<span class="sd"> transb : bool</span>
|
||
<span class="sd"> Is the second input transposed? Set to 'True' if you want the</span>
|
||
<span class="sd"> second input to be transposed, 'False' otherwise.</span>
|
||
|
||
<span class="sd"> host_context_lengths: cpu Tensor = None</span>
|
||
<span class="sd"> A host tensor that contains the lengths of the different inputs,</span>
|
||
|
||
<span class="sd"> max_context_length : int</span>
|
||
<span class="sd"> Maximum length during context phase, used to determine the workspace size.</span>
|
||
|
||
<span class="sd"> max_low_rank : int</span>
|
||
<span class="sd"> Maximum low_rank, used to determine the workspace size.</span>
|
||
|
||
<span class="sd"> lora_ranks : cpu Tensor with shape [batch_size]</span>
|
||
<span class="sd"> The low_rank of each request</span>
|
||
|
||
<span class="sd"> lora_weights_pointers : cpu int64 Tensor with shape [batch_size, 2]</span>
|
||
<span class="sd"> The weights pointers of each request. Consist of in_pointer and out_pointer.</span>
|
||
|
||
<span class="sd"> Return:</span>
|
||
<span class="sd"> The tensor produced by that layer.</span>
|
||
|
||
<span class="sd"> '''</span>
|
||
<span class="k">assert</span> <span class="n">host_context_lengths</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">or</span> <span class="ow">not</span> <span class="n">default_net</span><span class="p">(</span>
|
||
<span class="p">)</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">remove_input_padding</span>
|
||
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_creator_list</span>
|
||
<span class="n">in_hidden_size</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"in_hidden_size"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">in_hidden_size</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">out_hidden_size</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"out_hidden_size"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">out_hidden_size</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">transa</span> <span class="o">=</span> <span class="mi">1</span> <span class="k">if</span> <span class="n">transa</span> <span class="k">else</span> <span class="mi">0</span>
|
||
<span class="n">transa</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"transa"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">transa</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">transb</span> <span class="o">=</span> <span class="mi">1</span> <span class="k">if</span> <span class="n">transb</span> <span class="k">else</span> <span class="mi">0</span>
|
||
<span class="n">transb</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"transb"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">transb</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
|
||
<span class="n">plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
|
||
<span class="s1">'Lora'</span><span class="p">,</span> <span class="s1">'1'</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
|
||
<span class="n">p_dtype</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">lora_plugin</span>
|
||
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"type_id"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">))],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">remove_input_padding</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"remove_input_padding"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">max_context_length_filed</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"max_context_length"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">max_context_length</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">max_low_rank_filed</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"max_low_rank"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">max_low_rank</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
|
||
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span>
|
||
<span class="n">in_hidden_size</span><span class="p">,</span> <span class="n">out_hidden_size</span><span class="p">,</span> <span class="n">transa</span><span class="p">,</span> <span class="n">transb</span><span class="p">,</span> <span class="n">pf_type</span><span class="p">,</span>
|
||
<span class="n">remove_input_padding</span><span class="p">,</span> <span class="n">max_context_length_filed</span><span class="p">,</span> <span class="n">max_low_rank_filed</span>
|
||
<span class="p">])</span>
|
||
<span class="n">lora_plug</span> <span class="o">=</span> <span class="n">plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">"lora"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="nb">input</span><span class="p">,</span> <span class="n">host_request_types</span><span class="p">,</span> <span class="n">lora_ranks</span><span class="p">,</span> <span class="n">lora_weights_pointers</span><span class="p">]</span>
|
||
<span class="k">if</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">host_context_lengths</span><span class="p">]</span>
|
||
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">i</span><span class="o">.</span><span class="n">trt_tensor</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">plug_inputs</span><span class="p">]</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">lora_plug</span><span class="p">)</span>
|
||
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
</pre></div>
|
||
|
||
</div>
|
||
</div>
|
||
<footer>
|
||
|
||
<hr/>
|
||
|
||
<div role="contentinfo">
|
||
<p>© Copyright 2023, NVidia.</p>
|
||
</div>
|
||
|
||
Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
|
||
<a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
|
||
provided by <a href="https://readthedocs.org">Read the Docs</a>.
|
||
|
||
|
||
</footer>
|
||
</div>
|
||
</div>
|
||
</section>
|
||
</div>
|
||
<script>
|
||
jQuery(function () {
|
||
SphinxRtdTheme.Navigation.enable(true);
|
||
});
|
||
</script>
|
||
|
||
</body>
|
||
</html> |