mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
7570 lines
939 KiB
HTML
7570 lines
939 KiB
HTML
|
||
|
||
<!DOCTYPE html>
|
||
<html class="writer-html5" lang="en" data-content_root="../../">
|
||
<head>
|
||
<meta charset="utf-8" />
|
||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||
<title>tensorrt_llm.functional — tensorrt_llm documentation</title>
|
||
<link rel="stylesheet" type="text/css" href="../../_static/pygments.css?v=80d5e7a1" />
|
||
<link rel="stylesheet" type="text/css" href="../../_static/css/theme.css?v=e59714d7" />
|
||
<link rel="stylesheet" type="text/css" href="../../_static/copybutton.css?v=76b2166b" />
|
||
|
||
|
||
<script src="../../_static/jquery.js?v=5d32c60e"></script>
|
||
<script src="../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||
<script src="../../_static/documentation_options.js?v=5929fcd5"></script>
|
||
<script src="../../_static/doctools.js?v=9a2dae69"></script>
|
||
<script src="../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||
<script src="../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||
<script src="../../_static/copybutton.js?v=65e89d2a"></script>
|
||
<script src="../../_static/js/theme.js"></script>
|
||
<link rel="index" title="Index" href="../../genindex.html" />
|
||
<link rel="search" title="Search" href="../../search.html" />
|
||
</head>
|
||
|
||
<body class="wy-body-for-nav">
|
||
<div class="wy-grid-for-nav">
|
||
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
|
||
<div class="wy-side-scroll">
|
||
<div class="wy-side-nav-search" >
|
||
|
||
|
||
|
||
<a href="../../index.html" class="icon icon-home">
|
||
tensorrt_llm
|
||
</a>
|
||
<div role="search">
|
||
<form id="rtd-search-form" class="wy-form" action="../../search.html" method="get">
|
||
<input type="text" name="q" placeholder="Search docs" aria-label="Search docs" />
|
||
<input type="hidden" name="check_keywords" value="yes" />
|
||
<input type="hidden" name="area" value="default" />
|
||
</form>
|
||
</div>
|
||
</div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
|
||
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../overview.html">Overview</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../quick-start-guide.html">Quick Start Guide</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../key-features.html">Key Features</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../release-notes.html">Release Notes</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">Installation</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../installation/linux.html">Installing on Linux</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../installation/build-from-source-linux.html">Building from Source Code on Linux</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../installation/windows.html">Installing on Windows</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../installation/build-from-source-windows.html">Building from Source Code on Windows</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../installation/grace-hopper.html">Installing on Grace Hopper</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">LLM API</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../llm-api/index.html">API Introduction</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../llm-api/reference.html">API Reference</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">LLM API Examples</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../llm-api-examples/index.html">LLM Examples Introduction</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../llm-api-examples/customization.html">Common Customizations</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../llm-api-examples/llm_api_examples.html">Examples</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">Model Definition API</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/tensorrt_llm.layers.html">Layers</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/tensorrt_llm.functional.html">Functionals</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/tensorrt_llm.models.html">Models</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/tensorrt_llm.plugin.html">Plugin</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/tensorrt_llm.quantization.html">Quantization</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/tensorrt_llm.runtime.html">Runtime</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">C++ API</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../_cpp_gen/executor.html">Executor</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../_cpp_gen/runtime.html">Runtime</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">Command-Line Reference</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../commands/trtllm-build.html">trtllm-build</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../commands/trtllm-serve.html">trtllm-serve</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">Architecture</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../architecture/overview.html">TensorRT-LLM Architecture</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../architecture/core-concepts.html">Model Definition</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../architecture/core-concepts.html#compilation">Compilation</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../architecture/core-concepts.html#runtime">Runtime</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../architecture/core-concepts.html#multi-gpu-and-multi-node-support">Multi-GPU and Multi-Node Support</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../architecture/checkpoint.html">TensorRT-LLM Checkpoint</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../architecture/workflow.html">TensorRT-LLM Build Workflow</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../architecture/add-model.html">Adding a Model</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">Advanced</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../advanced/gpt-attention.html">Multi-Head, Multi-Query, and Group-Query Attention</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../advanced/gpt-runtime.html">C++ GPT Runtime</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../advanced/executor.html">Executor API</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../advanced/graph-rewriting.html">Graph Rewriting Module</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../advanced/inference-request.html">Inference Request</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../advanced/inference-request.html#responses">Responses</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../advanced/lora.html">Run gpt-2b + LoRA using GptManager / cpp runtime</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../advanced/speculative-decoding.html">Speculative Sampling</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../performance/perf-overview.html">Overview</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../performance/perf-benchmarking.html">Benchmarking</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../performance/perf-best-practices.html">Best Practices</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../performance/perf-analysis.html">Performance Analysis</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">Reference</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../reference/troubleshooting.html">Troubleshooting</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../reference/support-matrix.html">Support Matrix</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../reference/precision.html">Numerical Precision</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../reference/memory.html">Memory Usage of TensorRT-LLM</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">Blogs</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../blogs/H100vsA100.html">H100 has 4.6x A100 Performance in TensorRT-LLM, achieving 10,000 tok/s at 100ms to first token</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../blogs/H200launch.html">H200 achieves nearly 12,000 tokens/sec on Llama2-13B with TensorRT-LLM</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../blogs/Falcon180B-H200.html">Falcon-180B on a single H200 GPU with INT4 AWQ, and 6.7x faster Llama-70B over A100</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../blogs/quantization-in-TRT-LLM.html">Speed up inference with SOTA quantization techniques in TRT-LLM</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../blogs/XQA-kernel.html">New XQA-kernel provides 2.4x more Llama-70B throughput within the same latency budget</a></li>
|
||
</ul>
|
||
|
||
</div>
|
||
</div>
|
||
</nav>
|
||
|
||
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
|
||
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
|
||
<a href="../../index.html">tensorrt_llm</a>
|
||
</nav>
|
||
|
||
<div class="wy-nav-content">
|
||
<div class="rst-content">
|
||
<div role="navigation" aria-label="Page navigation">
|
||
<ul class="wy-breadcrumbs">
|
||
<li><a href="../../index.html" class="icon icon-home" aria-label="Home"></a></li>
|
||
<li class="breadcrumb-item"><a href="../index.html">Module code</a></li>
|
||
<li class="breadcrumb-item active">tensorrt_llm.functional</li>
|
||
<li class="wy-breadcrumbs-aside">
|
||
</li>
|
||
</ul>
|
||
<hr/>
|
||
</div>
|
||
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
|
||
<div itemprop="articleBody">
|
||
|
||
<h1>Source code for tensorrt_llm.functional</h1><div class="highlight"><pre>
|
||
<span></span><span class="c1"># SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.</span>
|
||
<span class="c1"># SPDX-License-Identifier: Apache-2.0</span>
|
||
<span class="c1">#</span>
|
||
<span class="c1"># Licensed under the Apache License, Version 2.0 (the "License");</span>
|
||
<span class="c1"># you may not use this file except in compliance with the License.</span>
|
||
<span class="c1"># You may obtain a copy of the License at</span>
|
||
<span class="c1">#</span>
|
||
<span class="c1"># http://www.apache.org/licenses/LICENSE-2.0</span>
|
||
<span class="c1">#</span>
|
||
<span class="c1"># Unless required by applicable law or agreed to in writing, software</span>
|
||
<span class="c1"># distributed under the License is distributed on an "AS IS" BASIS,</span>
|
||
<span class="c1"># WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.</span>
|
||
<span class="c1"># See the License for the specific language governing permissions and</span>
|
||
<span class="c1"># limitations under the License.</span>
|
||
<span class="kn">import</span> <span class="nn">math</span>
|
||
<span class="kn">import</span> <span class="nn">weakref</span>
|
||
<span class="kn">from</span> <span class="nn">collections</span> <span class="kn">import</span> <span class="n">OrderedDict</span>
|
||
<span class="kn">from</span> <span class="nn">enum</span> <span class="kn">import</span> <span class="n">IntEnum</span><span class="p">,</span> <span class="n">IntFlag</span><span class="p">,</span> <span class="n">auto</span>
|
||
<span class="kn">from</span> <span class="nn">functools</span> <span class="kn">import</span> <span class="n">partial</span>
|
||
<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">List</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Sequence</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Union</span>
|
||
|
||
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
|
||
|
||
<span class="c1"># isort: off</span>
|
||
<span class="kn">import</span> <span class="nn">tensorrt</span> <span class="k">as</span> <span class="nn">trt</span>
|
||
<span class="c1"># isort: on</span>
|
||
|
||
<span class="kn">from</span> <span class="nn">.</span> <span class="kn">import</span> <span class="n">graph_rewriting</span> <span class="k">as</span> <span class="n">gw</span>
|
||
<span class="kn">from</span> <span class="nn">._common</span> <span class="kn">import</span> <span class="n">default_net</span><span class="p">,</span> <span class="n">default_trtnet</span><span class="p">,</span> <span class="n">precision</span>
|
||
<span class="kn">from</span> <span class="nn">._utils</span> <span class="kn">import</span> <span class="p">(</span><span class="n">QuantModeWrapper</span><span class="p">,</span> <span class="n">bf16_array</span><span class="p">,</span> <span class="n">bool_array</span><span class="p">,</span>
|
||
<span class="n">dim_resolve_negative</span><span class="p">,</span> <span class="n">dim_to_trt_axes</span><span class="p">,</span> <span class="n">dims_array</span><span class="p">,</span>
|
||
<span class="n">fp16_array</span><span class="p">,</span> <span class="n">fp32_array</span><span class="p">,</span> <span class="n">int32_array</span><span class="p">,</span> <span class="n">int64_array</span><span class="p">,</span>
|
||
<span class="n">np_dtype_to_trt</span><span class="p">,</span> <span class="n">str_dtype_to_trt</span><span class="p">,</span> <span class="n">trt_dtype_to_np</span><span class="p">,</span>
|
||
<span class="n">trt_dtype_to_str</span><span class="p">)</span>
|
||
<span class="kn">from</span> <span class="nn">.network</span> <span class="kn">import</span> <span class="n">PluginInfo</span><span class="p">,</span> <span class="n">set_np_weight</span><span class="p">,</span> <span class="n">set_plugin_info</span>
|
||
<span class="kn">from</span> <span class="nn">.plugin</span> <span class="kn">import</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">,</span> <span class="n">current_all_reduce_helper</span>
|
||
<span class="kn">from</span> <span class="nn">.quantization</span> <span class="kn">import</span> <span class="n">QuantMode</span>
|
||
|
||
|
||
<div class="viewcode-block" id="DimRange">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.DimRange">[docs]</a>
|
||
<span class="k">class</span> <span class="nc">DimRange</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> One DimRange object stores the ranges of all the dimensions of one tensor in one optimization profile.</span>
|
||
<span class="sd"> For example, tensor has 2 dimensions. Then the data members are:</span>
|
||
<span class="sd"> self.min = [dim 0 min, dim 1 min]</span>
|
||
<span class="sd"> self.opt = [dim 0 opt, dim 1 opt]</span>
|
||
<span class="sd"> self.max = [dim 0 max, dim 1 max]</span>
|
||
<span class="sd"> For static dimension, it has min==opt==max, thus the shape param in the ctor can be an integer</span>
|
||
<span class="sd"> '''</span>
|
||
|
||
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">shape</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]]],</span>
|
||
<span class="n">names</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> shape: a list with length N, each element is an integer or a 3-elements tuple/list of int,</span>
|
||
<span class="sd"> where N is the number of dimensions for a tensor.</span>
|
||
<span class="sd"> When one element is an integer, it means that dimension is static.</span>
|
||
<span class="sd"> Otherwise, when one element is a tuple/list, it means the dimension is dynamic.</span>
|
||
<span class="sd"> The 3 elements in one tuple/list is ordered by (min, opt, max), and this function asserts</span>
|
||
<span class="sd"> 0 <= min <= opt <= max.</span>
|
||
|
||
<span class="sd"> Example, for a 3 rank tensor, with 1st dimension being static and has value 16, and second dimension being dynamic with</span>
|
||
<span class="sd"> min/opt/max values being 1/8/32, and 3rd dimension being static and has value 8.</span>
|
||
<span class="sd"> The shape parameter could be:</span>
|
||
<span class="sd"> [16, (1, 8, 32), 8]</span>
|
||
<span class="sd"> It has same semantics of</span>
|
||
<span class="sd"> [(16, 16, 16), (1, 8, 32), (8, 8, 8)]</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">min</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">opt</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">max</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">dimension_names</span> <span class="o">=</span> <span class="n">names</span>
|
||
<span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">names</span><span class="p">)</span> <span class="o">==</span> <span class="nb">len</span><span class="p">(</span>
|
||
<span class="n">shape</span>
|
||
<span class="p">),</span> <span class="s2">"Expecting shape list and name list must have same length, got {shape=}, {name=}"</span>
|
||
<span class="k">for</span> <span class="n">dim</span> <span class="ow">in</span> <span class="n">shape</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="p">(</span><span class="nb">list</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">)):</span>
|
||
<span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span> <span class="o">==</span> <span class="mi">3</span> <span class="ow">and</span> <span class="mi">0</span> <span class="o"><=</span> <span class="n">dim</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o"><=</span> <span class="n">dim</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o"><=</span> <span class="n">dim</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span> \
|
||
<span class="s2">"Each dimension must specify a 3-elements tuple or list in the order of (min,opt,max), got {dim=}"</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">min</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">dim</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">opt</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">dim</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">max</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">dim</span><span class="p">[</span><span class="mi">2</span><span class="p">])</span>
|
||
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">min</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">opt</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">max</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">AttributeError</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s1">'Dimension should be [min, opt, max] (dynamic shape) or int (specific value). Got </span><span class="si">{</span><span class="nb">type</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span><span class="si">}</span><span class="s1">'</span>
|
||
<span class="p">)</span>
|
||
|
||
<span class="k">def</span> <span class="fm">__eq__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">__value</span><span class="p">:</span> <span class="nb">object</span><span class="p">)</span> <span class="o">-></span> <span class="nb">bool</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">__value</span><span class="p">,</span> <span class="n">DimRange</span><span class="p">)</span> <span class="ow">and</span> \
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">dimension_names</span> <span class="o">==</span> <span class="n">__value</span><span class="o">.</span><span class="n">dimension_names</span> <span class="ow">and</span> \
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">min</span> <span class="o">==</span> <span class="n">__value</span><span class="o">.</span><span class="n">min</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">opt</span> <span class="o">==</span> <span class="n">__value</span><span class="o">.</span><span class="n">opt</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">max</span> <span class="o">==</span> <span class="n">__value</span><span class="o">.</span><span class="n">max</span>
|
||
|
||
<span class="k">def</span> <span class="fm">__repr__</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">str</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="nb">str</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span>
|
||
|
||
<span class="k">def</span> <span class="fm">__str__</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">str</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">dimension_names</span><span class="si">=}</span><span class="s2"> </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">min</span><span class="si">=}</span><span class="s2">, </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">opt</span><span class="si">=}</span><span class="s2">, </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">max</span><span class="si">=}</span><span class="s2">)"</span>
|
||
|
||
<span class="k">def</span> <span class="fm">__hash__</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">int</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="nb">hash</span><span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="bp">self</span><span class="p">))</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="Tensor">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor">[docs]</a>
|
||
<span class="k">class</span> <span class="nc">Tensor</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> The class to represent dense tensors.</span>
|
||
|
||
<span class="sd"> A dense tensor is named, has a shape and contains typed elements. Each</span>
|
||
<span class="sd"> dimension of a tensor can either be static or dynamic. Static dimensions</span>
|
||
<span class="sd"> are known at engine compilation by TensorRT. Dynamic dimensions can take</span>
|
||
<span class="sd"> values determined at runtime. The tensor can be located on the host (CPU)</span>
|
||
<span class="sd"> or the device (GPU).</span>
|
||
<span class="sd"> '''</span>
|
||
|
||
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
|
||
<span class="n">name</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">shape</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">dim_range</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">is_network_input</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||
<span class="n">location</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">TensorLocation</span><span class="o">.</span><span class="n">DEVICE</span><span class="p">,</span>
|
||
<span class="n">network</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">trt_tensor</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> name : str</span>
|
||
<span class="sd"> The name of the tensor.</span>
|
||
|
||
<span class="sd"> dtype : tensorrt.DataType</span>
|
||
<span class="sd"> The type of the elements of the tensor. See the TensorRT</span>
|
||
<span class="sd"> documentation for list of supported data types.</span>
|
||
|
||
<span class="sd"> shape : tensorrt.Dims</span>
|
||
<span class="sd"> The dimensions of the tensor. In TensorRT-LLM, tensors can have</span>
|
||
<span class="sd"> static or dynamic dimensions (it is possible to mix static and</span>
|
||
<span class="sd"> dynamic dimensions). A static dimension is known when the</span>
|
||
<span class="sd"> TensorRT engine is built. A dynamic dimension can be set when</span>
|
||
<span class="sd"> the engine is executed. Use -1 for dynamic dimensions.</span>
|
||
|
||
<span class="sd"> dim_range : OrderedDict</span>
|
||
<span class="sd"> An ordered dictionary (the positions of the elements matter)</span>
|
||
<span class="sd"> that associates a name and a range of values to the dimensions.</span>
|
||
<span class="sd"> For a static dimension, the range must be limited to a single</span>
|
||
<span class="sd"> value. For a dynamic dimension, the range is defined by three</span>
|
||
<span class="sd"> values [min, opt, max] where min and max are, respectively, the</span>
|
||
<span class="sd"> smallest and largest possible values of that dimension. The</span>
|
||
<span class="sd"> opt value is used by TensorRT to optimize the engine for the</span>
|
||
<span class="sd"> most common case.</span>
|
||
|
||
<span class="sd"> Assume there is N optimization profiles, each item dim_range dict is ordered by:</span>
|
||
<span class="sd"> (dynamic dimension name : [profile 0 (min, opt, max), profile 1 (min, opt, max), ... profile N(min, opt, max)])</span>
|
||
<span class="sd"> or it's following when the dimension is static (can think as min==opt==max):</span>
|
||
<span class="sd"> (static dimension name : [profile 0 value, profile 1 value, ... profile N value])</span>
|
||
<span class="sd"> For static dimension the profile 0-N value must be same, (TODO: can it be simplified to be only 1 value?)</span>
|
||
<span class="sd"> And number of keys is equal to number of optimization profiles.</span>
|
||
|
||
<span class="sd"> is_network_input : bool</span>
|
||
<span class="sd"> A boolean indicating if that tensor is an input of the network.</span>
|
||
<span class="sd"> Inputs must be provided by the user to run the engine.</span>
|
||
|
||
<span class="sd"> location : tensorrt.TensorLocation</span>
|
||
<span class="sd"> A flag to indicate where the tensor will be located. It can be</span>
|
||
<span class="sd"> on the host (CPU) or the device (GPU).</span>
|
||
|
||
<span class="sd"> network: Network</span>
|
||
<span class="sd"> A parent Network instance, that helps to fine the users of this tensor.</span>
|
||
|
||
<span class="sd"> trt_tensor: trt.ITensor</span>
|
||
<span class="sd"> Construct with the ITensor instance directly, and no shape profiles are required.</span>
|
||
<span class="sd"> '''</span>
|
||
|
||
<span class="c1"># Layout of self.profiles</span>
|
||
<span class="c1"># Opt profile 0: dim 0 (min, opt, max), dim 1 (min, opt, max) ... dim M</span>
|
||
<span class="c1"># Opt profile 1: dim 0 (min, opt, max), dim 1 (min, opt, max) ... dim M</span>
|
||
<span class="c1"># ...</span>
|
||
<span class="c1"># Opt profile N: dim 0 ... dim M</span>
|
||
|
||
<span class="c1"># So from the dim_range arg to self.profiles conversion, there is a layout transpose</span>
|
||
<span class="c1"># dim_range arg is: {M dimension x N profiles}, while self.profiles layout is {N profiles x M dimensions}</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dtype</span><span class="p">,</span> <span class="nb">str</span><span class="p">):</span>
|
||
<span class="n">dtype</span> <span class="o">=</span> <span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span>
|
||
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">profiles</span> <span class="o">=</span> <span class="p">[]</span>
|
||
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">is_tensor_wrapper</span> <span class="o">=</span> <span class="kc">False</span> <span class="c1"># specially for the graph rewriter</span>
|
||
|
||
<span class="c1"># work as a wrapper for a trt.ITensor, this is used specially in the graph rewriter</span>
|
||
<span class="k">if</span> <span class="n">trt_tensor</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">is_tensor_wrapper</span> <span class="o">=</span> <span class="kc">True</span>
|
||
<span class="k">assert</span> <span class="n">network</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span> <span class="o">=</span> <span class="n">trt_tensor</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">_network</span> <span class="o">=</span> <span class="n">weakref</span><span class="o">.</span><span class="n">ref</span><span class="p">(</span><span class="n">network</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="ow">not</span> <span class="n">is_network_input</span><span class="p">,</span> <span class="s2">"is_network_input should be False when trt_tensor is not None"</span>
|
||
<span class="k">return</span>
|
||
|
||
<span class="c1"># be cautious here, the weakref is critical to avoid circular referencing before Network and Tensor</span>
|
||
<span class="c1"># using strong reference will likely cause significant peak memory increase, since Network objects</span>
|
||
<span class="c1"># holds the weights data.</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">_network</span> <span class="o">=</span> <span class="n">weakref</span><span class="o">.</span><span class="n">ref</span><span class="p">(</span><span class="n">default_net</span><span class="p">())</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">is_network_input</span> <span class="o">=</span> <span class="n">is_network_input</span>
|
||
<span class="k">if</span> <span class="n">is_network_input</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="n">dim_range</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dim_range</span><span class="p">,</span> <span class="n">OrderedDict</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="nb">len</span><span class="p">(</span>
|
||
<span class="n">dim_range</span>
|
||
<span class="p">)</span> <span class="o">>=</span> <span class="mi">1</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"Each input tensor shall have at least one dimension, tensor '</span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s2">' found </span><span class="si">{</span><span class="n">dim_range</span><span class="si">=}</span><span class="s2">"</span>
|
||
|
||
<span class="n">found_profiles</span> <span class="o">=</span> <span class="p">[</span>
|
||
<span class="nb">len</span><span class="p">(</span><span class="n">ranges</span><span class="p">)</span> <span class="k">for</span> <span class="n">_</span><span class="p">,</span> <span class="n">ranges</span> <span class="ow">in</span> <span class="n">dim_range</span><span class="o">.</span><span class="n">items</span><span class="p">()</span>
|
||
<span class="p">]</span>
|
||
<span class="k">assert</span> <span class="nb">all</span><span class="p">(</span>
|
||
<span class="p">[</span><span class="n">x</span> <span class="o">==</span> <span class="n">found_profiles</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">found_profiles</span><span class="p">]</span>
|
||
<span class="p">),</span> <span class="sa">f</span><span class="s2">"Expecting all the dimensions in the dim_range has same number of profiles, tensor '</span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s2">' got </span><span class="si">{</span><span class="n">dim_range</span><span class="si">=}</span><span class="s2">"</span>
|
||
|
||
<span class="n">num_opt_profile</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="n">dim_range</span><span class="o">.</span><span class="n">items</span><span class="p">())[</span><span class="mi">0</span><span class="p">][</span><span class="mi">1</span><span class="p">])</span>
|
||
<span class="k">assert</span> <span class="n">num_opt_profile</span> <span class="o">>=</span> <span class="mi">1</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_opt_profile</span><span class="p">):</span>
|
||
<span class="n">range_shape</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="n">dimension_names</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="k">for</span> <span class="n">dim</span><span class="p">,</span> <span class="n">ranges</span> <span class="ow">in</span> <span class="n">dim_range</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
|
||
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">ranges</span><span class="p">,</span> <span class="p">(</span><span class="nb">list</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">))</span>
|
||
<span class="n">range_shape</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">ranges</span><span class="p">[</span><span class="n">i</span><span class="p">])</span>
|
||
<span class="n">dimension_names</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">profiles</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">DimRange</span><span class="p">(</span><span class="n">range_shape</span><span class="p">,</span> <span class="n">dimension_names</span><span class="p">))</span>
|
||
|
||
<span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">_add_input</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">name</span><span class="p">,</span> <span class="n">dtype</span><span class="p">,</span> <span class="n">shape</span><span class="p">,</span> <span class="n">dim_range</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">name</span> <span class="o">=</span> <span class="n">name</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">dtype</span> <span class="o">=</span> <span class="n">dtype</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">shape</span> <span class="o">=</span> <span class="n">shape</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">location</span> <span class="o">=</span> <span class="n">location</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span> <span class="nf">network</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_network</span><span class="p">()</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span> <span class="nf">name</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> The name of the tensor.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">name</span>
|
||
|
||
<span class="nd">@name</span><span class="o">.</span><span class="n">setter</span>
|
||
<span class="k">def</span> <span class="nf">name</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">name</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Set the name of the tensor.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="n">name</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">name</span> <span class="o">=</span> <span class="n">name</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span> <span class="nf">dtype</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> The type of the elements in the tensor.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">dtype</span>
|
||
|
||
<span class="nd">@dtype</span><span class="o">.</span><span class="n">setter</span>
|
||
<span class="k">def</span> <span class="nf">dtype</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dtype</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Set the type of the elements in the tensor.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="n">dtype</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">dtype</span> <span class="o">=</span> <span class="n">dtype</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span> <span class="nf">shape</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> The shape of the tensor.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">size</span><span class="p">()</span>
|
||
|
||
<span class="nd">@shape</span><span class="o">.</span><span class="n">setter</span>
|
||
<span class="k">def</span> <span class="nf">shape</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">shape</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Set the shape of the tensor. See __init__.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="n">shape</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">shape</span> <span class="o">=</span> <span class="n">shape</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span> <span class="nf">location</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> The physical location of the tensor (on the host or the device).</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">location</span>
|
||
|
||
<span class="nd">@location</span><span class="o">.</span><span class="n">setter</span>
|
||
<span class="k">def</span> <span class="nf">location</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">location</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Set the physical location of the tensor (on the host or the device). See __init__.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="n">location</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">location</span> <span class="o">=</span> <span class="n">location</span>
|
||
|
||
<div class="viewcode-block" id="Tensor.mark_output">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.mark_output">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">mark_output</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
|
||
<span class="n">name</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Mark a tensor as a network output.</span>
|
||
|
||
<span class="sd"> When a tensor is marked as an output, its content can be obtained after</span>
|
||
<span class="sd"> the execution of the TensorRT engine. The user is responsible for</span>
|
||
<span class="sd"> allocating buffers to store the output tensors when preparing the</span>
|
||
<span class="sd"> execution of the TensorRT engine.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="n">name</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">name</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">name</span>
|
||
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dtype</span><span class="p">,</span> <span class="nb">str</span><span class="p">):</span>
|
||
<span class="n">dtype</span> <span class="o">=</span> <span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span>
|
||
|
||
<span class="k">assert</span> <span class="n">dtype</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">or</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dtype</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="p">)</span>
|
||
<span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">_mark_output</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">name</span><span class="p">,</span> <span class="n">dtype</span><span class="p">)</span></div>
|
||
|
||
|
||
<span class="k">def</span> <span class="fm">__add__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.add.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">add</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
|
||
|
||
<span class="k">def</span> <span class="fm">__radd__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.add.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">add</span><span class="p">(</span><span class="n">b</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span>
|
||
|
||
<span class="k">def</span> <span class="fm">__sub__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.sub.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">sub</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
|
||
|
||
<span class="k">def</span> <span class="fm">__rsub__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.sub.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">sub</span><span class="p">(</span><span class="n">b</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span>
|
||
|
||
<span class="k">def</span> <span class="fm">__mul__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.mul.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">mul</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
|
||
|
||
<span class="k">def</span> <span class="fm">__rmul__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.mul.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">mul</span><span class="p">(</span><span class="n">b</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span>
|
||
|
||
<span class="k">def</span> <span class="fm">__truediv__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.div.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">div</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
|
||
|
||
<span class="k">def</span> <span class="fm">__floordiv__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.floordiv.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">floordiv</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
|
||
|
||
<span class="k">def</span> <span class="fm">__mod__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.floordiv.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">modulo</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
|
||
|
||
<span class="k">def</span> <span class="fm">__lt__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.lt.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">lt</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
|
||
|
||
<span class="k">def</span> <span class="fm">__gt__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.gt.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">gt</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
|
||
|
||
<span class="k">def</span> <span class="fm">__eq__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.eq.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_tensor_wrapper</span><span class="p">:</span>
|
||
<span class="c1"># for graph rewriter</span>
|
||
<span class="k">return</span> <span class="nb">hash</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">==</span> <span class="nb">hash</span><span class="p">(</span><span class="n">b</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="c1"># for creating the network</span>
|
||
<span class="k">return</span> <span class="n">eq</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
|
||
|
||
<span class="k">def</span> <span class="fm">__ge__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Maps to functional.gt or functional.eq.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">op_or</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="fm">__gt__</span><span class="p">(</span><span class="n">b</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="fm">__eq__</span><span class="p">(</span><span class="n">b</span><span class="p">))</span>
|
||
|
||
<span class="k">def</span> <span class="fm">__le__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Maps to functional.lt or functional.eq.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">op_or</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="fm">__lt__</span><span class="p">(</span><span class="n">b</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="fm">__eq__</span><span class="p">(</span><span class="n">b</span><span class="p">))</span>
|
||
|
||
<div class="viewcode-block" id="Tensor.view">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.view">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">view</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">shape</span><span class="p">,</span> <span class="n">zero_is_placeholder</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.view.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">view</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">shape</span><span class="p">,</span> <span class="n">zero_is_placeholder</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="Tensor.flatten">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.flatten">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">flatten</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">start_dim</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">end_dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.flatten.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">flatten</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">start_dim</span><span class="p">,</span> <span class="n">end_dim</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="Tensor.permute">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.permute">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">permute</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dims</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.permute.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">permute</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dims</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="Tensor.transpose">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.transpose">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">transpose</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim0</span><span class="p">,</span> <span class="n">dim1</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.transpose.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">transpose</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim0</span><span class="p">,</span> <span class="n">dim1</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="Tensor.mean">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.mean">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">mean</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.mean.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">mean</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">keepdim</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="Tensor.max">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.max">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">max</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.max.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="nb">max</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">keepdim</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="Tensor.abs">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.abs">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">abs</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.abs.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="nb">abs</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="Tensor.sqrt">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.sqrt">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">sqrt</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.sqrt.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">sqrt</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="Tensor.log">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.log">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">log</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.log.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">log</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="Tensor.cast">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.cast">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">cast</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dtype</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.cast.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">cast</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dtype</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="Tensor.size">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.size">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">size</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Returns the shape of the tensor if the dim parameter is None.</span>
|
||
<span class="sd"> Otherwise, returns a size of the dimension indicated by dim. The</span>
|
||
<span class="sd"> behavior is undefined if dim is negative or exceeds the rank of the</span>
|
||
<span class="sd"> tensor.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="n">dim</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">shape</span>
|
||
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="n">dim</span><span class="p">]</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="Tensor.rank">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.rank">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">rank</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Returns the rank (i.e. the number of dimensions) of the tensor.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="Tensor.ndim">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.ndim">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">ndim</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Returns the rank (i.e. the number of dimensions) of the tensor.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="Tensor.split">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.split">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">split</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">split_size_or_sections</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.split.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">split</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">split_size_or_sections</span><span class="p">,</span> <span class="n">dim</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="Tensor.select">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.select">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">select</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">index</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.select.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">select</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">index</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="Tensor.unbind">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.unbind">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">unbind</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.unbind.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">unbind</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="Tensor.is_dynamic">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.is_dynamic">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">is_dynamic</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> If the argument 'dim' is None, that function returns a boolean that</span>
|
||
<span class="sd"> indicates if the tensor contains a dynamic dimension (True) or not</span>
|
||
<span class="sd"> (False). In that case, the first dimension is excluded (as it usually</span>
|
||
<span class="sd"> corresponds to the batch size). If the argument is an integer, that</span>
|
||
<span class="sd"> functions returns a boolean that indicates if the dimension 'dim' is</span>
|
||
<span class="sd"> dynamic (True) or not (False).</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="n">dim</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="n">dim</span><span class="p">]</span> <span class="o">==</span> <span class="o">-</span><span class="mi">1</span>
|
||
|
||
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">s</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">shape</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="n">i</span> <span class="o">!=</span> <span class="mi">0</span> <span class="ow">and</span> <span class="n">s</span> <span class="o">==</span> <span class="o">-</span><span class="mi">1</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="kc">True</span>
|
||
|
||
<span class="k">return</span> <span class="kc">False</span></div>
|
||
|
||
|
||
<span class="c1"># graph writer related functions</span>
|
||
|
||
<div class="viewcode-block" id="Tensor.get_parent">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.get_parent">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">get_parent</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">''' Get the layer that produces this tensor. '''</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">network</span><span class="o">.</span><span class="n">get_tensor_parent</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="Tensor.get_users">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.get_users">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">get_users</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">''' Get the layers that use this tensor as an input. '''</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">network</span><span class="o">.</span><span class="n">get_tensor_users</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="Tensor.replace_all_uses_with">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.replace_all_uses_with">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">replace_all_uses_with</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">new_tensor</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Replace all uses of this tensor as an input to consumer layers</span>
|
||
<span class="sd"> '''</span>
|
||
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">network</span><span class="o">.</span><span class="n">is_graph_altered</span> <span class="o">=</span> <span class="kc">True</span>
|
||
<span class="n">users</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_users</span><span class="p">()</span>
|
||
<span class="k">for</span> <span class="n">user</span> <span class="ow">in</span> <span class="n">users</span><span class="p">:</span>
|
||
<span class="n">inputs_changed</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">user</span><span class="o">.</span><span class="n">num_inputs</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="n">user</span><span class="o">.</span><span class="n">get_inputs</span><span class="p">(</span><span class="n">i</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">trt_tensor</span> <span class="ow">is</span> <span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">:</span>
|
||
<span class="n">inputs_changed</span> <span class="o">+=</span> <span class="mi">1</span>
|
||
<span class="n">user</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="n">new_tensor</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">inputs_changed</span> <span class="o">>=</span> <span class="mi">1</span><span class="p">,</span> <span class="s2">"Tensor not found in layer inputs"</span>
|
||
|
||
<span class="c1"># update the FLayerMetadata as well</span>
|
||
<span class="n">flayer</span> <span class="o">=</span> <span class="n">gw</span><span class="o">.</span><span class="n">FLayerInfoMemo</span><span class="o">.</span><span class="n">instance</span><span class="p">()</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">user</span><span class="o">.</span><span class="n">name</span><span class="p">)</span>
|
||
<span class="n">flayer</span> <span class="ow">and</span> <span class="n">flayer</span><span class="o">.</span><span class="n">replace_input_with</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">new_tensor</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="Tensor.is_trt_wrapper">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.is_trt_wrapper">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">is_trt_wrapper</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Check if there is a trt.ITensor member inside, which is required for</span>
|
||
<span class="sd"> graph rewriter. In order to differentiate usages, it may be necessary</span>
|
||
<span class="sd"> to have an inheritance hierarchy.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="s1">'trt_tensor'</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="kc">True</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="kc">False</span></div>
|
||
|
||
|
||
<span class="k">def</span> <span class="fm">__hash__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_trt_wrapper</span><span class="p">():</span>
|
||
<span class="k">return</span> <span class="nb">id</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="nb">id</span><span class="p">(</span><span class="kc">None</span><span class="p">)</span>
|
||
|
||
<span class="k">def</span> <span class="fm">__repr__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="sa">f</span><span class="s2">"TensorRT-LLM Tensor: </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">name</span><span class="si">=}</span><span class="s2"> </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">dtype</span><span class="si">=}</span><span class="s2"> </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="si">=}</span><span class="s2">"</span></div>
|
||
|
||
|
||
|
||
<span class="k">def</span> <span class="nf">_create_tensor</span><span class="p">(</span><span class="n">trt_tensor</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">ITensor</span><span class="p">,</span> <span class="n">producer</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">ILayer</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> A helper function to create a TensorRT-LLM Tensor object that encapsulates</span>
|
||
<span class="sd"> the connection between the TensorRT tensor (trt.ITensor) and the layer</span>
|
||
<span class="sd"> (trt.ILayer) that produces it.</span>
|
||
|
||
<span class="sd"> That function is expected to be used as:</span>
|
||
|
||
<span class="sd"> # Insert a new layer in the network using the TensorRT API:</span>
|
||
<span class="sd"> layer = default_trtnet().add_<some_layer>(...)</span>
|
||
<span class="sd"> # Extract the first output of that layer and connect it to the layer.</span>
|
||
<span class="sd"> return _create_tensor(layer.get_output(0), layer)</span>
|
||
|
||
<span class="sd"> That function also sets the precision of the layer/producer to the default</span>
|
||
<span class="sd"> precision of the network.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> trt_tensor : trt.ITensor</span>
|
||
<span class="sd"> The TensorRT tensor to connect to its producer (the layer).</span>
|
||
|
||
<span class="sd"> producer : trt.ILayer</span>
|
||
<span class="sd"> The producer.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The TensorRT-LLM tensor (functional.Tensor) that encapsulates the</span>
|
||
<span class="sd"> TensorRT tensor and the layer that produces it. The former is</span>
|
||
<span class="sd"> accessible through the attribute 'trt_tensor' and the latter using the</span>
|
||
<span class="sd"> attribute 'producer'.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">assert</span> <span class="n">trt_tensor</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="k">assert</span> <span class="n">producer</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
|
||
<span class="c1"># Set the layer name since this is the only</span>
|
||
<span class="c1"># centralized location to pass the name from</span>
|
||
<span class="c1"># module space to the TRT IR</span>
|
||
<span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">_set_layer_name</span><span class="p">(</span><span class="n">producer</span><span class="p">)</span>
|
||
|
||
<span class="k">assert</span> <span class="n">trt_tensor</span><span class="o">.</span><span class="n">shape</span><span class="o">.</span><span class="fm">__len__</span><span class="p">(</span>
|
||
<span class="p">)</span> <span class="o">>=</span> <span class="mi">0</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"tensor </span><span class="si">{</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">name</span><span class="si">}</span><span class="s2"> has an invalid shape"</span>
|
||
<span class="n">tensor</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">name</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span>
|
||
<span class="n">shape</span><span class="o">=</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span>
|
||
<span class="n">is_network_input</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||
<span class="n">tensor</span><span class="o">.</span><span class="n">trt_tensor</span> <span class="o">=</span> <span class="n">trt_tensor</span>
|
||
<span class="n">tensor</span><span class="o">.</span><span class="n">producer</span> <span class="o">=</span> <span class="n">producer</span>
|
||
|
||
<span class="c1"># tb.print_stack(limit=10) # FOR DEBUGGING: filter producer.name if needed</span>
|
||
<span class="k">if</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">dtype</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">strongly_typed</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="n">producer</span><span class="o">.</span><span class="n">type</span> <span class="ow">not</span> <span class="ow">in</span> <span class="p">[</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">LayerType</span><span class="o">.</span><span class="n">SHAPE</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">LayerType</span><span class="o">.</span><span class="n">CONSTANT</span><span class="p">,</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">LayerType</span><span class="o">.</span><span class="n">GATHER</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">LayerType</span><span class="o">.</span><span class="n">CONCATENATION</span>
|
||
<span class="p">]:</span>
|
||
<span class="n">producer</span><span class="o">.</span><span class="n">precision</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">dtype</span>
|
||
<span class="k">assert</span> <span class="n">tensor</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
|
||
<span class="k">if</span> <span class="n">gw</span><span class="o">.</span><span class="n">FLayerInfoMemo</span><span class="o">.</span><span class="n">instance</span><span class="p">()</span><span class="o">.</span><span class="n">cur_flayer</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">gw</span><span class="o">.</span><span class="n">FLayerInfoMemo</span><span class="o">.</span><span class="n">instance</span><span class="p">()</span><span class="o">.</span><span class="n">cur_flayer</span><span class="o">.</span><span class="n">layer_name</span> <span class="o">=</span> <span class="n">producer</span><span class="o">.</span><span class="n">name</span>
|
||
|
||
<span class="k">return</span> <span class="n">tensor</span>
|
||
|
||
|
||
<span class="k">def</span> <span class="nf">_add_plugin_info</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">plugin_creator</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">IPluginCreator</span><span class="p">,</span>
|
||
<span class="n">plugin_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">pfc</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">plugin_info</span> <span class="o">=</span> <span class="n">PluginInfo</span><span class="p">(</span><span class="n">plugin_creator</span><span class="p">,</span> <span class="n">plugin_name</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="n">set_plugin_info</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">trt_network</span><span class="p">,</span> <span class="n">layer</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="n">plugin_info</span><span class="p">)</span>
|
||
|
||
|
||
<div class="viewcode-block" id="RotaryScalingType">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.RotaryScalingType">[docs]</a>
|
||
<span class="k">class</span> <span class="nc">RotaryScalingType</span><span class="p">(</span><span class="n">IntEnum</span><span class="p">):</span>
|
||
<span class="n">none</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="n">linear</span> <span class="o">=</span> <span class="mi">1</span>
|
||
<span class="n">dynamic</span> <span class="o">=</span> <span class="mi">2</span>
|
||
<span class="n">longrope</span> <span class="o">=</span> <span class="mi">3</span>
|
||
<span class="n">llama3</span> <span class="o">=</span> <span class="mi">4</span>
|
||
<span class="n">yarn</span> <span class="o">=</span> <span class="mi">5</span>
|
||
<span class="n">mrope</span> <span class="o">=</span> <span class="mi">6</span>
|
||
|
||
<div class="viewcode-block" id="RotaryScalingType.from_string">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.RotaryScalingType.from_string">[docs]</a>
|
||
<span class="nd">@staticmethod</span>
|
||
<span class="k">def</span> <span class="nf">from_string</span><span class="p">(</span><span class="n">s</span><span class="p">):</span>
|
||
<span class="k">try</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">RotaryScalingType</span><span class="p">[</span><span class="n">s</span><span class="p">]</span>
|
||
<span class="k">except</span> <span class="ne">KeyError</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s1">'Unsupported rotary scaling type: </span><span class="si">{</span><span class="n">s</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span></div>
|
||
</div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="PositionEmbeddingType">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.PositionEmbeddingType">[docs]</a>
|
||
<span class="k">class</span> <span class="nc">PositionEmbeddingType</span><span class="p">(</span><span class="n">IntEnum</span><span class="p">):</span>
|
||
<span class="n">learned_absolute</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="n">rope_gptj</span> <span class="o">=</span> <span class="mi">1</span>
|
||
<span class="n">rope_gpt_neox</span> <span class="o">=</span> <span class="mi">2</span>
|
||
<span class="n">long_rope</span> <span class="o">=</span> <span class="mi">3</span>
|
||
<span class="n">alibi</span> <span class="o">=</span> <span class="mi">4</span>
|
||
<span class="n">alibi_with_scale</span> <span class="o">=</span> <span class="mi">5</span>
|
||
<span class="n">relative</span> <span class="o">=</span> <span class="mi">6</span>
|
||
<span class="n">chatglm</span> <span class="o">=</span> <span class="mi">7</span>
|
||
<span class="n">yarn</span> <span class="o">=</span> <span class="mi">8</span>
|
||
<span class="n">mrope</span> <span class="o">=</span> <span class="mi">9</span>
|
||
|
||
<div class="viewcode-block" id="PositionEmbeddingType.is_rope">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.PositionEmbeddingType.is_rope">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">is_rope</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">bool</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="bp">self</span> <span class="ow">in</span> <span class="p">[</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">rope_gptj</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">rope_gpt_neox</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">long_rope</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mrope</span>
|
||
<span class="p">]</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="PositionEmbeddingType.is_mrope">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.PositionEmbeddingType.is_mrope">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">is_mrope</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">bool</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="bp">self</span> <span class="ow">in</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">mrope</span><span class="p">]</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="PositionEmbeddingType.is_alibi">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.PositionEmbeddingType.is_alibi">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">is_alibi</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">bool</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="bp">self</span> <span class="ow">in</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">alibi</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">alibi_with_scale</span><span class="p">]</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="PositionEmbeddingType.choices">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.PositionEmbeddingType.choices">[docs]</a>
|
||
<span class="nd">@staticmethod</span>
|
||
<span class="k">def</span> <span class="nf">choices</span><span class="p">()</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]:</span>
|
||
<span class="k">return</span> <span class="p">[</span><span class="n">embedding</span><span class="o">.</span><span class="n">name</span> <span class="k">for</span> <span class="n">embedding</span> <span class="ow">in</span> <span class="n">PositionEmbeddingType</span><span class="p">]</span></div>
|
||
|
||
|
||
<span class="k">def</span> <span class="fm">__str__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">name</span>
|
||
|
||
<div class="viewcode-block" id="PositionEmbeddingType.from_string">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.PositionEmbeddingType.from_string">[docs]</a>
|
||
<span class="nd">@staticmethod</span>
|
||
<span class="k">def</span> <span class="nf">from_string</span><span class="p">(</span><span class="n">s</span><span class="p">):</span>
|
||
<span class="k">try</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">PositionEmbeddingType</span><span class="p">[</span><span class="n">s</span><span class="p">]</span>
|
||
<span class="k">except</span> <span class="ne">KeyError</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s1">'Unsupported position embedding type: </span><span class="si">{</span><span class="n">s</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span></div>
|
||
</div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="AttentionMaskType">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.AttentionMaskType">[docs]</a>
|
||
<span class="k">class</span> <span class="nc">AttentionMaskType</span><span class="p">(</span><span class="n">IntEnum</span><span class="p">):</span>
|
||
<span class="n">padding</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="n">causal</span> <span class="o">=</span> <span class="mi">1</span>
|
||
<span class="n">sliding_window_causal</span> <span class="o">=</span> <span class="mi">2</span>
|
||
<span class="n">bidirectional</span> <span class="o">=</span> <span class="mi">3</span>
|
||
<span class="n">bidirectionalglm</span> <span class="o">=</span> <span class="mi">4</span> <span class="c1"># TODO: merge this mask into bidirectional</span>
|
||
<span class="n">blocksparse</span> <span class="o">=</span> <span class="mi">5</span>
|
||
<span class="n">custom_mask</span> <span class="o">=</span> <span class="mi">6</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="LayerNormType">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.LayerNormType">[docs]</a>
|
||
<span class="k">class</span> <span class="nc">LayerNormType</span><span class="p">(</span><span class="n">IntEnum</span><span class="p">):</span>
|
||
<span class="n">LayerNorm</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="n">RmsNorm</span> <span class="o">=</span> <span class="mi">1</span>
|
||
<span class="n">GroupNorm</span> <span class="o">=</span> <span class="mi">2</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="LayerNormPositionType">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.LayerNormPositionType">[docs]</a>
|
||
<span class="k">class</span> <span class="nc">LayerNormPositionType</span><span class="p">(</span><span class="n">IntEnum</span><span class="p">):</span>
|
||
<span class="n">pre_layernorm</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="n">post_layernorm</span> <span class="o">=</span> <span class="mi">1</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="MLPType">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.MLPType">[docs]</a>
|
||
<span class="k">class</span> <span class="nc">MLPType</span><span class="p">(</span><span class="n">IntEnum</span><span class="p">):</span>
|
||
<span class="n">MLP</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="n">GatedMLP</span> <span class="o">=</span> <span class="mi">1</span>
|
||
<span class="n">FusedGatedMLP</span> <span class="o">=</span> <span class="mi">2</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="activation">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.activation">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">activation</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">act_type</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">ActivationType</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an activation function.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor on which the activation function is applied.</span>
|
||
|
||
<span class="sd"> act_type : trt.ActivationType</span>
|
||
<span class="sd"> The type of the activation (RELU, TANH, SIGMOID, ...).</span>
|
||
|
||
<span class="sd"> The following closures are defined in functional.*:</span>
|
||
|
||
<span class="sd"> relu for op=trt.ActivationType.RELU</span>
|
||
<span class="sd"> tanh for op=trt.ActivationType.TANH</span>
|
||
<span class="sd"> sigmoid for op=trt.ActivationType.SIGMOID</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the activation layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_activation</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">act_type</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="int_clip">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.int_clip">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">int_clip</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">lower</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">upper</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="k">assert</span> <span class="n">lower</span> <span class="o"><=</span> <span class="n">upper</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"Lower bound must be less than or equal to upper bound i.e. </span><span class="si">{</span><span class="n">lower</span><span class="si">}</span><span class="s2"> <= </span><span class="si">{</span><span class="n">upper</span><span class="si">}</span><span class="s2">"</span>
|
||
<span class="n">res</span> <span class="o">=</span> <span class="n">minimum</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">upper</span><span class="p">)</span>
|
||
<span class="n">res</span> <span class="o">=</span> <span class="n">maximum</span><span class="p">(</span><span class="n">res</span><span class="p">,</span> <span class="n">lower</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">res</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="clip">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.clip">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">clip</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">alpha</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">beta</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add a CLIP operation that sets the range to [alpha, beta].</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor on which the activation function is applied.</span>
|
||
|
||
<span class="sd"> alpha : float</span>
|
||
<span class="sd"> The lower bound of the CLIP function.</span>
|
||
|
||
<span class="sd"> beta : float</span>
|
||
<span class="sd"> The upper bound of the CLIP function.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the activation layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_activation</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">ActivationType</span><span class="o">.</span><span class="n">CLIP</span><span class="p">)</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">alpha</span> <span class="o">=</span> <span class="n">alpha</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">beta</span> <span class="o">=</span> <span class="n">beta</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<span class="n">relu</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">activation</span><span class="p">,</span> <span class="n">act_type</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ActivationType</span><span class="o">.</span><span class="n">RELU</span><span class="p">)</span>
|
||
<span class="n">tanh</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">activation</span><span class="p">,</span> <span class="n">act_type</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ActivationType</span><span class="o">.</span><span class="n">TANH</span><span class="p">)</span>
|
||
<span class="n">sigmoid</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">activation</span><span class="p">,</span> <span class="n">act_type</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ActivationType</span><span class="o">.</span><span class="n">SIGMOID</span><span class="p">)</span>
|
||
|
||
|
||
<div class="viewcode-block" id="silu">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.silu">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">silu</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add a SiLU (`x * sigmoid(x)`) operation.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor on which the activation function is applied.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the activation layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="nb">input</span> <span class="o">*</span> <span class="n">sigmoid</span><span class="p">(</span><span class="nb">input</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="swiglu">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.swiglu">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">swiglu</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add a SwiGLU (`x * SiLU(gate)`) operation.</span>
|
||
|
||
<span class="sd"> That function takes a tensor, splits it into two halves along the last</span>
|
||
<span class="sd"> dimension, applies SiLU to the second half and multiply the results. The</span>
|
||
<span class="sd"> behavior is undefined if the last dimension is not even.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor on which the activation function is applied.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the activation layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">x</span><span class="p">,</span> <span class="n">gate</span> <span class="o">=</span> <span class="n">chunk</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">silu</span><span class="p">(</span><span class="n">gate</span><span class="p">)</span> <span class="o">*</span> <span class="n">x</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="squared_relu">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.squared_relu">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">squared_relu</span><span class="p">(</span><span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add a Squared ReLU operation.</span>
|
||
|
||
<span class="sd"> This function applies ReLU and squares the output.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor on which the activation function is applied.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the activation layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="nb">pow</span><span class="p">(</span><span class="n">relu</span><span class="p">(</span><span class="n">x</span><span class="p">),</span> <span class="mf">2.0</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="cast">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.cast">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">cast</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dtype</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="p">]):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add a cast operation.</span>
|
||
|
||
<span class="sd"> For an input tensor of type INT8, this function sets the dynamic range of</span>
|
||
<span class="sd"> the input to [-127, 127] for automatic dequantization. For a cast into</span>
|
||
<span class="sd"> INT8, that function sets the dynamic range of the output to [-127, 127] for</span>
|
||
<span class="sd"> automatic quantization.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor on which the cast is applied.</span>
|
||
|
||
<span class="sd"> dtype : str or trt.DataType</span>
|
||
<span class="sd"> The data type of the output tensor after the cast. When 'dtype' is</span>
|
||
<span class="sd"> provided as a string, it must be a name amongst the valid names.</span>
|
||
<span class="sd"> See _str_to_trt_dtype_dict in _utils.py for a list of supported</span>
|
||
<span class="sd"> types and type names.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the inserted layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dtype</span><span class="p">,</span> <span class="nb">str</span><span class="p">):</span>
|
||
<span class="n">cvt_dtype</span> <span class="o">=</span> <span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span>
|
||
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dtype</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="p">):</span>
|
||
<span class="n">cvt_dtype</span> <span class="o">=</span> <span class="n">dtype</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">TypeError</span><span class="p">(</span><span class="s2">"</span><span class="si">%s</span><span class="s2"> is not supported"</span> <span class="o">%</span> <span class="nb">type</span><span class="p">(</span><span class="n">dtype</span><span class="p">))</span>
|
||
|
||
<span class="k">if</span> <span class="nb">input</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">cvt_dtype</span><span class="p">:</span>
|
||
<span class="c1"># If input type and cast dtype are the same, do nothing</span>
|
||
<span class="k">return</span> <span class="nb">input</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_cast</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">cvt_dtype</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">strongly_typed</span><span class="p">:</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">set_output_type</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">cvt_dtype</span><span class="p">)</span>
|
||
<span class="n">output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">strongly_typed</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="nb">input</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="s1">'int8'</span><span class="p">):</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">get_input</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">set_dynamic_range</span><span class="p">(</span><span class="o">-</span><span class="mi">127</span><span class="p">,</span> <span class="mi">127</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">cvt_dtype</span> <span class="o">==</span> <span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="s1">'int8'</span><span class="p">):</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">set_dynamic_range</span><span class="p">(</span><span class="o">-</span><span class="mi">127</span><span class="p">,</span> <span class="mi">127</span><span class="p">)</span>
|
||
|
||
<span class="k">return</span> <span class="n">output</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="flip">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.flip">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">flip</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dims</span><span class="p">:</span> <span class="n">Sequence</span><span class="p">[</span><span class="nb">int</span><span class="p">])</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Reverses the order of an n-D tensor along given axis in dims.</span>
|
||
|
||
<span class="sd"> That flip operation maps to a TensorRT ISliceLayer. For the dimensions</span>
|
||
<span class="sd"> listed in dims it copies the elements from the last one to the first one</span>
|
||
<span class="sd"> (from (N-1) down to 0 with a step of -1). For the dimensions not in 'dims',</span>
|
||
<span class="sd"> it copies the elements from the first one to the last one (from 0 to N-1</span>
|
||
<span class="sd"> with a step of 1).</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor on which the cast is applied.</span>
|
||
|
||
<span class="sd"> dims : list or tuple</span>
|
||
<span class="sd"> The axes to flip. Negative indices are supported.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the inserted layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">assert</span> <span class="ow">not</span> <span class="nb">input</span><span class="o">.</span><span class="n">is_dynamic</span><span class="p">()</span>
|
||
|
||
<span class="n">ndim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
|
||
|
||
<span class="k">for</span> <span class="n">index</span><span class="p">,</span> <span class="n">value</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">dims</span><span class="p">):</span>
|
||
<span class="k">assert</span> <span class="o">-</span><span class="n">ndim</span> <span class="o"><=</span> <span class="n">value</span> <span class="o"><</span> <span class="n">ndim</span>
|
||
<span class="k">if</span> <span class="o">-</span><span class="n">ndim</span> <span class="o"><=</span> <span class="n">value</span> <span class="o"><</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="n">dims</span><span class="p">[</span><span class="n">index</span><span class="p">]</span> <span class="o">+=</span> <span class="n">ndim</span>
|
||
|
||
<span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">dims</span><span class="p">)</span> <span class="o">==</span> <span class="nb">len</span><span class="p">(</span><span class="nb">set</span><span class="p">(</span><span class="n">dims</span><span class="p">))</span>
|
||
|
||
<span class="n">start_values</span> <span class="o">=</span> <span class="p">[</span>
|
||
<span class="nb">input</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="n">i</span><span class="p">]</span> <span class="o">-</span> <span class="mi">1</span> <span class="k">if</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">dims</span> <span class="k">else</span> <span class="mi">0</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)</span>
|
||
<span class="p">]</span>
|
||
<span class="n">stride_values</span> <span class="o">=</span> <span class="p">[</span><span class="o">-</span><span class="mi">1</span> <span class="k">if</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">dims</span> <span class="k">else</span> <span class="mi">1</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)]</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_slice</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">start</span><span class="o">=</span><span class="n">start_values</span><span class="p">,</span>
|
||
<span class="n">shape</span><span class="o">=</span><span class="nb">input</span><span class="o">.</span><span class="n">size</span><span class="p">(),</span>
|
||
<span class="n">stride</span><span class="o">=</span><span class="n">stride_values</span><span class="p">)</span>
|
||
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="interpolate">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.interpolate">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">interpolate</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">size</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">scale_factor</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">mode</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s1">'nearest'</span><span class="p">,</span>
|
||
<span class="n">align_corners</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">recompute_scale_factor</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">antialias</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="c1">##</span>
|
||
<span class="c1">## TODO: Document that function!</span>
|
||
<span class="c1">##</span>
|
||
|
||
<span class="k">assert</span> <span class="ow">not</span> <span class="nb">input</span><span class="o">.</span><span class="n">is_dynamic</span><span class="p">()</span>
|
||
|
||
<span class="n">input_ndim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
|
||
|
||
<span class="k">assert</span> <span class="mi">2</span> <span class="o"><</span> <span class="n">input_ndim</span> <span class="o"><</span> <span class="mi">6</span><span class="p">,</span> <span class="s2">"Only 3D, 4D and 5D input Tensors supported"</span>
|
||
<span class="k">assert</span> <span class="p">(</span><span class="n">size</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">)</span> <span class="o">^</span> <span class="p">(</span>
|
||
<span class="n">scale_factor</span>
|
||
<span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">),</span> <span class="s2">"Only one of out_shape or scales should be defined"</span>
|
||
|
||
<span class="k">assert</span> <span class="n">mode</span> <span class="ow">in</span> <span class="p">(</span><span class="s1">'nearest'</span><span class="p">,</span> <span class="s1">'linear'</span><span class="p">,</span> <span class="s1">'bilinear'</span><span class="p">,</span> <span class="s1">'bicubic'</span><span class="p">,</span> <span class="s1">'trilinear'</span><span class="p">,</span>
|
||
<span class="s1">'nearest-exact'</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">mode</span> <span class="o">==</span> <span class="s1">'trilinear'</span> <span class="ow">and</span> <span class="n">input_ndim</span> <span class="o">!=</span> <span class="mi">5</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"trilinear only supports 5D tensor"</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">mode</span> <span class="o">==</span> <span class="s2">"bilinear"</span> <span class="ow">and</span> <span class="n">input_ndim</span> <span class="o">!=</span> <span class="mi">4</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"bilinear only supports 4D tensor"</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">mode</span> <span class="o">==</span> <span class="s2">"linear"</span> <span class="ow">and</span> <span class="n">input_ndim</span> <span class="o">!=</span> <span class="mi">3</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"linear only supports 3D tensor"</span><span class="p">)</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_resize</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
|
||
<span class="n">input_shape</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">size</span><span class="p">()</span>
|
||
|
||
<span class="n">updated_shape</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="k">if</span> <span class="n">scale_factor</span><span class="p">:</span>
|
||
<span class="n">scale_len</span> <span class="o">=</span> <span class="mi">1</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scale_factor</span><span class="p">,</span>
|
||
<span class="p">(</span><span class="nb">float</span><span class="p">,</span> <span class="nb">int</span><span class="p">))</span> <span class="k">else</span> <span class="nb">len</span><span class="p">(</span><span class="n">scale_factor</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">scale_len</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">and</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scale_factor</span><span class="p">,</span> <span class="p">(</span><span class="nb">float</span><span class="p">,</span> <span class="nb">int</span><span class="p">)):</span>
|
||
<span class="n">updated_scale</span> <span class="o">=</span> <span class="p">[</span><span class="n">scale_factor</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">input_ndim</span> <span class="o">-</span> <span class="mi">2</span><span class="p">)]</span>
|
||
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">updated_scale</span> <span class="o">=</span> <span class="n">scale_factor</span>
|
||
<span class="n">updated_shape</span> <span class="o">=</span> <span class="p">[</span>
|
||
<span class="nb">int</span><span class="p">(</span><span class="n">math</span><span class="o">.</span><span class="n">floor</span><span class="p">(</span><span class="n">updated_scale</span><span class="p">[</span><span class="n">i</span> <span class="o">-</span> <span class="mi">2</span><span class="p">]</span> <span class="o">*</span>
|
||
<span class="n">input_shape</span><span class="p">[</span><span class="n">i</span><span class="p">]))</span> <span class="k">if</span> <span class="n">i</span> <span class="o">></span> <span class="mi">1</span> <span class="k">else</span> <span class="n">input_shape</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">input_ndim</span><span class="p">)</span>
|
||
<span class="p">]</span>
|
||
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">size_len</span> <span class="o">=</span> <span class="mi">1</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="nb">int</span><span class="p">)</span> <span class="k">else</span> <span class="nb">len</span><span class="p">(</span><span class="n">size</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">size_len</span> <span class="o">==</span> <span class="n">input_ndim</span> <span class="o">-</span> <span class="mi">2</span>
|
||
<span class="k">if</span> <span class="n">size_len</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">and</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
|
||
<span class="n">updated_size</span> <span class="o">=</span> <span class="p">[</span><span class="n">size</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">input_ndim</span> <span class="o">-</span> <span class="mi">2</span><span class="p">)]</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">updated_size</span> <span class="o">=</span> <span class="n">size</span>
|
||
|
||
<span class="n">updated_shape</span> <span class="o">=</span> <span class="p">[</span>
|
||
<span class="n">input_shape</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="k">if</span> <span class="n">i</span> <span class="o"><</span> <span class="mi">2</span> <span class="k">else</span> <span class="n">updated_size</span><span class="p">[</span><span class="n">i</span> <span class="o">-</span> <span class="mi">2</span><span class="p">]</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">input_ndim</span><span class="p">)</span>
|
||
<span class="p">]</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">shape</span> <span class="o">=</span> <span class="n">updated_shape</span>
|
||
|
||
<span class="k">if</span> <span class="n">mode</span> <span class="ow">in</span> <span class="p">[</span><span class="s1">'nearest'</span><span class="p">,</span> <span class="s1">'nearest-exact'</span><span class="p">]</span> <span class="ow">or</span> <span class="n">mode</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">resize_mode</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">InterpolationMode</span><span class="o">.</span><span class="n">NEAREST</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">coordinate_transformation</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">ResizeCoordinateTransformation</span><span class="o">.</span><span class="n">ASYMMETRIC</span>
|
||
<span class="k">elif</span> <span class="n">mode</span> <span class="ow">in</span> <span class="p">[</span><span class="s1">'linear'</span><span class="p">,</span> <span class="s1">'bilinear'</span><span class="p">,</span> <span class="s1">'trilinear'</span><span class="p">]:</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">resize_mode</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">InterpolationMode</span><span class="o">.</span><span class="n">LINEAR</span>
|
||
<span class="k">if</span> <span class="n">align_corners</span><span class="p">:</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">coordinate_transformation</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">ResizeCoordinateTransformation</span><span class="o">.</span><span class="n">ALIGN_CORNERS</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">coordinate_transformation</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">ResizeCoordinateTransformation</span><span class="o">.</span><span class="n">HALF_PIXEL</span>
|
||
<span class="c1"># TODO, need to confirm the align_corners effect on bilinear mode.</span>
|
||
<span class="k">if</span> <span class="n">mode</span> <span class="o">==</span> <span class="s1">'bilinear'</span><span class="p">:</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">coordinate_transformation</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">ResizeCoordinateTransformation</span><span class="o">.</span><span class="n">HALF_PIXEL</span>
|
||
|
||
<span class="k">elif</span> <span class="n">mode</span> <span class="ow">in</span> <span class="p">[</span><span class="s1">'bicubic'</span><span class="p">]:</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">resize_mode</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">InterpolationMode</span><span class="o">.</span><span class="n">CUBIC</span>
|
||
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">coordinate_transformation</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">ResizeCoordinateTransformation</span><span class="o">.</span><span class="n">HALF_PIXEL</span>
|
||
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">resize_mode</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">InterpolationMode</span><span class="o">.</span><span class="n">NEAREST</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">coordinate_transformation</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">ResizeCoordinateTransformation</span><span class="o">.</span><span class="n">ASYMMETRIC</span>
|
||
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="matmul">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.matmul">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">matmul</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">mat2</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">transa</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">transb</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">use_fp32_acc</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add a matrix multiplication.</span>
|
||
|
||
<span class="sd"> That operation maps to a tensorrt.IMatrixMultiplyLayer layer. As explained</span>
|
||
<span class="sd"> in the TensorRT documentation, it computes the inner product between the</span>
|
||
<span class="sd"> two inputs after applying an optional transposition on the inputs.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The first tensor (often called A).</span>
|
||
|
||
<span class="sd"> mat2 : Tensor</span>
|
||
<span class="sd"> The second tensor (often called B).</span>
|
||
|
||
<span class="sd"> transa : bool</span>
|
||
<span class="sd"> Is the first input transposed? Set to 'True' if you want the first</span>
|
||
<span class="sd"> input to be transposed, 'False' otherwise.</span>
|
||
|
||
<span class="sd"> transb : bool</span>
|
||
<span class="sd"> Is the second input transposed? Set to 'True' if you want the</span>
|
||
<span class="sd"> second input to be transposed, 'False' otherwise.</span>
|
||
|
||
<span class="sd"> use_fp32_acc: bool</span>
|
||
<span class="sd"> Set to 'True' if for accuracy reason, this fp16 matmul needs to use</span>
|
||
<span class="sd"> fp32 accumulation. This can be a per model and per matmul decision.</span>
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the inserted layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="c1"># This option is only supported for fp16, but not bf16 or any other precisions.</span>
|
||
<span class="n">use_fp32_acc</span> <span class="o">=</span> <span class="n">use_fp32_acc</span> <span class="ow">and</span> <span class="nb">input</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">HALF</span> <span class="ow">and</span> <span class="n">mat2</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">HALF</span>
|
||
|
||
<span class="k">if</span> <span class="n">use_fp32_acc</span><span class="p">:</span>
|
||
<span class="nb">input</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="s1">'float32'</span><span class="p">)</span>
|
||
<span class="n">mat2</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">mat2</span><span class="p">,</span> <span class="s1">'float32'</span><span class="p">)</span>
|
||
|
||
<span class="nb">input</span><span class="p">,</span> <span class="n">mat2</span> <span class="o">=</span> <span class="n">broadcast_helper</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">mat2</span><span class="p">)</span>
|
||
<span class="n">op0</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">MatrixOperation</span><span class="o">.</span><span class="n">TRANSPOSE</span> <span class="k">if</span> <span class="n">transa</span> \
|
||
<span class="k">else</span> <span class="n">trt</span><span class="o">.</span><span class="n">MatrixOperation</span><span class="o">.</span><span class="n">NONE</span>
|
||
<span class="n">op1</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">MatrixOperation</span><span class="o">.</span><span class="n">TRANSPOSE</span> <span class="k">if</span> <span class="n">transb</span> \
|
||
<span class="k">else</span> <span class="n">trt</span><span class="o">.</span><span class="n">MatrixOperation</span><span class="o">.</span><span class="n">NONE</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_matrix_multiply</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">op0</span><span class="p">,</span>
|
||
<span class="n">mat2</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">op1</span><span class="p">)</span>
|
||
<span class="n">output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">use_fp32_acc</span><span class="p">:</span>
|
||
<span class="n">output</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="s2">"float16"</span><span class="p">)</span>
|
||
|
||
<span class="k">return</span> <span class="n">output</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="gemm_swiglu">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.gemm_swiglu">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">gemm_swiglu</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">weight</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">bias</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">scale_d0</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span>
|
||
<span class="n">scale_d1</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span>
|
||
<span class="n">scale_output</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add a matrix multiplication, followed by SwiGLU (`x * SiLU(gate)`) operation.</span>
|
||
|
||
<span class="sd"> The second SwiGLU operation takes the preceding tensor, splits it into two halves</span>
|
||
<span class="sd"> along the last dimension, applies SiLU to the second half and multiply the results. The</span>
|
||
<span class="sd"> behaviour is undefined if the last dimension is not even.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The first tensor (often called A).</span>
|
||
|
||
<span class="sd"> weight : Tensor</span>
|
||
<span class="sd"> The second tensor (often called B).</span>
|
||
|
||
<span class="sd"> bias : Optional[Tensor]</span>
|
||
<span class="sd"> The per-channel bias. The plugin with fp8 dtype does not support bias yet.</span>
|
||
|
||
<span class="sd"> scale_d0 : float</span>
|
||
<span class="sd"> The scale for dequantizing x, used for fp8</span>
|
||
|
||
<span class="sd"> scale_d1 : float</span>
|
||
<span class="sd"> The scale for dequantizing gate, used for fp8</span>
|
||
|
||
<span class="sd"> scale_output : float</span>
|
||
<span class="sd"> The scale for quantizing output, used for fp8</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the inserted layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
|
||
<span class="s1">'GemmSwiglu'</span><span class="p">,</span> <span class="s1">'1'</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
|
||
<span class="n">p_dtype</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">gemm_swiglu_plugin</span>
|
||
<span class="k">if</span> <span class="n">p_dtype</span> <span class="o">==</span> <span class="s2">"fp8"</span><span class="p">:</span>
|
||
<span class="k">assert</span> <span class="n">bias</span> <span class="o">==</span> <span class="kc">None</span><span class="p">,</span> <span class="s2">"fp8 gemm_swiglu does not support bias yet"</span>
|
||
|
||
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"type_id"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">))],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">pf_has_bias</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"has_bias"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="mi">0</span> <span class="k">if</span> <span class="n">bias</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="mi">1</span><span class="p">),</span> <span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">pf_scale_d0</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"scale_d0"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">scale_d0</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">FLOAT32</span><span class="p">)</span>
|
||
<span class="n">pf_scale_d1</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"scale_d1"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">scale_d1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">FLOAT32</span><span class="p">)</span>
|
||
<span class="n">pf_scale_output</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"scale_output"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">scale_output</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">FLOAT32</span><span class="p">)</span>
|
||
|
||
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">(</span>
|
||
<span class="p">[</span><span class="n">pf_type</span><span class="p">,</span> <span class="n">pf_has_bias</span><span class="p">,</span> <span class="n">pf_scale_d0</span><span class="p">,</span> <span class="n">pf_scale_d1</span><span class="p">,</span> <span class="n">pf_scale_output</span><span class="p">])</span>
|
||
<span class="n">gemm_swiglu_plug</span> <span class="o">=</span> <span class="n">plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">"gemm_swiglu"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
|
||
<span class="c1"># TODO(anchengc) pass nullptr when no bias</span>
|
||
<span class="k">if</span> <span class="n">bias</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">bias</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">weight</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">trt_dtype_to_np</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">dtype</span><span class="p">)))</span>
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">weight</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">bias</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">]</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">gemm_swiglu_plug</span><span class="p">)</span>
|
||
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="constant">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.constant">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">constant</span><span class="p">(</span><span class="n">ndarray</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add a constant layer.</span>
|
||
|
||
<span class="sd"> TensorRT graphs encapsulate constant values in the form of constant layers</span>
|
||
<span class="sd"> (tensorrt.IConstantLayer). That function creates such a layer from a Numpy</span>
|
||
<span class="sd"> array of values. After compilation of the network by TensorRT, those</span>
|
||
<span class="sd"> weights are stored in the serialized TensorRT engine.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> ndarray : numpy.ndarray</span>
|
||
<span class="sd"> The array of values (weights) encapsulated by this constant layer.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the inserted layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">weights</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">Weights</span><span class="p">(</span><span class="n">np_dtype_to_trt</span><span class="p">(</span><span class="n">ndarray</span><span class="o">.</span><span class="n">dtype</span><span class="p">),</span> <span class="n">ndarray</span><span class="o">.</span><span class="n">ctypes</span><span class="o">.</span><span class="n">data</span><span class="p">,</span>
|
||
<span class="n">ndarray</span><span class="o">.</span><span class="n">size</span><span class="p">)</span>
|
||
<span class="c1"># Prevent underlying numpy array from going out of scope</span>
|
||
<span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">register_ndarray</span><span class="p">(</span><span class="n">ndarray</span><span class="p">)</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_constant</span><span class="p">(</span><span class="n">trt</span><span class="o">.</span><span class="n">Dims</span><span class="p">(</span><span class="n">ndarray</span><span class="o">.</span><span class="n">shape</span><span class="p">),</span> <span class="n">weights</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">strongly_typed</span><span class="p">:</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">set_output_type</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">np_dtype_to_trt</span><span class="p">(</span><span class="n">ndarray</span><span class="o">.</span><span class="n">dtype</span><span class="p">))</span>
|
||
<span class="n">tensor</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
<span class="c1"># TODO: remove this WAR after https://nvbugs/4359151 fixed.</span>
|
||
<span class="n">set_np_weight</span><span class="p">(</span><span class="n">default_trtnet</span><span class="p">(),</span> <span class="n">layer</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="n">ndarray</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">tensor</span></div>
|
||
|
||
|
||
|
||
<span class="c1"># TODO: TensorRT uses sizes of the output dimensions.</span>
|
||
<span class="c1"># DL framework uses ends usually. Will change it to ends.</span>
|
||
<div class="viewcode-block" id="slice">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.slice">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">slice</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">starts</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">Sequence</span><span class="p">[</span><span class="nb">int</span><span class="p">]],</span>
|
||
<span class="n">sizes</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">Sequence</span><span class="p">[</span><span class="nb">int</span><span class="p">]],</span>
|
||
<span class="n">strides</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">Sequence</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">mode</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">SampleMode</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">fill_value</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to extract a slice from a tensor.</span>
|
||
|
||
<span class="sd"> As described in the TensorRT documentation of the ISliceLayer, the slice</span>
|
||
<span class="sd"> layer has two variants: Static and dynamic.</span>
|
||
|
||
<span class="sd"> For static slicing, this function takes the starts and sizes values in the</span>
|
||
<span class="sd"> different dimensions to slice at layer creation time via a sequence of</span>
|
||
<span class="sd"> integers. For dynamic slicing, it accepts starts and sizes as</span>
|
||
<span class="sd"> tensorrt.ITensor`s.</span>
|
||
|
||
<span class="sd"> The slice layer selects for each dimension a start location from within the</span>
|
||
<span class="sd"> input tensor, and copies elements to the output tensor using a stride of 1</span>
|
||
<span class="sd"> across the input tensor. Start and size tensors must be 1-D int32 shape</span>
|
||
<span class="sd"> tensors if not specified as a sequence of integers.</span>
|
||
|
||
<span class="sd"> As an example, on input = [[0, 2, 4], [1, 3, 5]], the call to</span>
|
||
|
||
<span class="sd"> slice(input, start=[1, 0], size=[1, 2])</span>
|
||
|
||
<span class="sd"> will produce the tensor [[1, 3]] as output. The slice operator when</span>
|
||
<span class="sd"> executed by TensorRT will copy one row (because size[0] == 1) starting from</span>
|
||
<span class="sd"> the 2nd row (because start[0] == 1) and two columns (size[1] == 2) starting</span>
|
||
<span class="sd"> from the 1st column (because start[1] == 0).</span>
|
||
|
||
<span class="sd"> In pseudo-code the behavior of that operation can be described as follows</span>
|
||
<span class="sd"> for a 2D tensor (and easily be extended to more dimensions):</span>
|
||
|
||
<span class="sd"> output = Tensor(shape=sizes)</span>
|
||
<span class="sd"> for ii in range(sizes[0]):</span>
|
||
<span class="sd"> for jj in range(sizes[1]):</span>
|
||
<span class="sd"> output[ii][jj] = input[starts[0]+ii][starts[1]+jj]</span>
|
||
|
||
<span class="sd"> Note that it is common in deep-learning frameworks to use ranges</span>
|
||
<span class="sd"> [start:end] for similar operations. It can be emulated by setting the sizes</span>
|
||
<span class="sd"> argument such that in each dimension [start:start+size] == [start:end] i.e.</span>
|
||
<span class="sd"> size = end-start.</span>
|
||
|
||
<span class="sd"> TensorRT supports different slice modes but that function restricts that</span>
|
||
<span class="sd"> choice to `mode == tensorrt.SampleMode.STRICT_BOUNDS`.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor on which the slicing is performed.</span>
|
||
|
||
<span class="sd"> starts : Union[Tensor, Sequence[int]]</span>
|
||
<span class="sd"> The starting points, in the input tensor, and each dimension.</span>
|
||
|
||
<span class="sd"> sizes : Union[Tensor, Sequence[int]]</span>
|
||
<span class="sd"> The number of elements in each dimension of the sliced tensor (output).</span>
|
||
|
||
<span class="sd"> strides : Union[Tensor, Sequence[int]]</span>
|
||
<span class="sd"> The step be taken from start, in input tensor.</span>
|
||
|
||
<span class="sd"> mode : trt.SampleMode</span>
|
||
<span class="sd"> The mode that controls how the slice operation handles out of bounds coordinates.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the slice layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">input_ndim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
|
||
|
||
<span class="n">trt_starts</span> <span class="o">=</span> <span class="n">starts</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">starts</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">):</span>
|
||
<span class="n">trt_starts</span> <span class="o">=</span> <span class="p">[</span><span class="mi">0</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">input_ndim</span><span class="p">)]</span> <span class="c1"># unused dummy value</span>
|
||
|
||
<span class="n">trt_sizes</span> <span class="o">=</span> <span class="n">sizes</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">sizes</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">):</span>
|
||
<span class="n">trt_sizes</span> <span class="o">=</span> <span class="p">[</span><span class="mi">1</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">input_ndim</span><span class="p">)]</span> <span class="c1"># unused dummy value</span>
|
||
|
||
<span class="n">trt_strides</span> <span class="o">=</span> <span class="n">strides</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">strides</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">)</span> <span class="ow">or</span> <span class="n">strides</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">trt_strides</span> <span class="o">=</span> <span class="p">[</span><span class="mi">1</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">input_ndim</span><span class="p">)]</span>
|
||
|
||
<span class="k">if</span> <span class="n">fill_value</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">fill_value</span><span class="p">,</span> <span class="nb">float</span><span class="p">):</span>
|
||
<span class="n">fill_value</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">fp32_array</span><span class="p">(</span><span class="n">fill_value</span><span class="p">))</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_slice</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">start</span><span class="o">=</span><span class="n">trt_starts</span><span class="p">,</span>
|
||
<span class="n">shape</span><span class="o">=</span><span class="n">trt_sizes</span><span class="p">,</span>
|
||
<span class="n">stride</span><span class="o">=</span><span class="n">trt_strides</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">mode</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">mode</span> <span class="o">=</span> <span class="n">mode</span>
|
||
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">starts</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">):</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">starts</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">sizes</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">):</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">sizes</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">strides</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">):</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="n">strides</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">mode</span> <span class="ow">is</span> <span class="n">trt</span><span class="o">.</span><span class="n">SampleMode</span><span class="o">.</span><span class="n">FILL</span> <span class="ow">and</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">fill_value</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">):</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="n">fill_value</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="rand">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.rand">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">rand</span><span class="p">(</span><span class="n">shape</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">low</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">high</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="p">]</span> <span class="o">=</span> <span class="s1">'float32'</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> This operation adds a fill layer that generates a random (uniform) tensor with the specified shape and data type.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> shape: Tensor</span>
|
||
<span class="sd"> The shape of the tensor needed to be generated.</span>
|
||
<span class="sd"> low: float</span>
|
||
<span class="sd"> The minimum value (inclusive) of the range used for random.</span>
|
||
<span class="sd"> high: float</span>
|
||
<span class="sd"> The maximum value (inclusive) of the range used for random.</span>
|
||
<span class="sd"> dtype: Union[str, trt.DataType]</span>
|
||
<span class="sd"> The desired data type for the output tensor.</span>
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The generated random tensor produced by the fill layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="c1"># NOTE: DISABLED FOR NOW UNTIL THE FILL LAYER (RANDOM_UNIFORM) in TRT IS FIXED</span>
|
||
<span class="k">assert</span> <span class="kc">False</span><span class="p">,</span> <span class="s2">"The rand() op is temporarily disabled."</span>
|
||
<span class="n">low</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">fp32_array</span><span class="p">(</span><span class="n">low</span><span class="p">))</span>
|
||
<span class="n">high</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">fp32_array</span><span class="p">(</span><span class="n">high</span><span class="p">))</span>
|
||
<span class="n">trt_dtype</span> <span class="o">=</span> <span class="n">dtype</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dtype</span><span class="p">,</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="p">)</span> <span class="k">else</span> <span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_fill</span><span class="p">([</span><span class="mi">0</span><span class="p">],</span> <span class="n">trt</span><span class="o">.</span><span class="n">FillOperation</span><span class="o">.</span><span class="n">RANDOM_UNIFORM</span><span class="p">,</span>
|
||
<span class="n">trt_dtype</span><span class="p">)</span>
|
||
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">shape</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">low</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">high</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="categorical_sample">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.categorical_sample">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">categorical_sample</span><span class="p">(</span><span class="n">probs</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">rand_data</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> This is a sampling operation and an equivalent of torch.distributions.Categorical.sample()</span>
|
||
<span class="sd"> i.e. given a probability distribution tensor, it samples an index of that tensor.</span>
|
||
<span class="sd"> See: https://pytorch.org/docs/stable/distributions.html#torch.distributions.categorical.Categorical.sample</span>
|
||
<span class="sd"> NOTE: This assumes that the given probabilities are **not** normalized.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> probs: Tensor</span>
|
||
<span class="sd"> A 1-D floating point tensor representing the probability distributions.</span>
|
||
<span class="sd"> rand_data: Tensor (optional)</span>
|
||
<span class="sd"> A random tensor of same shape as `probs` tensor.</span>
|
||
<span class="sd"> If not provided, this function will add a rand() op to generate it and use for sampling.</span>
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> A tensor containing a single index of the `probs` tensor representing the sample.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">probs</span> <span class="o">=</span> <span class="n">probs</span> <span class="o">/</span> <span class="nb">sum</span><span class="p">(</span><span class="n">probs</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||
<span class="n">rand_shape</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="k">assert</span> <span class="n">probs</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">></span> <span class="mi">0</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">probs</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">-</span> <span class="mi">1</span><span class="p">):</span>
|
||
<span class="n">rand_shape</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">shape</span><span class="p">(</span><span class="n">probs</span><span class="p">,</span> <span class="n">i</span><span class="p">))</span>
|
||
<span class="n">rand_shape</span> <span class="o">=</span> <span class="n">concat</span><span class="p">(</span><span class="n">rand_shape</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">rand_data</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">rand_data</span> <span class="o">=</span> <span class="n">rand</span><span class="p">(</span><span class="n">rand_shape</span><span class="p">,</span> <span class="n">low</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">high</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">probs</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">rand_shape</span> <span class="o">==</span> <span class="n">shape</span><span class="p">(</span><span class="n">rand_data</span><span class="p">)</span>
|
||
<span class="n">rand_data</span> <span class="o">=</span> <span class="n">expand</span><span class="p">(</span><span class="n">unsqueeze</span><span class="p">(</span><span class="n">rand_data</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">),</span> <span class="n">shape</span><span class="p">(</span><span class="n">probs</span><span class="p">))</span>
|
||
<span class="n">cum_probs</span> <span class="o">=</span> <span class="n">cumsum</span><span class="p">(</span><span class="n">probs</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="n">cmp</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">cum_probs</span> <span class="o">>=</span> <span class="n">rand_data</span><span class="p">,</span> <span class="n">probs</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
|
||
<span class="n">samples</span> <span class="o">=</span> <span class="n">argmax</span><span class="p">(</span><span class="n">cmp</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">samples</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="Conditional">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Conditional">[docs]</a>
|
||
<span class="k">class</span> <span class="nc">Conditional</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to conditionally execute two code paths/subgraphs.</span>
|
||
|
||
<span class="sd"> Usage:</span>
|
||
<span class="sd"> 1. conditional = Conditional(condition)</span>
|
||
<span class="sd"> 2. input_1_ = conditional.add_input(input_1)</span>
|
||
<span class="sd"> ...</span>
|
||
<span class="sd"> input_n_ = conditional.add_input(input_n)</span>
|
||
<span class="sd"> 3. Construct the graph to get true_output_value and false_output_value using input_1_, ..., input_n_</span>
|
||
<span class="sd"> 4. output = conditional.add_output(true_output_value, false_output_value)</span>
|
||
<span class="sd"> '''</span>
|
||
|
||
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">condition</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">):</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_if_conditional</span><span class="p">()</span>
|
||
<span class="k">if</span> <span class="n">condition</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">></span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="n">condition</span> <span class="o">=</span> <span class="n">view</span><span class="p">(</span><span class="n">condition</span><span class="p">,</span> <span class="p">[])</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">layer</span><span class="o">.</span><span class="n">set_condition</span><span class="p">(</span><span class="n">condition</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
|
||
<div class="viewcode-block" id="Conditional.add_input">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Conditional.add_input">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">add_input</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="n">in_node</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer</span><span class="o">.</span><span class="n">add_input</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">in_node</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">in_node</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="Conditional.add_output">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Conditional.add_output">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">add_output</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">true_value</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">false_value</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="n">out_node</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer</span><span class="o">.</span><span class="n">add_output</span><span class="p">(</span><span class="n">true_value</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">false_value</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">out_node</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">out_node</span><span class="p">)</span></div>
|
||
</div>
|
||
|
||
|
||
|
||
<span class="c1"># TODO: support step.</span>
|
||
<div class="viewcode-block" id="arange">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.arange">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">arange</span><span class="p">(</span><span class="n">start</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">],</span> <span class="n">end</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">],</span>
|
||
<span class="n">dtype</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to fill a 1D tensor.</span>
|
||
|
||
<span class="sd"> The tensor is filled with the values between start and end with a step of 1</span>
|
||
<span class="sd"> between the different elements. In pseudo-code, it corresponds to a tensor</span>
|
||
<span class="sd"> populated with the values:</span>
|
||
|
||
<span class="sd"> output = Tensor([dtype(ii) for ii in range(start, end, 1)])</span>
|
||
|
||
<span class="sd"> For example, a call to arange(3, 6, 'int32') will add an operation to the</span>
|
||
<span class="sd"> TensorRT graph that will produce [3, 4, 5] when executed. The call to</span>
|
||
<span class="sd"> arange(2, 5, 'float32') will add a layer to generate [2.0, 3.0, 4.0].</span>
|
||
|
||
<span class="sd"> This operation is implemented using a tensorrt.IFillLayer in</span>
|
||
<span class="sd"> trt.FillOperation.LINSPACE mode.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> start : Union[Tensor, int]</span>
|
||
<span class="sd"> The starting point of the range.</span>
|
||
|
||
<span class="sd"> end : Union[Tensor, int]</span>
|
||
<span class="sd"> The end point of the range.</span>
|
||
|
||
<span class="sd"> dtype : str</span>
|
||
<span class="sd"> The type of the elements. See _str_to_trt_dtype_dict in _utils.py</span>
|
||
<span class="sd"> for a list of supported types and type names.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the fill layer. It is a 1D tensor containing</span>
|
||
<span class="sd"> `end-start` elements of type `dtype`.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">res_dtype</span> <span class="o">=</span> <span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">start</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
|
||
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">end</span><span class="p">,</span> <span class="nb">int</span><span class="p">)</span>
|
||
<span class="n">array_func</span> <span class="o">=</span> <span class="n">int32_array</span> <span class="k">if</span> <span class="n">res_dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">int32</span> <span class="k">else</span> <span class="n">int64_array</span>
|
||
<span class="n">start</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">array_func</span><span class="p">(</span><span class="n">start</span><span class="p">))</span>
|
||
<span class="n">end</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">array_func</span><span class="p">(</span><span class="n">end</span><span class="p">))</span>
|
||
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">start</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">):</span>
|
||
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">end</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">start</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">int32</span> <span class="ow">or</span> <span class="n">start</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">int64</span>
|
||
<span class="k">assert</span> <span class="n">end</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">int32</span> <span class="ow">or</span> <span class="n">end</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">int64</span>
|
||
<span class="k">if</span> <span class="n">start</span><span class="o">.</span><span class="n">dtype</span> <span class="o">!=</span> <span class="n">end</span><span class="o">.</span><span class="n">dtype</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="n">start</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">int32</span><span class="p">:</span> <span class="c1"># end == trt.int64</span>
|
||
<span class="k">if</span> <span class="n">res_dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">int32</span><span class="p">:</span>
|
||
<span class="n">end</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">end</span><span class="p">,</span> <span class="s2">"int32"</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">start</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">start</span><span class="p">,</span> <span class="s2">"int64"</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span> <span class="c1"># start == trt.int64 and end == trt.int32</span>
|
||
<span class="k">if</span> <span class="n">res_dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">int32</span><span class="p">:</span>
|
||
<span class="n">start</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">start</span><span class="p">,</span> <span class="s2">"int32"</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">end</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">end</span><span class="p">,</span> <span class="s2">"int64"</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">TypeError</span><span class="p">(</span><span class="s2">"</span><span class="si">%s</span><span class="s2"> is not supported"</span> <span class="o">%</span> <span class="nb">type</span><span class="p">(</span><span class="n">start</span><span class="p">))</span>
|
||
|
||
<span class="k">assert</span> <span class="n">start</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">end</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"start type (</span><span class="si">{</span><span class="n">start</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">) != end type (</span><span class="si">{</span><span class="n">end</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">)"</span>
|
||
<span class="n">step</span> <span class="o">=</span> <span class="n">constant_to_tensor_</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">start</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="n">to_array</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||
|
||
<span class="n">num</span> <span class="o">=</span> <span class="n">end</span> <span class="o">-</span> <span class="n">start</span>
|
||
<span class="n">num</span> <span class="o">=</span> <span class="n">num</span><span class="o">.</span><span class="n">view</span><span class="p">([</span><span class="mi">1</span><span class="p">])</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">trt</span><span class="o">.</span><span class="n">int64</span><span class="p">)</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_fill</span><span class="p">([</span><span class="mi">0</span><span class="p">],</span> <span class="n">trt</span><span class="o">.</span><span class="n">FillOperation</span><span class="o">.</span><span class="n">LINSPACE</span><span class="p">,</span>
|
||
<span class="n">start</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">num</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span> <span class="c1"># rank = 1</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">start</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span> <span class="c1"># rank = 0</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">step</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span> <span class="c1"># rank = 1</span>
|
||
<span class="n">tensor</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">tensor</span><span class="o">.</span><span class="n">dtype</span> <span class="o">!=</span> <span class="n">res_dtype</span><span class="p">:</span>
|
||
<span class="n">tensor</span> <span class="o">=</span> <span class="n">tensor</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">tensor</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="expand">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.expand">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">expand</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">expand_shape</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to expand a tensor.</span>
|
||
|
||
<span class="sd"> The operation expands the input tensor in the singleton dimensions to the</span>
|
||
<span class="sd"> size indicated by the corresponding dimension in the `expand_shape` tensor.</span>
|
||
<span class="sd"> In other words, given an input tensor with dimensions of size 1, those</span>
|
||
<span class="sd"> dimensions will be expanded to the size in `expand_shape`.</span>
|
||
|
||
<span class="sd"> For example, a tensor of shape [4, 3, 1, 3] will be expanded to a tensor of</span>
|
||
<span class="sd"> shape [4, 3, 2, 3] by the layer created using expand(input, [4, 3, 2, 3]).</span>
|
||
|
||
<span class="sd"> The expansion may either replicate the values or be mapped to a view with a</span>
|
||
<span class="sd"> stride of 0 in the expanded dimensions. For example, for a tensor [[3, 2]] of</span>
|
||
<span class="sd"> shape [1, 2],</span>
|
||
|
||
<span class="sd"> expand([[3, 2]], [2, 2])</span>
|
||
|
||
<span class="sd"> can be used to expand the input to [[3, 2], [3, 2]].</span>
|
||
|
||
<span class="sd"> This operation is implemented using a tensorrt.ISliceLayer. The current</span>
|
||
<span class="sd"> implementation does not verify that non singleton dimensions are not</span>
|
||
<span class="sd"> shrunk. In other words, for an input of shape [4, 1, 2],</span>
|
||
|
||
<span class="sd"> expand(input, [3, 2, 2])</span>
|
||
|
||
<span class="sd"> will produce a tensor of shape [3, 2, 2]. That behavior is subject to</span>
|
||
<span class="sd"> change in the future.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor.</span>
|
||
|
||
<span class="sd"> expand_shape : Tensor</span>
|
||
<span class="sd"> The new shape of the expanded tensor.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the expand layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">ndim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_slice</span><span class="p">(</span>
|
||
<span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">start</span><span class="o">=</span><span class="p">[</span><span class="mi">0</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)],</span>
|
||
<span class="n">shape</span><span class="o">=</span><span class="p">[</span><span class="mi">1</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)],</span> <span class="c1"># unused dummy value</span>
|
||
<span class="n">stride</span><span class="o">=</span><span class="p">[</span><span class="mi">1</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)]</span> <span class="c1"># unused dummy value</span>
|
||
<span class="p">)</span>
|
||
|
||
<span class="c1"># The stride is either:</span>
|
||
<span class="c1"># 0 for dimensions of size 1 (i.e. shape(input, i) - 1 == 1 - 1 == 0) or,</span>
|
||
<span class="c1"># 1 for dimensions of size > 1 since minimum(value >= 1, 1) == 1.</span>
|
||
<span class="n">stride_tensor</span> <span class="o">=</span> <span class="n">concat</span><span class="p">(</span>
|
||
<span class="p">[</span><span class="n">minimum</span><span class="p">((</span><span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">i</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span><span class="p">),</span> <span class="mi">1</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)])</span>
|
||
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">expand_shape</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="n">stride_tensor</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="einsum">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.einsum">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">einsum</span><span class="p">(</span><span class="n">einsum_eq</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">inputs</span><span class="p">:</span> <span class="n">Sequence</span><span class="p">[</span><span class="n">Tensor</span><span class="p">])</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an Einsum operation.</span>
|
||
|
||
<span class="sd"> That operation maps to tensorrt.IEinsumLayer. As explained in the TensorRT</span>
|
||
<span class="sd"> documentation, this layer implements a summation over the elements of the</span>
|
||
<span class="sd"> inputs along dimensions specified by the equation parameter, based on the</span>
|
||
<span class="sd"> Einstein summation convention. The layer can have one or more inputs of</span>
|
||
<span class="sd"> rank >= 0. All the inputs must be of same data type. This layer supports</span>
|
||
<span class="sd"> all TensorRT data types except bool. There is one output tensor of the same</span>
|
||
<span class="sd"> type as the input tensors. The shape of output tensor is determined by the</span>
|
||
<span class="sd"> equation.</span>
|
||
|
||
<span class="sd"> The equation specifies ASCII lower-case letters for each dimension in the</span>
|
||
<span class="sd"> inputs in the same order as the dimensions, separated by comma for each</span>
|
||
<span class="sd"> input. The dimensions labeled with the same subscript must match or be</span>
|
||
<span class="sd"> able to be broadcasted. Repeated subscript labels in one input take the diagonal.</span>
|
||
<span class="sd"> Repeating a label across multiple inputs means that those axes will be</span>
|
||
<span class="sd"> multiplied. Omitting a label from the output means values along those axes</span>
|
||
<span class="sd"> will be summed. In implicit mode, the indices which appear once in the</span>
|
||
<span class="sd"> expression will be part of the output in increasing alphabetical order. In</span>
|
||
<span class="sd"> explicit mode, the output can be controlled by specifying output subscript</span>
|
||
<span class="sd"> labels by adding an arrow (‘->’) followed by subscripts for the output. For</span>
|
||
<span class="sd"> example, “ij,jk->ik” is equivalent to “ij,jk”. Ellipsis (‘…’) can be used</span>
|
||
<span class="sd"> in place of subscripts to broadcast the dimensions. See the TensorRT</span>
|
||
<span class="sd"> Developer Guide for more details on equation syntax.</span>
|
||
|
||
<span class="sd"> Many common operations can be expressed using the Einsum equation. For</span>
|
||
<span class="sd"> example:</span>
|
||
<span class="sd"> Matrix Transpose: ij->ji</span>
|
||
<span class="sd"> Sum: ij-> Matrix-Matrix</span>
|
||
<span class="sd"> Multiplication: ik,kj->ij</span>
|
||
<span class="sd"> Dot Product: i,i-></span>
|
||
<span class="sd"> Matrix-Vector Multiplication: ik,k->i</span>
|
||
<span class="sd"> Batch Matrix Multiplication: ijk,ikl->ijl</span>
|
||
<span class="sd"> Batch Diagonal: …ii->…i</span>
|
||
|
||
<span class="sd"> Note that TensorRT does not support ellipsis or diagonal operations so,</span>
|
||
<span class="sd"> neither, does TensorRT-LLM.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> einsum_eq : str</span>
|
||
<span class="sd"> The Einsum equation.</span>
|
||
|
||
<span class="sd"> inputs: Sequence[Tensor]</span>
|
||
<span class="sd"> The sequence of inputs consumed by the Einsum operation.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the Einsum operation.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_einsum</span><span class="p">([</span><span class="n">i</span><span class="o">.</span><span class="n">trt_tensor</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">inputs</span><span class="p">],</span>
|
||
<span class="n">einsum_eq</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="permute">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.permute">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">permute</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dims</span><span class="p">:</span> <span class="n">Sequence</span><span class="p">[</span><span class="nb">int</span><span class="p">])</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to permute the dimensions of a tensor.</span>
|
||
|
||
<span class="sd"> The dimensions of the input tensor are permuted according to the sequence</span>
|
||
<span class="sd"> of dimensions in 'dims'. That operation maps to tensorrt.IShuffleLayer where</span>
|
||
<span class="sd"> the second transposition is described by the indices in 'dims'.</span>
|
||
|
||
<span class="sd"> Given a tensor of rank N, the result of the permutation is a tensor of rank</span>
|
||
<span class="sd"> N in which the i-th input dimension maps to the dims[i]-th dimension.</span>
|
||
|
||
<span class="sd"> For example, permute(input, [1, 0]) will transpose a 2D tensor by permuting</span>
|
||
<span class="sd"> the rows and columns.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor to permute.</span>
|
||
|
||
<span class="sd"> dims : Sequence[int]</span>
|
||
<span class="sd"> The description of the permutation.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the permutation layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">dims</span> <span class="o">=</span> <span class="n">dim_resolve_negative</span><span class="p">(</span><span class="nb">tuple</span><span class="p">(</span><span class="n">dims</span><span class="p">),</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">())</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_shuffle</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">second_transpose</span> <span class="o">=</span> <span class="n">dims</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="transpose">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.transpose">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">transpose</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dim0</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">dim1</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to transpose two dimensions of a tensor.</span>
|
||
|
||
<span class="sd"> That operation produces a tensor in which the dimensions 'dim0' and 'dim1'</span>
|
||
<span class="sd"> are permuted. The other dimensions, if the rank of the tensor is greater</span>
|
||
<span class="sd"> than 2, remain untouched.</span>
|
||
|
||
<span class="sd"> That function is a helper built on the 'functional.permute' function.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor to transpose.</span>
|
||
|
||
<span class="sd"> dim0 : int</span>
|
||
<span class="sd"> The first dimension to transpose.</span>
|
||
|
||
<span class="sd"> dim1 : int</span>
|
||
<span class="sd"> The second dimension to transpose.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the permutation layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">permutation</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()))</span>
|
||
<span class="n">permutation</span><span class="p">[</span><span class="n">dim0</span><span class="p">]</span> <span class="o">=</span> <span class="n">dim1</span>
|
||
<span class="n">permutation</span><span class="p">[</span><span class="n">dim1</span><span class="p">]</span> <span class="o">=</span> <span class="n">dim0</span>
|
||
|
||
<span class="k">return</span> <span class="n">permute</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">permutation</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="view">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.view">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">view</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">shape</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">Sequence</span><span class="p">[</span><span class="nb">int</span><span class="p">]],</span>
|
||
<span class="n">zero_is_placeholder</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to create a view of a tensor.</span>
|
||
|
||
<span class="sd"> That operation adds a tensorrt.IShuffleLayer to the network. If the 'shape'</span>
|
||
<span class="sd"> parameter is a Tensor, that view is dynamic. Otherwise, it is a static</span>
|
||
<span class="sd"> view.</span>
|
||
|
||
<span class="sd"> Note that TensorRT limits the number of inferred dimensions to 1. It means</span>
|
||
<span class="sd"> that the shape sequence or tensor cannot contain more than one -1. This</span>
|
||
<span class="sd"> function enforces that constraint and will assert if it is not respected.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor to transpose.</span>
|
||
|
||
<span class="sd"> shape : Union[Tensor, Sequence[int]]</span>
|
||
<span class="sd"> The shape of the new tensor.</span>
|
||
|
||
<span class="sd"> zero_is_placeholder : bool</span>
|
||
<span class="sd"> When that parameter is True, the 0s in 'shape' are replaced by the</span>
|
||
<span class="sd"> sizes of the corresponding dimensions from the 'input'. Otherwise,</span>
|
||
<span class="sd"> the dimensions corresponding to 0s are shrunk.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the view/shuffle layer.</span>
|
||
<span class="sd"> '''</span>
|
||
|
||
<span class="c1"># TensorRT demands that at most one dimension is permitted to be specified as -1</span>
|
||
<span class="k">def</span> <span class="nf">assert_no_more_than_one_inferred_dim</span><span class="p">(</span><span class="nb">list</span><span class="p">):</span>
|
||
<span class="n">inferred_dim_list</span> <span class="o">=</span> <span class="p">[</span><span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">list</span> <span class="k">if</span> <span class="n">i</span> <span class="o">==</span> <span class="o">-</span><span class="mi">1</span><span class="p">]</span>
|
||
<span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">inferred_dim_list</span><span class="p">)</span> <span class="o"><=</span> <span class="mi">1</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_shuffle</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">zero_is_placeholder</span> <span class="o">=</span> <span class="n">zero_is_placeholder</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">shape</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">):</span>
|
||
<span class="n">assert_no_more_than_one_inferred_dim</span><span class="p">(</span><span class="n">shape</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">shape</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">shape</span><span class="p">,</span> <span class="p">(</span><span class="nb">list</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">)):</span>
|
||
<span class="n">assert_no_more_than_one_inferred_dim</span><span class="p">(</span><span class="n">shape</span><span class="p">)</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">reshape_dims</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">(</span><span class="n">shape</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">TypeError</span><span class="p">(</span><span class="s2">"</span><span class="si">%s</span><span class="s2"> is not supported"</span> <span class="o">%</span> <span class="nb">type</span><span class="p">(</span><span class="n">shape</span><span class="p">))</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="flatten">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.flatten">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">flatten</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">start_dim</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span> <span class="n">end_dim</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="o">-</span><span class="mi">1</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Flattens input by reshaping it into a one-dimensional tensor.</span>
|
||
|
||
<span class="sd"> If start_dim or end_dim are passed, only dimensions starting with start_dim and</span>
|
||
<span class="sd"> ending with end_dim are flattened. The order of elements in input is unchanged.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor to flatten.</span>
|
||
|
||
<span class="sd"> start_dim : int</span>
|
||
<span class="sd"> The first dim to flatten.</span>
|
||
|
||
<span class="sd"> end_dim : int</span>
|
||
<span class="sd"> The last dim to flatten.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the flatten layer.</span>
|
||
|
||
<span class="sd"> '''</span>
|
||
<span class="n">shape</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">shape</span>
|
||
<span class="n">ndim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
|
||
<span class="k">if</span> <span class="n">start_dim</span> <span class="o"><</span> <span class="mi">0</span><span class="p">:</span> <span class="n">start_dim</span> <span class="o">+=</span> <span class="n">ndim</span>
|
||
<span class="k">if</span> <span class="n">end_dim</span> <span class="o"><</span> <span class="mi">0</span><span class="p">:</span> <span class="n">end_dim</span> <span class="o">+=</span> <span class="n">ndim</span>
|
||
<span class="n">new_shape</span> <span class="o">=</span> <span class="nb">list</span><span class="p">()</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">start_dim</span><span class="p">):</span>
|
||
<span class="n">new_shape</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">shape</span><span class="p">[</span><span class="n">i</span><span class="p">])</span>
|
||
<span class="k">if</span> <span class="n">end_dim</span> <span class="o">-</span> <span class="n">start_dim</span> <span class="o">>=</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="n">flat_dim</span> <span class="o">=</span> <span class="mi">1</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">start_dim</span><span class="p">,</span> <span class="n">end_dim</span> <span class="o">+</span> <span class="mi">1</span><span class="p">):</span>
|
||
<span class="n">flat_dim</span> <span class="o">*=</span> <span class="n">shape</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
|
||
<span class="n">new_shape</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">flat_dim</span><span class="p">)</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">end_dim</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="n">ndim</span><span class="p">):</span>
|
||
<span class="n">new_shape</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">shape</span><span class="p">[</span><span class="n">i</span><span class="p">])</span>
|
||
<span class="k">return</span> <span class="n">view</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">new_shape</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="expand_dims">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.expand_dims">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">expand_dims</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">dim</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">Sequence</span><span class="p">[</span><span class="nb">int</span><span class="p">]],</span>
|
||
<span class="n">shape_cast_dtype</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to expand the tensor shape with singleton dimensions.</span>
|
||
|
||
<span class="sd"> That function adds a tensorrt.IShuffleLayer to the network. Given an 'input'</span>
|
||
<span class="sd"> of rank N and a sequence of M dimensions, the output tensor produced by</span>
|
||
<span class="sd"> this operation (when executed by TensorRT) will have a rank of N+M. Singleton</span>
|
||
<span class="sd"> dimensions will be inserted at the different positions in 'dim'.</span>
|
||
|
||
<span class="sd"> The pseudo-code for that operation is:</span>
|
||
|
||
<span class="sd"> new_shape, ii = [], 0</span>
|
||
<span class="sd"> for jj in range(input.rank() + len(dim)):</span>
|
||
<span class="sd"> new_shape.append(1 if jj in dims else input.shape[ii++])</span>
|
||
|
||
<span class="sd"> For example, for a tensor of shape [3, 4, 1, 5]</span>
|
||
|
||
<span class="sd"> expand_dims(input, [0, 2])</span>
|
||
|
||
<span class="sd"> will produce a tensor of shape [1, 3, 1, 4, 1, 5].</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor to expand.</span>
|
||
|
||
<span class="sd"> dim : Union[int, Sequence[int]]</span>
|
||
<span class="sd"> The positions in the output tensor where to insert singleton</span>
|
||
<span class="sd"> dimensions.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the shuffle layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
|
||
<span class="n">dim</span> <span class="o">=</span> <span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="p">)</span>
|
||
|
||
<span class="n">out_ndim</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span> <span class="o">+</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
|
||
|
||
<span class="n">input_shape</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">cast_to_dtype</span><span class="o">=</span><span class="n">shape_cast_dtype</span><span class="p">)</span>
|
||
<span class="n">out_shapes</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="n">j</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">out_ndim</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">dim</span><span class="p">:</span>
|
||
<span class="n">out_shapes</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">out_shapes</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">gather</span><span class="p">(</span><span class="n">input_shape</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">j</span><span class="p">))</span>
|
||
<span class="n">j</span> <span class="o">=</span> <span class="n">j</span> <span class="o">+</span> <span class="mi">1</span>
|
||
|
||
<span class="n">out_shape</span> <span class="o">=</span> <span class="n">concat</span><span class="p">(</span><span class="n">out_shapes</span><span class="p">)</span>
|
||
|
||
<span class="k">return</span> <span class="n">view</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">out_shape</span><span class="p">,</span> <span class="n">zero_is_placeholder</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<span class="c1"># NOTE: Jointly added with Apple</span>
|
||
<div class="viewcode-block" id="squeeze">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.squeeze">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">squeeze</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">dim</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">Sequence</span><span class="p">[</span><span class="nb">int</span><span class="p">]]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">zero_is_placeholder</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to remove singleton dimensions of a tensor.</span>
|
||
|
||
<span class="sd"> This functions creates an operation that removes singleton dimension</span>
|
||
<span class="sd"> (dimension of size 1) at positions 'dim' in the input tensor. It works with</span>
|
||
<span class="sd"> negative values for the 'dim'.</span>
|
||
|
||
<span class="sd"> For example, for a tensor 'input' of shape [1, 4, 1, 4]:</span>
|
||
|
||
<span class="sd"> squeeze(input, 0) will produce an output of shape [4, 1, 4],</span>
|
||
<span class="sd"> squeeze(input, 2) will produce an output of shape [1, 4, 4],</span>
|
||
<span class="sd"> squeeze(input, [0, 2]) will produce an output of shape [4, 4],</span>
|
||
<span class="sd"> squeeze(input, [-2]) will produce an output of shape [1, 4, 4],</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor for which the singleton dimensions will be removed.</span>
|
||
|
||
<span class="sd"> dim : Union[int, Sequence[int]]</span>
|
||
<span class="sd"> The index of the singleton dimensions in the input tensor.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="n">dim</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">dim</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()))</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
|
||
<span class="n">dim</span> <span class="o">=</span> <span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="p">)</span>
|
||
<span class="n">dim</span> <span class="o">=</span> <span class="n">dim_resolve_negative</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">())</span>
|
||
|
||
<span class="n">new_shape</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">s</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">shape</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="n">s</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">and</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">dim</span><span class="p">:</span>
|
||
<span class="k">continue</span>
|
||
<span class="n">new_shape</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">i</span><span class="p">))</span>
|
||
|
||
<span class="n">new_shape</span> <span class="o">=</span> <span class="n">concat</span><span class="p">(</span><span class="n">new_shape</span><span class="p">)</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">new_shape</span><span class="p">)</span> <span class="o">></span> <span class="mi">0</span> <span class="k">else</span> <span class="p">[]</span>
|
||
<span class="nb">input</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">new_shape</span><span class="p">,</span> <span class="n">zero_is_placeholder</span><span class="o">=</span><span class="n">zero_is_placeholder</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="nb">input</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="unsqueeze">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.unsqueeze">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">unsqueeze</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">axis</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to insert a singleton dimension to a tensor.</span>
|
||
|
||
<span class="sd"> That functions creates an operation that insert a singleton dimension</span>
|
||
<span class="sd"> (dimension of size 1) at position 'axis' in the output tensor. It works with</span>
|
||
<span class="sd"> negative values for the 'axis'.</span>
|
||
|
||
<span class="sd"> For example, for a tensor 'input' of shape [4, 4]:</span>
|
||
|
||
<span class="sd"> unsqueeze(input, 0) will produce an output of shape [1, 4, 4],</span>
|
||
<span class="sd"> unsqueeze(input, 1) will produce an output of shape [4, 1, 4],</span>
|
||
<span class="sd"> unsqueeze(input, -1) will produce an output of shape [4, 4, 1],</span>
|
||
<span class="sd"> unsqueeze(input, -2) will produce an output of shape [4, 1, 4],</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor to expand with a singleton dimension.</span>
|
||
|
||
<span class="sd"> axis : int</span>
|
||
<span class="sd"> The index of the singleton dimension in the output tensor.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="n">axis</span> <span class="o"><</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="n">axis</span> <span class="o">=</span> <span class="n">axis</span> <span class="o">+</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">+</span> <span class="mi">1</span>
|
||
|
||
<span class="k">return</span> <span class="n">expand_dims</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">axis</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="stack">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.stack">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">stack</span><span class="p">(</span><span class="n">inputs</span><span class="p">:</span> <span class="n">Sequence</span><span class="p">[</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">dim</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to contact input tensors along a new dimension.</span>
|
||
|
||
<span class="sd"> The function creates an operation that creates a new dim for all the</span>
|
||
<span class="sd"> input tensors and then concatenates them along that new dim.</span>
|
||
<span class="sd">.</span>
|
||
|
||
<span class="sd"> All the tensors in 'inputs' must have the same shape.</span>
|
||
|
||
<span class="sd"> for ii in range(inputs[0].rank()):</span>
|
||
<span class="sd"> assert all(inp.shape[ii] == inputs[0].shape[ii] for inp in inputs)</span>
|
||
|
||
<span class="sd"> The shape of the output tensor is defined as:</span>
|
||
|
||
<span class="sd"> output.rank() = inputs[0].rank() + 1</span>
|
||
|
||
<span class="sd"> output.shape[dim] = len(inputs)</span>
|
||
|
||
<span class="sd"> for ii in range(inputs[0].rank()):</span>
|
||
<span class="sd"> if ii < dim:</span>
|
||
<span class="sd"> output.shape[ii] = inputs[0].shape[ii]</span>
|
||
<span class="sd"> else:</span>
|
||
<span class="sd"> output.shape[ii+1] = inputs[0].shape[ii]</span>
|
||
|
||
<span class="sd"> For example, given a sequence of two 2D tensors [[0, 1], [2, 3]] and</span>
|
||
<span class="sd"> [[4, 5], [6, 7]] both of shape [2, 2],</span>
|
||
|
||
<span class="sd"> stack(inputs, 0)</span>
|
||
|
||
<span class="sd"> will produce [[[0, 1], [2, 3]], [[4, 5], [6, 7]]] of shape [2, 2, 2] and</span>
|
||
|
||
<span class="sd"> stack(inputs, 1)</span>
|
||
|
||
<span class="sd"> will produce [[[0, 1], [4, 5]], [[2, 3], [6, 7]]] of shape [2, 2, 2].</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> inputs : Sequence[Tensor]</span>
|
||
<span class="sd"> The sequence of tensors to stack.</span>
|
||
|
||
<span class="sd"> dim : int</span>
|
||
<span class="sd"> The dimension in which the stack is performed.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> A tensor that contains the input tensors stacked along a new dimension.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">concat</span><span class="p">([</span><span class="n">unsqueeze</span><span class="p">(</span><span class="n">inp</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="n">dim</span><span class="p">)</span> <span class="k">for</span> <span class="n">inp</span> <span class="ow">in</span> <span class="n">inputs</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="n">dim</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="expand_dims_like">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.expand_dims_like">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">expand_dims_like</span><span class="p">(</span><span class="n">left</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">],</span> <span class="n">right</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to expand the first tensor to the same rank as the second</span>
|
||
<span class="sd"> tensor.</span>
|
||
|
||
<span class="sd"> That function takes a first tensor. It also accepts an integer or a float,</span>
|
||
<span class="sd"> in which case it creates a constant tensor from it. In both cases, the rank</span>
|
||
<span class="sd"> of that first tensor is compared to the rank of the second tensor. If they</span>
|
||
<span class="sd"> are of the same rank, the first tensor is returned. Otherwise, the first</span>
|
||
<span class="sd"> tensor is expanded on the left to match the rank of the second tensor.</span>
|
||
|
||
<span class="sd"> Note that the shapes do not have to match, only the rank is considered in</span>
|
||
<span class="sd"> that function.</span>
|
||
|
||
<span class="sd"> For example, for a pair of tensors of shapes [3, 4] and [4, 3, 2], the</span>
|
||
<span class="sd"> first tensor will be expanded to a tensor of rank 3 and shape [1, 3, 4].</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> left : Union[Tensor, int, float]</span>
|
||
<span class="sd"> The first tensor to expand. When a scalar value is provided as a</span>
|
||
<span class="sd"> parameter, that function first creates a tensor before expanding it</span>
|
||
<span class="sd"> (if needed).</span>
|
||
|
||
<span class="sd"> right : Tensor</span>
|
||
<span class="sd"> The reference tensor to match.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the shuffle layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">left</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
|
||
<span class="n">left</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">dims_array</span><span class="p">([</span><span class="n">left</span><span class="p">]))</span>
|
||
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">left</span><span class="p">,</span> <span class="nb">float</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">right</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">)</span> <span class="ow">and</span> <span class="n">right</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">HALF</span><span class="p">:</span>
|
||
<span class="n">left</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">fp16_array</span><span class="p">([</span><span class="n">left</span><span class="p">]))</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">left</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">fp32_array</span><span class="p">([</span><span class="n">left</span><span class="p">]))</span>
|
||
<span class="n">left_ndim</span> <span class="o">=</span> <span class="n">left</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
|
||
<span class="n">right_ndim</span> <span class="o">=</span> <span class="n">right</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
|
||
<span class="k">if</span> <span class="n">right_ndim</span> <span class="o">></span> <span class="n">left_ndim</span><span class="p">:</span>
|
||
<span class="n">new_ndim</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">right_ndim</span> <span class="o">-</span> <span class="n">left_ndim</span><span class="p">))</span>
|
||
<span class="k">return</span> <span class="n">expand_dims</span><span class="p">(</span><span class="n">left</span><span class="p">,</span> <span class="n">new_ndim</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">left</span></div>
|
||
|
||
|
||
|
||
<span class="c1"># If dim is None, return a 1-D TensorRT-LLM tensor of the size</span>
|
||
<span class="c1"># If dim is not None, return a 0-D TensorRT-LLM tensor of the dimension size</span>
|
||
<div class="viewcode-block" id="shape">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.shape">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">dim</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">cast_to_dtype</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">clip_before_cast</span><span class="p">:</span> <span class="n">Sequence</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to create a shape tensor.</span>
|
||
|
||
<span class="sd"> The shape tensor can either be the shape of the input tensor when the</span>
|
||
<span class="sd"> parameter dim is None or a scalar (tensor of rank 0) that corresponds to</span>
|
||
<span class="sd"> the size of dim-th dimension.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor from which we want to extract the shape or the</span>
|
||
<span class="sd"> size in one dimension.</span>
|
||
|
||
<span class="sd"> dim : Optional[int]</span>
|
||
<span class="sd"> The dimension from which to extract the size. If it is None, the</span>
|
||
<span class="sd"> entire shape of the input tensor is returned.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> A tensor that contains the shape of the input tensor (if 'dim' is None)</span>
|
||
<span class="sd"> or the size in the dimension 'dim' of the input tensor. If 'dim' is</span>
|
||
<span class="sd"> 'None', that tensor has the same rank as the input tensor, otherwise</span>
|
||
<span class="sd"> its rank is 0.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_shape</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="n">res</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">cast_to_dtype</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="n">clip_before_cast</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="p">(</span><span class="n">cast_to_dtype</span> <span class="o">==</span> <span class="s1">'int32'</span>
|
||
<span class="ow">or</span> <span class="n">cast_to_dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">int32</span><span class="p">):</span>
|
||
<span class="k">assert</span> <span class="nb">len</span><span class="p">(</span>
|
||
<span class="n">clip_before_cast</span>
|
||
<span class="p">)</span> <span class="o">==</span> <span class="mi">2</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"This parameter only expects a tuple of 2 integers (lower, upper) but got </span><span class="si">{</span><span class="n">clip_before_cast</span><span class="si">}</span><span class="s2">"</span>
|
||
<span class="n">res</span> <span class="o">=</span> <span class="n">int_clip</span><span class="p">(</span><span class="n">res</span><span class="p">,</span> <span class="n">clip_before_cast</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">clip_before_cast</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
|
||
<span class="n">res</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">res</span><span class="p">,</span> <span class="n">cast_to_dtype</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">dim</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">res</span>
|
||
|
||
<span class="k">return</span> <span class="n">gather</span><span class="p">(</span><span class="n">res</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">indices</span><span class="o">=</span><span class="n">dim</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">([])</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="gather">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.gather">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">gather</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">indices</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">])</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to gather elements from a tensor.</span>
|
||
|
||
<span class="sd"> That function implements the GatherElements operator from the ONNX</span>
|
||
<span class="sd"> specification as described in</span>
|
||
|
||
<span class="sd"> https://github.com/onnx/onnx/blob/main/docs/Operators.md#GatherElements</span>
|
||
|
||
<span class="sd"> The input and indices arguments must have the same rank >= 1. The operation</span>
|
||
<span class="sd"> will produce a tensor with the same shape as the indices tensor. The axis</span>
|
||
<span class="sd"> is the dimension to gather on.</span>
|
||
|
||
<span class="sd"> As shown in the ONNX description, for a 3D tensor, the output is:</span>
|
||
|
||
<span class="sd"> out[i][j][k] = input[indices[i][j][k]][j][k] if axis = 0,</span>
|
||
<span class="sd"> out[i][j][k] = input[i][indices[i][j][k]][k] if axis = 1,</span>
|
||
<span class="sd"> out[i][j][k] = input[i][j][indices[i][j][k]] if axis = 2.</span>
|
||
|
||
<span class="sd"> For example,</span>
|
||
|
||
<span class="sd"> gather([[4, 2], [5, 3]], 0, [[1, 0], [0, 1]])</span>
|
||
|
||
<span class="sd"> will produce [[5, 2], [4, 3]].</span>
|
||
|
||
<span class="sd"> gather([[1, 2, 3], [4, 5, 6], 1, [[1], [0]])</span>
|
||
|
||
<span class="sd"> will produce [[2], [4]]. See the ONNX documentation for more examples.</span>
|
||
|
||
<span class="sd"> That operation maps to the TensorRT IGatherLayer.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor to gather elements from.</span>
|
||
|
||
<span class="sd"> dim : int</span>
|
||
<span class="sd"> The dimension to gather on.</span>
|
||
|
||
<span class="sd"> indices : Union[Tensor, int]</span>
|
||
<span class="sd"> The positions in the 'dim' dimension to gather from.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor containing the gathered elements. It has the same shape as</span>
|
||
<span class="sd"> the indices tensor.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">indices</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
|
||
<span class="n">indices</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">int32_array</span><span class="p">([</span><span class="n">indices</span><span class="p">]))</span>
|
||
|
||
<span class="c1"># The input and indices tensors must have the same rank.</span>
|
||
<span class="k">assert</span> <span class="nb">input</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span> <span class="o">==</span> <span class="n">indices</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_gather_v2</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">indices</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">mode</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">GatherMode</span><span class="o">.</span><span class="n">ELEMENT</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">dim</span> <span class="o"><</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="n">dim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">+</span> <span class="n">dim</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">axis</span> <span class="o">=</span> <span class="n">dim</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="select">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.select">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">select</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">index</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">])</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to select a slice of elements from a tensor.</span>
|
||
|
||
<span class="sd"> Given an input tensor, that function creates an operation that selects the</span>
|
||
<span class="sd"> index-th slice of elements in the dimension 'dim' to create a new tensor.</span>
|
||
<span class="sd"> The output tensor has a shape in which the input dimension 'dim' is</span>
|
||
<span class="sd"> removed.</span>
|
||
|
||
<span class="sd"> The 'index' can either be an integer or a 1D tensor containing a single</span>
|
||
<span class="sd"> element.</span>
|
||
|
||
<span class="sd"> For example, on input=[[4, 2, 5], [2, 1, 2], [4, 7, 1]], which has a shape</span>
|
||
<span class="sd"> [3, 3],</span>
|
||
|
||
<span class="sd"> select(input, 0, 1)</span>
|
||
|
||
<span class="sd"> will create a tensor of shape [3] that contains the [2, 1, 2].</span>
|
||
|
||
<span class="sd"> Regarding the shape of the output tensor, the dimension 'dim' is removed.</span>
|
||
<span class="sd"> It means that for a tensor of shape [4, 2, 6, 3],</span>
|
||
|
||
<span class="sd"> select(input, 2, 4)</span>
|
||
|
||
<span class="sd"> will select the 5th slice (index == 4) from the 3rd dimension (dim == 2)</span>
|
||
<span class="sd"> and return a tensor of shape [4, 2, 3] (i.e. the 3rd dimension is removed).</span>
|
||
|
||
<span class="sd"> That operation maps to the TensorRT IGatherLayer.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor to select from.</span>
|
||
|
||
<span class="sd"> dim : int</span>
|
||
<span class="sd"> The dimension to select from.</span>
|
||
|
||
<span class="sd"> index : Union[Tensor, int]</span>
|
||
<span class="sd"> The index of the slice in the 'dim' dimension to select.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor containing the selected slice.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">index</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
|
||
<span class="n">index</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">int32_array</span><span class="p">([</span><span class="n">index</span><span class="p">]))</span>
|
||
<span class="k">assert</span> <span class="n">index</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">and</span> <span class="n">index</span><span class="o">.</span><span class="n">size</span><span class="p">(</span>
|
||
<span class="mi">0</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"index should have rank 1, got </span><span class="si">{</span><span class="n">index</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span><span class="si">}</span><span class="s2">"</span>
|
||
|
||
<span class="n">new_shape</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">rank</span><span class="p">()):</span>
|
||
<span class="k">if</span> <span class="n">i</span> <span class="o">!=</span> <span class="n">dim</span><span class="p">:</span>
|
||
<span class="n">new_shape</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">i</span><span class="p">))</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_gather</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">index</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">dim</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">concat</span><span class="p">(</span><span class="n">new_shape</span><span class="p">))</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="index_select">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.index_select">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">index_select</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">index</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to select slices of elements from a tensor.</span>
|
||
|
||
<span class="sd"> Given an input tensor, that function creates an operation that selects the</span>
|
||
<span class="sd"> slices of elements in the dimension 'dim' at the indices listed in 'index'</span>
|
||
<span class="sd"> to create a new tensor. The output tensor has the same rank as the input</span>
|
||
<span class="sd"> tensor.</span>
|
||
|
||
<span class="sd"> The 'index' is a tensor of rank 1.</span>
|
||
|
||
<span class="sd"> For example, on input=[[4, 2, 5], [2, 1, 2], [4, 7, 1]], which has a shape</span>
|
||
<span class="sd"> [3, 3],</span>
|
||
|
||
<span class="sd"> index_select(input, 0, [0, 1])</span>
|
||
|
||
<span class="sd"> will create a tensor of shape [2, 3] that contains the [[4, 2, 5], [2, 1, 2]].</span>
|
||
|
||
<span class="sd"> Regarding the shape of the output tensor, the dimension 'dim' has the same</span>
|
||
<span class="sd"> size as the 'index' tensor. It means that for a input tensor of shape [4, 2, 6, 3],</span>
|
||
|
||
<span class="sd"> index_select(input, 2, [1, 4])</span>
|
||
|
||
<span class="sd"> will select the 2nd and 5th slices (index == 1 or 4) from the 3rd dimension</span>
|
||
<span class="sd"> (dim == 2) and return a tensor of shape [4, 2, 2, 3] (i.e. the 3rd</span>
|
||
<span class="sd"> dimension is shrunk to 2).</span>
|
||
|
||
<span class="sd"> Note that this operation can also be used to expand a tensor in the 'dim'</span>
|
||
<span class="sd"> dimension, for example, on input [[0, 1], [2, 3]],</span>
|
||
|
||
<span class="sd"> index_select(input, 1, [0, 0, 0])</span>
|
||
|
||
<span class="sd"> will produce a tensor of shape [2, 3] containing [[0, 0, 0], [2, 2, 2]].</span>
|
||
|
||
<span class="sd"> That operation maps to the TensorRT IGatherLayer.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor to select from.</span>
|
||
|
||
<span class="sd"> dim : int</span>
|
||
<span class="sd"> The dimension to select from.</span>
|
||
|
||
<span class="sd"> index : Tensor</span>
|
||
<span class="sd"> The indices of the slices in the 'dim' dimension to select.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor containing the selected slices.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">assert</span> <span class="n">index</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"index should have rank 1, got </span><span class="si">{</span><span class="n">index</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span><span class="si">}</span><span class="s2">"</span>
|
||
|
||
<span class="n">new_shape</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">rank</span><span class="p">()):</span>
|
||
<span class="k">if</span> <span class="n">i</span> <span class="o">!=</span> <span class="n">dim</span><span class="p">:</span>
|
||
<span class="n">new_shape</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">i</span><span class="p">))</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">new_shape</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">shape</span><span class="p">(</span><span class="n">index</span><span class="p">,</span> <span class="mi">0</span><span class="p">))</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_gather</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">index</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">dim</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">concat</span><span class="p">(</span><span class="n">new_shape</span><span class="p">))</span></div>
|
||
|
||
|
||
|
||
<span class="c1"># NOTE: Jointly added with Apple</span>
|
||
<div class="viewcode-block" id="scatter">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.scatter">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">scatter</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">indices</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">updates</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> This operation adds a layer that creates an output tensor by element-wise</span>
|
||
<span class="sd"> copying values from the input tensor and then updating values by the given</span>
|
||
<span class="sd"> `indices` and `updates` tensors.</span>
|
||
<span class="sd"> For a 2D input tensor, it first copies the input to output,</span>
|
||
<span class="sd"> then updates the output tensor like the following for each entry in `updates`:</span>
|
||
<span class="sd"> output[indices[i][j]][j] = updates[i][j] if dim=0</span>
|
||
<span class="sd"> output[i][indices[i][j]] = updates[i][j] if dim=1</span>
|
||
<span class="sd"> If the `input` tensor is [[1, 2, 3], [4, 5, 6]],</span>
|
||
<span class="sd"> the indices tensor is [[1, 2], [0, 1]],</span>
|
||
<span class="sd"> the updates tensor is [[-1, -2], [-3, -4]], and dim=1</span>
|
||
<span class="sd"> the output tensor will be [[1, -1, -2], [-3, -4, 6]].</span>
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input: Tensor</span>
|
||
<span class="sd"> The input data that needs to be updated.</span>
|
||
<span class="sd"> dim: int</span>
|
||
<span class="sd"> The axis on which the scatter is to be performed.</span>
|
||
<span class="sd"> indices: Tensor</span>
|
||
<span class="sd"> An integer tensor of the same rank as input that indicates the positions to be updated.</span>
|
||
<span class="sd"> updates: Tensor</span>
|
||
<span class="sd"> A data tensor of same shape as the `indices` tensor that contains the update values.</span>
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> A tensor created by the element-wise scatter layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_scatter</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">indices</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">updates</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">mode</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ScatterMode</span><span class="o">.</span><span class="n">ELEMENT</span><span class="p">)</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">axis</span> <span class="o">=</span> <span class="n">dim</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="gather_nd">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.gather_nd">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">gather_nd</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">indices</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">batch_dims</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Adds a layer that performs a gather with some element-wise dimensions.</span>
|
||
<span class="sd"> See: https://onnx.ai/onnx/operators/onnx__GatherND.html</span>
|
||
<span class="sd"> The gather is performed on dim=batch_dims.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input: Tensor</span>
|
||
<span class="sd"> The tensor on which the gather operation is performed.</span>
|
||
<span class="sd"> indices: Tensor</span>
|
||
<span class="sd"> The tensor that indicates which entries to be gathered.</span>
|
||
<span class="sd"> batch_dims: int</span>
|
||
<span class="sd"> The number of first dimensions that should be skipped before gather starts.</span>
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> A tensor created by the gather layer with GatherMode.ND.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">gather_layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_gather_v2</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">indices</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">mode</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">GatherMode</span><span class="o">.</span><span class="n">ND</span><span class="p">)</span>
|
||
<span class="n">gather_layer</span><span class="o">.</span><span class="n">num_elementwise_dims</span> <span class="o">=</span> <span class="n">batch_dims</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">gather_layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">gather_layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="nonzero">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.nonzero">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">nonzero</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Adds a layer that finds the indices of non-zero values of the input tensor.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input: Tensor</span>
|
||
<span class="sd"> The input tensor for which we need to find the indices of non-zero values.</span>
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> A tensor of shape [D, C] where D is the number of dimensions of `input` and</span>
|
||
<span class="sd"> C is the number of non-zero values in it.</span>
|
||
<span class="sd"> Each column of this 2D tensor represents the index tuple for each non-zero value.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">non_zero_layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_non_zero</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">non_zero_layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">non_zero_layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="masked_select">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.masked_select">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">masked_select</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">mask</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to select elements from a tensor according to a boolean</span>
|
||
<span class="sd"> mask tensor.</span>
|
||
|
||
<span class="sd"> Given an input tensor, that function creates an operation that selects</span>
|
||
<span class="sd"> elements at the indices indicated by the boolean mask tensor to create</span>
|
||
<span class="sd"> a new tensor. The output tensor is a 1-D tensor.</span>
|
||
|
||
<span class="sd"> The input tensor must have rank >= 1. The shapes of the input tensor and</span>
|
||
<span class="sd"> the mask tensor don’t need to match, but they must be able to be broadcasted.</span>
|
||
|
||
<span class="sd"> For example, on input=[[4, 2, 5], [2, 1, 2], [4, 7, 1]], which has a shape</span>
|
||
<span class="sd"> [3, 3],</span>
|
||
|
||
<span class="sd"> masked_select(input, [[True, False, True], [False, True, False], [True, False, True]])</span>
|
||
|
||
<span class="sd"> will create a tensor of shape [5] that contains the [4, 5, 1, 4, 1].</span>
|
||
|
||
<span class="sd"> masked_select(input, [[True], [False], [True]])</span>
|
||
|
||
<span class="sd"> will create a tensor of shape [6] that contains the [4, 2, 5, 4, 7, 1].</span>
|
||
|
||
<span class="sd"> masked_select(input, [[False, False, True]])</span>
|
||
|
||
<span class="sd"> will create a tensor of shape [3] that contains the [5, 2, 1].</span>
|
||
|
||
<span class="sd"> masked_select(input, [False])</span>
|
||
|
||
<span class="sd"> will create a tensor of shape [0] which is empty.</span>
|
||
|
||
<span class="sd"> That operation is implemented by NonZero, Shuffle and GatherV2 layers</span>
|
||
<span class="sd"> in TensorRT.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor to select from.</span>
|
||
|
||
<span class="sd"> mask : Tensor</span>
|
||
<span class="sd"> The boolean mask tensor that indicates elements to select.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The 1-D tensor containing the selected elements.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">assert</span> <span class="nb">input</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span> <span class="o">>=</span> <span class="mi">1</span><span class="p">,</span> <span class="s2">"input should have rank >= 1"</span>
|
||
<span class="nb">input</span><span class="p">,</span> <span class="n">mask</span> <span class="o">=</span> <span class="n">broadcast_helper</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">mask</span><span class="p">)</span>
|
||
<span class="n">expanded_mask</span> <span class="o">=</span> <span class="n">expand</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">))</span>
|
||
|
||
<span class="n">non_zero_layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_non_zero</span><span class="p">(</span><span class="n">expanded_mask</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
|
||
<span class="n">shuffle_layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_shuffle</span><span class="p">(</span><span class="n">non_zero_layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">))</span>
|
||
<span class="n">shuffle_layer</span><span class="o">.</span><span class="n">second_transpose</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
|
||
|
||
<span class="n">gather_layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_gather_v2</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">shuffle_layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
|
||
<span class="n">mode</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">GatherMode</span><span class="o">.</span><span class="n">ND</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">gather_layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">gather_layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="cumsum">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.cumsum">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">cumsum</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">prefer_plugin</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to calculate inclusive cumulative sum of elements of</span>
|
||
<span class="sd"> a tensor in a given dimension.</span>
|
||
|
||
<span class="sd"> Given an input tensor, that function creates an operation that calculates</span>
|
||
<span class="sd"> inclusive cumulative sum of elements in the dimension 'dim' to create</span>
|
||
<span class="sd"> a new tensor. The output tensor has the same shape as the input tensor.</span>
|
||
|
||
<span class="sd"> The input tensor must have rank >= 1. The 'dim' must be valid, and negative</span>
|
||
<span class="sd"> value is supported.</span>
|
||
|
||
<span class="sd"> For example, on input=[[4, 2, 5], [2, 1, 2], [4, 7, 1]], which has a shape</span>
|
||
<span class="sd"> [3, 3],</span>
|
||
|
||
<span class="sd"> cumsum(input, 0)</span>
|
||
|
||
<span class="sd"> will produce [[4, 2, 5], [6, 3, 7], [10, 10, 8]].</span>
|
||
|
||
<span class="sd"> cumsum(input, 1)</span>
|
||
|
||
<span class="sd"> will produce [[4, 6, 11], [2, 3, 5], [4, 11, 12]].</span>
|
||
|
||
<span class="sd"> That operation is implemented by TensorRT ILoopLayer.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor to calculate the inclusive cumulative sum.</span>
|
||
|
||
<span class="sd"> dim : int</span>
|
||
<span class="sd"> The dimension to calculate the inclusive cumulative sum. Negative</span>
|
||
<span class="sd"> value is supported.</span>
|
||
|
||
<span class="sd"> prefer_plugin : bool</span>
|
||
<span class="sd"> Whether to use the cumsumLastDim plugin if dim is last dim.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor containing the inclusive cumulative sum of input.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">assert</span> <span class="nb">input</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span> <span class="o">>=</span> <span class="mi">1</span><span class="p">,</span> <span class="s2">"input should have rank >= 1"</span>
|
||
<span class="k">assert</span> <span class="n">dim</span> <span class="o"><</span> <span class="nb">input</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span> <span class="ow">and</span> <span class="n">dim</span> <span class="o">>=</span> <span class="o">-</span><span class="nb">input</span><span class="o">.</span><span class="n">rank</span><span class="p">(</span>
|
||
<span class="p">),</span> <span class="sa">f</span><span class="s2">"dim should be in [</span><span class="si">{</span><span class="o">-</span><span class="nb">input</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span><span class="si">}</span><span class="s2">, </span><span class="si">{</span><span class="nb">input</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span><span class="si">}</span><span class="s2">) when input have rank </span><span class="si">{</span><span class="nb">input</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span><span class="si">}</span><span class="s2">"</span>
|
||
|
||
<span class="n">dim</span> <span class="o">=</span> <span class="n">dim_resolve_negative</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">())[</span><span class="mi">0</span><span class="p">]</span>
|
||
|
||
<span class="k">if</span> <span class="n">dim</span> <span class="o">==</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">-</span> <span class="mi">1</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="n">prefer_plugin</span><span class="p">:</span>
|
||
<span class="n">last_dim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">last_dim</span> <span class="o">==</span> <span class="o">-</span><span class="mi">1</span><span class="p">:</span> <span class="c1"># dynamic?</span>
|
||
<span class="n">last_dim</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="n">old_shape</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
|
||
<span class="n">input_2d</span> <span class="o">=</span> <span class="n">unsqueeze</span><span class="p">(</span>
|
||
<span class="nb">input</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span> <span class="c1"># special handling of rank-1 dynamic tensor</span>
|
||
<span class="k">elif</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">!=</span> <span class="mi">2</span><span class="p">:</span>
|
||
<span class="n">input_2d</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">concat</span><span class="p">([</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">last_dim</span><span class="p">]),</span>
|
||
<span class="n">zero_is_placeholder</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">input_2d</span> <span class="o">=</span> <span class="nb">input</span>
|
||
<span class="n">cumsum_last_dim_plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">(</span>
|
||
<span class="p">)</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span><span class="s1">'CumsumLastDim'</span><span class="p">,</span> <span class="s1">'1'</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">cumsum_last_dim_plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="n">input_length</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"input_length"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">input_2d</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"type_id"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">input_2d</span><span class="o">.</span><span class="n">dtype</span><span class="p">)],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span><span class="n">input_length</span><span class="p">,</span> <span class="n">pf_type</span><span class="p">])</span>
|
||
<span class="n">cumsum_last_dim_plug</span> <span class="o">=</span> <span class="n">cumsum_last_dim_plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span>
|
||
<span class="s2">"cumsum_last_dim"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">input_2d</span><span class="p">]</span>
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">i</span><span class="o">.</span><span class="n">trt_tensor</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">plug_inputs</span><span class="p">]</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span>
|
||
<span class="n">cumsum_last_dim_plug</span><span class="p">)</span>
|
||
<span class="n">_add_plugin_info</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">cumsum_last_dim_plg_creator</span><span class="p">,</span>
|
||
<span class="s2">"cumsum_last_dim"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="n">output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
<span class="n">output</span> <span class="o">=</span> <span class="n">output</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">old_shape</span><span class="p">,</span> <span class="n">zero_is_placeholder</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">output</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="c1"># credit to Apple</span>
|
||
<span class="n">reduction_length</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="n">reduction_range</span> <span class="o">=</span> <span class="n">arange</span><span class="p">(</span><span class="n">constant_to_tensor_</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="s1">'int64'</span><span class="p">,</span>
|
||
<span class="n">to_array</span><span class="o">=</span><span class="kc">False</span><span class="p">),</span>
|
||
<span class="n">reduction_length</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="s1">'int64'</span><span class="p">)</span>
|
||
<span class="n">lower_triangle</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">unsqueeze</span><span class="p">(</span><span class="n">reduction_range</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
|
||
<span class="o"><=</span> <span class="n">unsqueeze</span><span class="p">(</span><span class="n">reduction_range</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="nb">input</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
|
||
<span class="n">output</span> <span class="o">=</span> <span class="nb">sum</span><span class="p">(</span><span class="n">unsqueeze</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="o">-</span><span class="mi">2</span><span class="p">)</span> <span class="o">*</span> <span class="n">lower_triangle</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">output</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">slice_shape</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()):</span>
|
||
<span class="k">if</span> <span class="n">i</span> <span class="o">!=</span> <span class="n">dim</span><span class="p">:</span>
|
||
<span class="n">slice_shape</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">i</span><span class="p">))</span>
|
||
|
||
<span class="n">zero_tensor</span> <span class="o">=</span> <span class="n">constant_to_tensor_</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="nb">input</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">slice_shape</span><span class="p">)</span> <span class="o">></span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="n">zero_tensor</span> <span class="o">=</span> <span class="n">expand_dims</span><span class="p">(</span><span class="n">zero_tensor</span><span class="p">,</span>
|
||
<span class="p">[</span><span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">slice_shape</span><span class="p">))])</span>
|
||
<span class="n">slice_shape</span> <span class="o">=</span> <span class="n">concat</span><span class="p">(</span><span class="n">slice_shape</span><span class="p">)</span>
|
||
<span class="n">zero_tensor</span> <span class="o">=</span> <span class="n">expand</span><span class="p">(</span><span class="n">zero_tensor</span><span class="p">,</span> <span class="n">slice_shape</span><span class="p">)</span>
|
||
|
||
<span class="n">loop_layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_loop</span><span class="p">()</span>
|
||
<span class="n">trip_limit</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">dim</span><span class="p">)</span><span class="o">.</span><span class="n">trt_tensor</span>
|
||
<span class="n">loop_layer</span><span class="o">.</span><span class="n">add_trip_limit</span><span class="p">(</span><span class="n">trip_limit</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">TripLimit</span><span class="o">.</span><span class="n">COUNT</span><span class="p">)</span>
|
||
|
||
<span class="n">iterator_layer</span> <span class="o">=</span> <span class="n">loop_layer</span><span class="o">.</span><span class="n">add_iterator</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">dim</span><span class="p">)</span>
|
||
<span class="n">cur_slice</span> <span class="o">=</span> <span class="n">iterator_layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
||
|
||
<span class="n">running_sum_layer</span> <span class="o">=</span> <span class="n">loop_layer</span><span class="o">.</span><span class="n">add_recurrence</span><span class="p">(</span><span class="n">zero_tensor</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="n">running_sum</span> <span class="o">=</span> <span class="n">running_sum_layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
||
|
||
<span class="n">cur_sum_layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_elementwise</span><span class="p">(</span>
|
||
<span class="n">cur_slice</span><span class="p">,</span> <span class="n">running_sum</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">SUM</span><span class="p">)</span>
|
||
<span class="n">cur_sum</span> <span class="o">=</span> <span class="n">cur_sum_layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
||
<span class="n">running_sum_layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">cur_sum</span><span class="p">)</span>
|
||
|
||
<span class="n">loop_output_layer</span> <span class="o">=</span> <span class="n">loop_layer</span><span class="o">.</span><span class="n">add_loop_output</span><span class="p">(</span>
|
||
<span class="n">cur_sum</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">LoopOutput</span><span class="o">.</span><span class="n">CONCATENATE</span><span class="p">,</span> <span class="n">dim</span><span class="p">)</span>
|
||
<span class="n">loop_output_layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">trip_limit</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">loop_output_layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
|
||
<span class="n">loop_output_layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="masked_scatter">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.masked_scatter">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">masked_scatter</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">mask</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">source</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add the masked_scatter base on PyTorch definition.</span>
|
||
|
||
<span class="sd"> See https://pytorch.org/docs/stable/generated/torch.Tensor.masked_scatter_.html#torch.Tensor.masked_scatter_ for a</span>
|
||
<span class="sd"> description of that function.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor.</span>
|
||
|
||
<span class="sd"> mask : Tensor</span>
|
||
<span class="sd"> The boolean mask tensor that indicates elements to select.</span>
|
||
|
||
<span class="sd"> source: Tensor</span>
|
||
<span class="sd"> The tensor to copy from</span>
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor containing the source tensor selected by mask.</span>
|
||
|
||
<span class="sd"> '''</span>
|
||
<span class="k">assert</span> <span class="nb">input</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span> <span class="o">>=</span> <span class="mi">1</span><span class="p">,</span> <span class="s2">"input should have rank >= 1"</span>
|
||
<span class="nb">input</span><span class="p">,</span> <span class="n">mask</span> <span class="o">=</span> <span class="n">broadcast_helper</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">mask</span><span class="p">)</span>
|
||
<span class="n">expanded_mask</span> <span class="o">=</span> <span class="n">expand</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">))</span>
|
||
|
||
<span class="n">non_zero_layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_non_zero</span><span class="p">(</span><span class="n">expanded_mask</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
|
||
<span class="n">shuffle_layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_shuffle</span><span class="p">(</span><span class="n">non_zero_layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">))</span>
|
||
<span class="n">shuffle_layer</span><span class="o">.</span><span class="n">second_transpose</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
|
||
<span class="n">source</span> <span class="o">=</span> <span class="n">source</span><span class="o">.</span><span class="n">view</span><span class="p">([</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
|
||
|
||
<span class="n">scatter_layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_scatter</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">shuffle_layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
|
||
<span class="n">source</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">mode</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ScatterMode</span><span class="o">.</span><span class="n">ND</span><span class="p">)</span>
|
||
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">scatter_layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">scatter_layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="concat">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.concat">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">concat</span><span class="p">(</span><span class="n">inputs</span><span class="p">:</span> <span class="n">Sequence</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">]],</span> <span class="n">dim</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to concatenate tensors.</span>
|
||
|
||
<span class="sd"> The function creates an operation that concatenates the tensors from the</span>
|
||
<span class="sd"> sequence 'inputs'. The concatenation is done along the dimension 'dim'.</span>
|
||
|
||
<span class="sd"> All the tensors in 'inputs' must have the same shape expect for the</span>
|
||
<span class="sd"> dimension 'dim'.</span>
|
||
|
||
<span class="sd"> for ii in range(inputs[0].rank()):</span>
|
||
<span class="sd"> assert (ii == dim) or all(inp.shape[ii] == inputs[0].shape[ii] for inp in inputs)</span>
|
||
|
||
<span class="sd"> The shape of the output tensor is defined as:</span>
|
||
|
||
<span class="sd"> for ii in range(inputs[0].rank()):</span>
|
||
<span class="sd"> # Same size as all the inputs in dimension ii != dim.</span>
|
||
<span class="sd"> output.shape[ii] = inputs[0].shape[ii]</span>
|
||
|
||
<span class="sd"> # Sum of the sizes in the different inputs in dimension 'dim'.</span>
|
||
<span class="sd"> if ii == dim:</span>
|
||
<span class="sd"> for jj in range(1, len(inputs)):</span>
|
||
<span class="sd"> output.shape[ii] += inputs[jj].shape[ii]</span>
|
||
|
||
<span class="sd"> For example, given a sequence of two 2D tensors [[0, 1], [2, 3]] and</span>
|
||
<span class="sd"> [[4, 5], [6, 7]] both of shape [2, 2],</span>
|
||
|
||
<span class="sd"> concat(inputs, 0)</span>
|
||
|
||
<span class="sd"> will produce [[0, 1], [2, 3], [4, 5], [6, 7]] of shape [4, 2] and</span>
|
||
|
||
<span class="sd"> concat(inputs, 1)</span>
|
||
|
||
<span class="sd"> will produce [[0, 1, 4, 5], [2, 3, 6, 7]] of shape [2, 4].</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> inputs : Sequence[Union[Tensor, int]]</span>
|
||
<span class="sd"> The sequence of tensors to concatenate. For integers, that function</span>
|
||
<span class="sd"> creates constant tensors.</span>
|
||
|
||
<span class="sd"> dim : int</span>
|
||
<span class="sd"> The dimension in which the concatenation is performed.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> A tensor that contains the concatenation of the tensors.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">assert</span> <span class="nb">len</span><span class="p">(</span>
|
||
<span class="n">inputs</span>
|
||
<span class="p">)</span> <span class="o">></span> <span class="mi">0</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"Number of inputs (</span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span><span class="si">}</span><span class="s2">) to the concatenation layer must be > 0."</span>
|
||
<span class="n">tmp</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="n">inputs</span> <span class="o">=</span> <span class="n">constants_to_tensors_</span><span class="p">(</span><span class="o">*</span><span class="n">inputs</span><span class="p">)</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">inputs</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="n">i</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="n">tmp</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">i</span><span class="o">.</span><span class="n">view</span><span class="p">([</span><span class="mi">1</span><span class="p">]))</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">tmp</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">i</span><span class="p">)</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_concatenation</span><span class="p">([</span><span class="n">i</span><span class="o">.</span><span class="n">trt_tensor</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">tmp</span><span class="p">])</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">axis</span> <span class="o">=</span> <span class="n">dim_resolve_negative</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="n">tmp</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">ndim</span><span class="p">())[</span><span class="mi">0</span><span class="p">]</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="softmax">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.softmax">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">softmax</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to compute softmax on a tensor.</span>
|
||
|
||
<span class="sd"> That operation computes the softmax on the input tensor in the dimension</span>
|
||
<span class="sd"> 'dim' if specified. Otherwise, it is applied on the last dimension.</span>
|
||
|
||
<span class="sd"> It inserts a ISoftmaxLayer to the TensorRT graph.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor on which to apply softmax.</span>
|
||
|
||
<span class="sd"> dim : Optional[int]</span>
|
||
<span class="sd"> The dimension used to apply softmax.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The output tensor of the softmax layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="n">dim</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">dim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">-</span> <span class="mi">1</span>
|
||
<span class="k">if</span> <span class="n">dim</span> <span class="o"><</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="n">dim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">+</span> <span class="n">dim</span>
|
||
<span class="n">axes</span> <span class="o">=</span> <span class="n">dim_to_trt_axes</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_softmax</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">axes</span> <span class="o">=</span> <span class="n">axes</span>
|
||
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<span class="k">def</span> <span class="nf">_lookup_plugin</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">weight</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">rank</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">per_token_scale</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to perform lookup in a tensor.</span>
|
||
|
||
<span class="sd"> That operation performs the lookup needed by embedding layers. Given a</span>
|
||
<span class="sd"> 'weight' tensor of shape [rows, cols], it produces a tensor of shape</span>
|
||
<span class="sd"> [inputs.size(0), cols] where the ith row corresponds to the input[i] row in</span>
|
||
<span class="sd"> the weight tensor.</span>
|
||
|
||
<span class="sd"> It inserts a IPluginV2Layer.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor contains the indices to perform the lookup.</span>
|
||
|
||
<span class="sd"> weight : Tensor</span>
|
||
<span class="sd"> The table to gather from.</span>
|
||
|
||
<span class="sd"> rank : int</span>
|
||
<span class="sd"> The mpi rank.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The output tensor of the lookup layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
|
||
<span class="s1">'Lookup'</span><span class="p">,</span> <span class="s1">'1'</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
|
||
<span class="n">p_dtype</span> <span class="o">=</span> <span class="n">per_token_scale</span><span class="o">.</span><span class="n">dtype</span>
|
||
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"type_id"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">)],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
|
||
<span class="n">rank</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"rank"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">rank</span><span class="p">)],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
|
||
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span><span class="n">pf_type</span><span class="p">,</span> <span class="n">rank</span><span class="p">])</span>
|
||
<span class="n">lookup_plug</span> <span class="o">=</span> <span class="n">plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">"lookup"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">weight</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">]</span>
|
||
<span class="k">if</span> <span class="n">per_token_scale</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">per_token_scale</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="n">weight</span><span class="o">.</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">set_dynamic_range</span><span class="p">(</span><span class="o">-</span><span class="mi">127</span><span class="p">,</span> <span class="mi">127</span><span class="p">)</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">lookup_plug</span><span class="p">)</span>
|
||
<span class="n">_add_plugin_info</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">plg_creator</span><span class="p">,</span> <span class="s2">"lookup"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
|
||
|
||
<div class="viewcode-block" id="embedding">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.embedding">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">embedding</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">weight</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">tp_size</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">tp_group</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">sharding_dim</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">tp_rank</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">per_token_scale</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">padding</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to perform embedding lookup.</span>
|
||
|
||
<span class="sd"> That operation performs the embedding lookup. The 'input' tensor contains</span>
|
||
<span class="sd"> the identifiers of the rows of 'weight' to gather.</span>
|
||
|
||
<span class="sd"> 1. Distribute the embedding lookup table over multiple GPU</span>
|
||
<span class="sd"> When 'tp_size' is greater than 1 and the 'tp_group' is defined, this</span>
|
||
<span class="sd"> embedding lookup is distributed among multiple GPUs.</span>
|
||
|
||
<span class="sd"> When 'sharding_dim==0', each GPU stores a subset of the rows of the embedding</span>
|
||
<span class="sd"> table rows(that number of rows per GPU is given by weights.shape[0] and the offset to</span>
|
||
<span class="sd"> the 1st row stored on the GPU is given by rank * weights.shape[0]). Each</span>
|
||
<span class="sd"> parallel rank will query all the indices and set 0s for the weights that</span>
|
||
<span class="sd"> are not stored on the associated GPU. To compute the final result, a</span>
|
||
<span class="sd"> parallel all-reduce operation is added to the TensorRT graph. That lookup</span>
|
||
<span class="sd"> can be performed using either the plugin or the operators TensorRT support.</span>
|
||
|
||
<span class="sd"> When'sharding_dim==1', each GPU stores a subset of the embedding table's columns.</span>
|
||
<span class="sd"> Each rank can obtain a portion of the embedding results.</span>
|
||
<span class="sd"> Then the embedding is collected using the all-gather operation.</span>
|
||
<span class="sd"> Related transposition operations are also used to obtain the final results.</span>
|
||
|
||
<span class="sd"> 2. Store embedding lookup table as a whole</span>
|
||
<span class="sd"> When 'tp_size' is not greater than 1, the embedding lookup table will not</span>
|
||
<span class="sd"> be divided. In this case, when the default_net().plugin_config.lookup_plugin is set,</span>
|
||
<span class="sd"> the operation is implemented using a plugin (without the all-reduce operation).</span>
|
||
<span class="sd"> Otherwise, this operation is implemented using the standard IGatherLayer in TensorRT.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor the contains the indices to perform the lookup.</span>
|
||
|
||
<span class="sd"> weight : Tensor</span>
|
||
<span class="sd"> The table to gather from.</span>
|
||
|
||
<span class="sd"> tp_size : int</span>
|
||
<span class="sd"> The number of GPUs collaborating to perform that embedding.</span>
|
||
|
||
<span class="sd"> tg_group : Optional[List[int]]</span>
|
||
<span class="sd"> The group of world ranks participating in the all-reduce when</span>
|
||
<span class="sd"> tp_size > 1.</span>
|
||
|
||
<span class="sd"> sharding_dim : int</span>
|
||
<span class="sd"> sharding_dim = 0 means that we shard the embedding table in vocab dim;</span>
|
||
<span class="sd"> sharding_dim = 1 means that we shard the embedding table in embedding dim.</span>
|
||
|
||
<span class="sd"> tp_rank : int</span>
|
||
<span class="sd"> The tensor parallelism rank. Used to calculate offset in TP on vocab dim.</span>
|
||
|
||
<span class="sd"> padding: Tensor</span>
|
||
<span class="sd"> Additional padding added to the end of the embedding table before feeding into gather op.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the embedding lookup layer.</span>
|
||
<span class="sd"> '''</span>
|
||
|
||
<span class="c1"># Per token scale is only supported by lookup plugin so if per_token_scale is not None, we must use lookup plugin</span>
|
||
<span class="c1"># Otherwise, we prefer to use ootb</span>
|
||
<span class="n">use_lookup_plugin</span> <span class="o">=</span> <span class="n">per_token_scale</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
|
||
<span class="k">if</span> <span class="n">padding</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">padded_weight</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span><span class="n">weight</span><span class="p">,</span> <span class="n">padding</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">padded_weight</span> <span class="o">=</span> <span class="n">weight</span>
|
||
|
||
<span class="c1"># Distribute embedding lookup table across multiple GPU</span>
|
||
<span class="k">if</span> <span class="n">tp_size</span> <span class="o">></span> <span class="mi">1</span> <span class="ow">and</span> <span class="n">tp_group</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="n">sharding_dim</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span> <span class="c1"># TP on vocab_size dimension</span>
|
||
<span class="k">if</span> <span class="n">tp_rank</span> <span class="o">==</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
|
||
<span class="s2">"Rank cannot be none for tensor parallelism on vocab dim"</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">use_lookup_plugin</span><span class="p">:</span>
|
||
<span class="n">x</span> <span class="o">=</span> <span class="n">_lookup_plugin</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">tp_rank</span><span class="p">,</span> <span class="n">per_token_scale</span><span class="p">)</span>
|
||
<span class="n">x</span> <span class="o">=</span> <span class="n">allreduce</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">tp_group</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">shape_weight</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="n">weight</span><span class="p">)</span>
|
||
<span class="n">vocab_size</span> <span class="o">=</span> <span class="nb">slice</span><span class="p">(</span><span class="n">shape_weight</span><span class="p">,</span> <span class="n">starts</span><span class="o">=</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">sizes</span><span class="o">=</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
|
||
<span class="n">tmp_input</span> <span class="o">=</span> <span class="nb">input</span> <span class="o">-</span> <span class="n">vocab_size</span> <span class="o">*</span> <span class="n">tp_rank</span>
|
||
|
||
<span class="c1"># Identify the valid indices</span>
|
||
<span class="n">is_qualified</span> <span class="o">=</span> <span class="n">op_and</span><span class="p">(</span><span class="n">tmp_input</span> <span class="o">>=</span> <span class="mi">0</span><span class="p">,</span> <span class="n">tmp_input</span> <span class="o"><</span> <span class="n">vocab_size</span><span class="p">)</span>
|
||
<span class="n">is_qualified_expand</span> <span class="o">=</span> <span class="n">expand_dims</span><span class="p">(</span><span class="n">is_qualified</span><span class="p">,</span>
|
||
<span class="p">[</span><span class="n">is_qualified</span><span class="o">.</span><span class="n">ndim</span><span class="p">()])</span>
|
||
|
||
<span class="c1"># Replace the invalid ones to zero</span>
|
||
<span class="n">placeholder_input</span> <span class="o">=</span> <span class="n">where</span><span class="p">(</span><span class="n">is_qualified</span><span class="p">,</span> <span class="n">tmp_input</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
|
||
|
||
<span class="c1"># Get the temporal results</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_gather</span><span class="p">(</span>
|
||
<span class="n">padded_weight</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">placeholder_input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
|
||
<span class="n">tmp_output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
|
||
<span class="c1"># Set zero for invalid results</span>
|
||
<span class="n">placeholder_tmp</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">is_qualified_expand</span><span class="p">,</span> <span class="n">tmp_output</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
|
||
<span class="n">placeholder</span> <span class="o">=</span> <span class="n">placeholder_tmp</span> <span class="o">-</span> <span class="n">placeholder_tmp</span>
|
||
<span class="n">x</span> <span class="o">=</span> <span class="n">where</span><span class="p">(</span><span class="n">is_qualified_expand</span><span class="p">,</span> <span class="n">tmp_output</span><span class="p">,</span> <span class="n">placeholder</span><span class="p">)</span>
|
||
|
||
<span class="c1"># Use all reduce to collect the results</span>
|
||
<span class="n">x</span> <span class="o">=</span> <span class="n">allreduce</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">tp_group</span><span class="p">)</span>
|
||
|
||
<span class="k">elif</span> <span class="n">sharding_dim</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span> <span class="c1"># TP on hidden dimension</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_gather</span><span class="p">(</span><span class="n">padded_weight</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
|
||
<span class="n">x</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
|
||
<span class="c1"># [dim0, local_dim] -> [dim0 * tp_size, local_dim] --> [dim0, local_dim * tp_size]</span>
|
||
<span class="n">x</span> <span class="o">=</span> <span class="n">allgather</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">tp_group</span><span class="p">,</span> <span class="n">gather_dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
|
||
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
|
||
<span class="s1">'Tensor Parallelism only support splitting Embedding lookup along hidden (sharding_dim==1) and vocab (sharding_dim==0) dimensionis'</span>
|
||
<span class="p">)</span>
|
||
|
||
<span class="c1"># Store embedding lookup table as a whole</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="n">use_lookup_plugin</span><span class="p">:</span>
|
||
<span class="n">x</span> <span class="o">=</span> <span class="n">_lookup_plugin</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span>
|
||
<span class="n">padded_weight</span><span class="p">,</span>
|
||
<span class="n">rank</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">per_token_scale</span><span class="o">=</span><span class="n">per_token_scale</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_gather</span><span class="p">(</span><span class="n">padded_weight</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
|
||
<span class="n">x</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">x</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="constant_to_tensor_">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.constant_to_tensor_">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">constant_to_tensor_</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">,</span> <span class="nb">bool</span><span class="p">],</span>
|
||
<span class="n">dtype</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="p">,</span> <span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">to_array</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="n">dtype</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="c1"># deduce the type from the given value</span>
|
||
<span class="c1"># NOTE: bool is a subtype of int, so bool needs to be checked first</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="nb">bool</span><span class="p">):</span>
|
||
<span class="n">dtype</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">bool</span>
|
||
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
|
||
<span class="n">dtype</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">int32</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">dtype</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">float32</span>
|
||
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dtype</span><span class="p">,</span> <span class="nb">str</span><span class="p">):</span>
|
||
<span class="n">dtype</span> <span class="o">=</span> <span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span>
|
||
<span class="n">array_fn_dict</span> <span class="o">=</span> <span class="p">{</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">int64</span><span class="p">:</span> <span class="n">int64_array</span><span class="p">,</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">int32</span><span class="p">:</span> <span class="n">int32_array</span><span class="p">,</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">float32</span><span class="p">:</span> <span class="n">fp32_array</span><span class="p">,</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">float16</span><span class="p">:</span> <span class="n">fp16_array</span><span class="p">,</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">bfloat16</span><span class="p">:</span> <span class="n">bf16_array</span><span class="p">,</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">bool</span><span class="p">:</span> <span class="n">bool_array</span><span class="p">,</span>
|
||
<span class="p">}</span>
|
||
<span class="k">assert</span> <span class="n">dtype</span> <span class="ow">in</span> <span class="n">array_fn_dict</span>
|
||
<span class="k">return</span> <span class="n">constant</span><span class="p">(</span><span class="n">array_fn_dict</span><span class="p">[</span><span class="n">dtype</span><span class="p">]([</span><span class="nb">input</span><span class="p">]</span> <span class="k">if</span> <span class="n">to_array</span> <span class="k">else</span> <span class="nb">input</span><span class="p">))</span>
|
||
|
||
<span class="k">return</span> <span class="nb">input</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="constants_to_tensors_">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.constants_to_tensors_">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">constants_to_tensors_</span><span class="p">(</span>
|
||
<span class="o">*</span><span class="n">inputs</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">])</span> <span class="o">-></span> <span class="n">Tuple</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="o">...</span><span class="p">]:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Helper function to create tensors from multiple inputs.</span>
|
||
|
||
<span class="sd"> For each inputs, that function first creates a constant tensor if the input</span>
|
||
<span class="sd"> is an integer or a float. Then, if any input is int64, it upcasts other</span>
|
||
<span class="sd"> integer inputs to int64.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> inputs : Tuple[Union[Tensor, int, float], ...]</span>
|
||
<span class="sd"> The inputs to create tensors from.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> A tuple of tensors.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">has_int64</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">inputs</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="nb">int</span><span class="p">)</span> <span class="ow">and</span> <span class="p">(</span><span class="n">i</span> <span class="o">>=</span> <span class="mi">2</span><span class="o">**</span><span class="mi">31</span> <span class="ow">or</span> <span class="n">i</span> <span class="o"><</span> <span class="o">-</span><span class="mi">2</span><span class="o">**</span><span class="mi">31</span><span class="p">)</span>\
|
||
<span class="ow">or</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">)</span> <span class="ow">and</span> <span class="n">i</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">int64</span><span class="p">:</span>
|
||
<span class="n">has_int64</span> <span class="o">=</span> <span class="kc">True</span>
|
||
<span class="k">break</span>
|
||
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="n">has_int64</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="nb">tuple</span><span class="p">(</span><span class="n">constant_to_tensor_</span><span class="p">(</span><span class="n">i</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">inputs</span><span class="p">)</span>
|
||
|
||
<span class="n">result</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">inputs</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="nb">int</span><span class="p">)</span> <span class="ow">or</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">)</span> <span class="ow">and</span> <span class="n">i</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">int32</span><span class="p">:</span>
|
||
<span class="n">result</span><span class="o">.</span><span class="n">append</span><span class="p">(</span>
|
||
<span class="n">constant_to_tensor_</span><span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">int64</span> <span class="k">if</span> <span class="n">has_int64</span> <span class="k">else</span> <span class="n">trt</span><span class="o">.</span><span class="n">int32</span><span class="p">))</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">result</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">constant_to_tensor_</span><span class="p">(</span><span class="n">i</span><span class="p">))</span>
|
||
<span class="k">return</span> <span class="nb">tuple</span><span class="p">(</span><span class="n">result</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="broadcast_helper">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.broadcast_helper">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">broadcast_helper</span><span class="p">(</span><span class="n">left</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">],</span>
|
||
<span class="n">right</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">])</span> <span class="o">-></span> <span class="n">Tuple</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">]:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Helper function to perform a broadcast.</span>
|
||
|
||
<span class="sd"> For each input, that function first creates a constant tensor if the input</span>
|
||
<span class="sd"> is an integer or a float. Then, if needed, it expands the smaller tensor to</span>
|
||
<span class="sd"> make sure its rank is the same as the larger one.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> left : Union[Tensor, int, float]</span>
|
||
<span class="sd"> The first input. If that input is an integer or a float, the</span>
|
||
<span class="sd"> function creates a constant tensor.</span>
|
||
|
||
<span class="sd"> right : Union[Tensor, int, float]</span>
|
||
<span class="sd"> The second input. If that input is an integer or a float, the</span>
|
||
<span class="sd"> function creates a constant tensor.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> A pair of tensors of same rank.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">strongly_typed</span><span class="p">:</span>
|
||
<span class="n">left</span> <span class="o">=</span> <span class="n">constant_to_tensor_</span><span class="p">(</span><span class="n">left</span><span class="p">)</span>
|
||
<span class="n">right</span> <span class="o">=</span> <span class="n">constant_to_tensor_</span><span class="p">(</span><span class="n">right</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">left</span> <span class="o">=</span> <span class="n">constant_to_tensor_</span><span class="p">(</span>
|
||
<span class="n">left</span><span class="p">,</span> <span class="n">right</span><span class="o">.</span><span class="n">dtype</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">right</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">)</span> <span class="k">else</span> <span class="kc">None</span><span class="p">)</span>
|
||
<span class="n">right</span> <span class="o">=</span> <span class="n">constant_to_tensor_</span><span class="p">(</span><span class="n">right</span><span class="p">,</span> <span class="n">left</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">left</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span> <span class="o">==</span> <span class="n">right</span><span class="o">.</span><span class="n">rank</span><span class="p">():</span>
|
||
<span class="k">return</span> <span class="p">(</span><span class="n">left</span><span class="p">,</span> <span class="n">right</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">left</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span> <span class="o"><</span> <span class="n">right</span><span class="o">.</span><span class="n">rank</span><span class="p">():</span>
|
||
<span class="n">left</span> <span class="o">=</span> <span class="n">expand_dims_like</span><span class="p">(</span><span class="n">left</span><span class="p">,</span> <span class="n">right</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="p">(</span><span class="n">left</span><span class="p">,</span> <span class="n">right</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">left</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span> <span class="o">></span> <span class="n">right</span><span class="o">.</span><span class="n">rank</span><span class="p">():</span>
|
||
<span class="n">right</span> <span class="o">=</span> <span class="n">expand_dims_like</span><span class="p">(</span><span class="n">right</span><span class="p">,</span> <span class="n">left</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="p">(</span><span class="n">left</span><span class="p">,</span> <span class="n">right</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="elementwise_binary">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.elementwise_binary">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">elementwise_binary</span><span class="p">(</span><span class="n">left</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="nb">float</span><span class="p">],</span> <span class="n">right</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">],</span>
|
||
<span class="n">op</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an elementwise operation with two inputs.</span>
|
||
|
||
<span class="sd"> For each input, that function first creates a constant tensor if the input</span>
|
||
<span class="sd"> is an integer or a float. Then, if needed, it expands the smaller tensor to</span>
|
||
<span class="sd"> make sure its rank is the same as the larger one. Then, it performs the</span>
|
||
<span class="sd"> elementwise operation 'op'.</span>
|
||
|
||
<span class="sd"> The following closures are defined in functional.*:</span>
|
||
|
||
<span class="sd"> add for op=trt.ElementWiseOperation.SUM</span>
|
||
<span class="sd"> sub for op=trt.ElementWiseOperation.SUB</span>
|
||
<span class="sd"> mul for op=trt.ElementWiseOperation.PROD</span>
|
||
<span class="sd"> div for op=trt.ElementWiseOperation.DIV</span>
|
||
<span class="sd"> floordiv for op=trt.ElementWiseOperation.FLOOR_DIV</span>
|
||
<span class="sd"> gt for op=trt.ElementWiseOperation.GREATER</span>
|
||
<span class="sd"> lt for op=trt.ElementWiseOperation.LESS</span>
|
||
<span class="sd"> op_and for op=trt.ElementWiseOperation.AND</span>
|
||
<span class="sd"> op_or for op=trt.ElementWiseOperation.OR</span>
|
||
<span class="sd"> eq for op=trt.ElementWiseOperation.EQUAL</span>
|
||
<span class="sd"> minimum for op=trt.ElementWiseOperation.MIN</span>
|
||
<span class="sd"> maximum for op=trt.ElementWiseOperation.MAX</span>
|
||
<span class="sd"> pow for op=trt.ElementWiseOperation.POW</span>
|
||
|
||
<span class="sd"> It is implemented using the IElementWiseLayer from TensorRT.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> left : Union[Tensor, int, float]</span>
|
||
<span class="sd"> The first input. If that input is an integer or a float, the</span>
|
||
<span class="sd"> function creates a constant tensor.</span>
|
||
|
||
<span class="sd"> right : Union[Tensor, int, float]</span>
|
||
<span class="sd"> The second input. If that input is an integer or a float, the</span>
|
||
<span class="sd"> function creates a constant tensor.</span>
|
||
|
||
<span class="sd"> op : trt.ElementWiseOperation</span>
|
||
<span class="sd"> The binary operation to perform.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by this elementwise operation.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">left</span><span class="p">,</span> <span class="n">right</span> <span class="o">=</span> <span class="n">broadcast_helper</span><span class="p">(</span><span class="n">left</span><span class="p">,</span> <span class="n">right</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">left</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">int32</span> <span class="ow">and</span> <span class="n">right</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">int64</span><span class="p">:</span>
|
||
<span class="n">left</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">left</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">int64</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">left</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">int64</span> <span class="ow">and</span> <span class="n">right</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">int32</span><span class="p">:</span>
|
||
<span class="n">right</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">right</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">int64</span><span class="p">)</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_elementwise</span><span class="p">(</span><span class="n">left</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">right</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">op</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<span class="n">add</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">SUM</span><span class="p">)</span>
|
||
<span class="n">sub</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">SUB</span><span class="p">)</span>
|
||
<span class="n">mul</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">PROD</span><span class="p">)</span>
|
||
<span class="n">div</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">DIV</span><span class="p">)</span>
|
||
<span class="n">floordiv</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">FLOOR_DIV</span><span class="p">)</span>
|
||
<span class="n">gt</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">GREATER</span><span class="p">)</span>
|
||
<span class="n">lt</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">LESS</span><span class="p">)</span>
|
||
<span class="n">op_and</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">AND</span><span class="p">)</span>
|
||
<span class="n">op_or</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">OR</span><span class="p">)</span>
|
||
<span class="n">eq</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">EQUAL</span><span class="p">)</span>
|
||
<span class="n">minimum</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">MIN</span><span class="p">)</span>
|
||
<span class="n">maximum</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">MAX</span><span class="p">)</span>
|
||
<span class="nb">pow</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">POW</span><span class="p">)</span>
|
||
|
||
|
||
<div class="viewcode-block" id="modulo">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.modulo">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">modulo</span><span class="p">(</span><span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">y</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">])</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> This function adds an element-wise modulo (x % y) operation for a given tensor.</span>
|
||
<span class="sd"> Since there is no TensorRT layer that can directly perform this,</span>
|
||
<span class="sd"> this function implements it using some of the basic operations.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> A tensor that represents (x % y) modulo operation.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">x</span> <span class="o">-</span> <span class="p">(</span><span class="n">x</span> <span class="o">//</span> <span class="n">y</span><span class="p">)</span> <span class="o">*</span> <span class="n">y</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="where">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.where">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">where</span><span class="p">(</span><span class="n">condition</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">bool</span><span class="p">],</span> <span class="n">left</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">],</span>
|
||
<span class="n">right</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">])</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add a where (aka select or if-then-else) operation.</span>
|
||
|
||
<span class="sd"> Assuming the three input parameters have the same shape, that function creates</span>
|
||
<span class="sd"> the operation to compute a tensor of the same shape such that:</span>
|
||
|
||
<span class="sd"> for ii in range(mul(condition.shape)):</span>
|
||
<span class="sd"> output[ii] = left[ii] if condition[ii] else right[ii]</span>
|
||
|
||
<span class="sd"> For each input, that function first creates a constant tensor if the</span>
|
||
<span class="sd"> condition is boolean or the left/right input is an integer or a float.</span>
|
||
<span class="sd"> Then, if needed, it expands the smaller tensor to make sure its</span>
|
||
<span class="sd"> rank is the same as the larger one. Then, it performs the selection.</span>
|
||
|
||
<span class="sd"> It is implemented using the ISelectLayer from TensorRT.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> condition : Union[Tensor, bool]</span>
|
||
<span class="sd"> The condition. If that input is a boolean, the function</span>
|
||
<span class="sd"> creates a constant tensor.</span>
|
||
|
||
<span class="sd"> left : Union[Tensor, int, float]</span>
|
||
<span class="sd"> The first input. If that input is an integer or a float, the</span>
|
||
<span class="sd"> function creates a constant tensor.</span>
|
||
|
||
<span class="sd"> right : Union[Tensor, int, float]</span>
|
||
<span class="sd"> The second input. If that input is an integer or a float, the</span>
|
||
<span class="sd"> function creates a constant tensor.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by this where operation.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="c1"># Convert to tensors.</span>
|
||
<span class="n">condition</span> <span class="o">=</span> <span class="n">constant_to_tensor_</span><span class="p">(</span><span class="n">condition</span><span class="p">)</span>
|
||
<span class="n">left</span><span class="p">,</span> <span class="n">right</span> <span class="o">=</span> <span class="n">constants_to_tensors_</span><span class="p">(</span><span class="n">left</span><span class="p">,</span> <span class="n">right</span><span class="p">)</span>
|
||
|
||
<span class="c1"># Find the tensor with the largest rank of the three.</span>
|
||
<span class="n">largest</span> <span class="o">=</span> <span class="n">condition</span>
|
||
<span class="k">if</span> <span class="n">largest</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span> <span class="o"><</span> <span class="n">left</span><span class="o">.</span><span class="n">rank</span><span class="p">():</span>
|
||
<span class="n">largest</span> <span class="o">=</span> <span class="n">left</span>
|
||
<span class="k">if</span> <span class="n">largest</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span> <span class="o"><</span> <span class="n">right</span><span class="o">.</span><span class="n">rank</span><span class="p">():</span>
|
||
<span class="n">largest</span> <span class="o">=</span> <span class="n">right</span>
|
||
|
||
<span class="c1"># Expand the tensors to match the largest one.</span>
|
||
<span class="k">if</span> <span class="n">condition</span> <span class="ow">is</span> <span class="ow">not</span> <span class="n">largest</span><span class="p">:</span>
|
||
<span class="n">condition</span> <span class="o">=</span> <span class="n">expand_dims_like</span><span class="p">(</span><span class="n">condition</span><span class="p">,</span> <span class="n">largest</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">left</span> <span class="ow">is</span> <span class="ow">not</span> <span class="n">largest</span><span class="p">:</span>
|
||
<span class="n">left</span> <span class="o">=</span> <span class="n">expand_dims_like</span><span class="p">(</span><span class="n">left</span><span class="p">,</span> <span class="n">largest</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">right</span> <span class="ow">is</span> <span class="ow">not</span> <span class="n">largest</span><span class="p">:</span>
|
||
<span class="n">right</span> <span class="o">=</span> <span class="n">expand_dims_like</span><span class="p">(</span><span class="n">right</span><span class="p">,</span> <span class="n">largest</span><span class="p">)</span>
|
||
|
||
<span class="c1"># Insert the operation.</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_select</span><span class="p">(</span><span class="n">condition</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">left</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">right</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="unary">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.unary">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">unary</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">op</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">UnaryOperation</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an elementwise operation on a single input.</span>
|
||
|
||
<span class="sd"> The following closures are defined in functional.*:</span>
|
||
|
||
<span class="sd"> round for op=trt.UnaryOperation.ROUND</span>
|
||
<span class="sd"> sqrt for op=trt.UnaryOperation.SQRT</span>
|
||
<span class="sd"> exp for op=trt.UnaryOperation.EXP</span>
|
||
<span class="sd"> sin for op=trt.UnaryOperation.SIN</span>
|
||
<span class="sd"> cos for op=trt.UnaryOperation.COS</span>
|
||
<span class="sd"> abs for op=trt.UnaryOperation.ABS</span>
|
||
<span class="sd"> log for op=trt.UnaryOperation.LOG</span>
|
||
|
||
<span class="sd"> It is implemented using the IUnaryLayer from TensorRT.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor.</span>
|
||
|
||
<span class="sd"> op : trt.UnaryOperation</span>
|
||
<span class="sd"> The unary operation to perform.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by this elementwise operation.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_unary</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">op</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<span class="nb">round</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">unary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">UnaryOperation</span><span class="o">.</span><span class="n">ROUND</span><span class="p">)</span>
|
||
<span class="n">sqrt</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">unary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">UnaryOperation</span><span class="o">.</span><span class="n">SQRT</span><span class="p">)</span>
|
||
<span class="n">exp</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">unary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">UnaryOperation</span><span class="o">.</span><span class="n">EXP</span><span class="p">)</span>
|
||
<span class="n">sin</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">unary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">UnaryOperation</span><span class="o">.</span><span class="n">SIN</span><span class="p">)</span>
|
||
<span class="n">cos</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">unary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">UnaryOperation</span><span class="o">.</span><span class="n">COS</span><span class="p">)</span>
|
||
<span class="nb">abs</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">unary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">UnaryOperation</span><span class="o">.</span><span class="n">ABS</span><span class="p">)</span>
|
||
<span class="n">log</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">unary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">UnaryOperation</span><span class="o">.</span><span class="n">LOG</span><span class="p">)</span>
|
||
<span class="n">not_op</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">unary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">UnaryOperation</span><span class="o">.</span><span class="n">NOT</span><span class="p">)</span>
|
||
|
||
|
||
<div class="viewcode-block" id="log_softmax">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.log_softmax">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">log_softmax</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> This function is equivalent of torch.nn.functional.log_softmax() i.e.</span>
|
||
<span class="sd"> it performs log(softmax(input)) in a safer and faster way.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input: Tensor</span>
|
||
<span class="sd"> The data tensor on which log_softmax to be computed.</span>
|
||
<span class="sd"> dim: int</span>
|
||
<span class="sd"> The dimension of the input tensor along which log_softmax will be computed.</span>
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> A tensor of same shape as input with log_softmax computed on the specified dim.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">x_max</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="n">dim</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||
<span class="n">x</span> <span class="o">=</span> <span class="nb">input</span> <span class="o">-</span> <span class="n">x_max</span>
|
||
<span class="k">return</span> <span class="n">x</span> <span class="o">-</span> <span class="n">log</span><span class="p">(</span><span class="nb">sum</span><span class="p">(</span><span class="n">exp</span><span class="p">(</span><span class="n">x</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="n">dim</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">))</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="reduce">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.reduce">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">reduce</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">op</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">ReduceOperation</span><span class="p">,</span>
|
||
<span class="n">dim</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">]],</span>
|
||
<span class="n">keepdim</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an reduction operation to do along a dimension.</span>
|
||
|
||
<span class="sd"> It is implemented using the IReduceLayer from TensorRT.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor.</span>
|
||
|
||
<span class="sd"> op : trt.ReduceOperation</span>
|
||
<span class="sd"> The reduction operation to perform.</span>
|
||
<span class="sd"> Options: SUM, PROD, MAX, MIN, AVG</span>
|
||
|
||
<span class="sd"> dim : int</span>
|
||
<span class="sd"> The dimension along which the reduction is performed.</span>
|
||
|
||
<span class="sd"> keepdim : bool</span>
|
||
<span class="sd"> Is the dimension kept in the reduced tensor? When True the</span>
|
||
<span class="sd"> dimension is kept, it is removed from the shape otherwise.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by this reduction operation.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">dim</span> <span class="o">=</span> <span class="n">dim_resolve_negative</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">())</span>
|
||
<span class="n">axes</span> <span class="o">=</span> <span class="n">dim_to_trt_axes</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_reduce</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">op</span><span class="p">,</span>
|
||
<span class="n">axes</span><span class="p">,</span>
|
||
<span class="n">keep_dims</span><span class="o">=</span><span class="n">keepdim</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<span class="n">prod</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">reduce</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ReduceOperation</span><span class="o">.</span><span class="n">PROD</span><span class="p">)</span>
|
||
<span class="nb">min</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">reduce</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ReduceOperation</span><span class="o">.</span><span class="n">MIN</span><span class="p">)</span>
|
||
|
||
|
||
<div class="viewcode-block" id="mean">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.mean">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">mean</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">dim</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">]],</span>
|
||
<span class="n">keepdim</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to compute the mean along a dimension.</span>
|
||
|
||
<span class="sd"> Computes the mean along the dimension 'dim' of the input tensor.</span>
|
||
|
||
<span class="sd"> It is implemented using the IReduceLayer from TensorRT.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor.</span>
|
||
|
||
<span class="sd"> dim : int</span>
|
||
<span class="sd"> The dimension along which the mean is computed.</span>
|
||
|
||
<span class="sd"> keepdim : bool</span>
|
||
<span class="sd"> Is the dimension kept in the reduced tensor? When True the</span>
|
||
<span class="sd"> dimension is kept, it is removed from the shape otherwise.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by this reduction operation.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">reduce</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ReduceOperation</span><span class="o">.</span><span class="n">AVG</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="n">dim</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="n">keepdim</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="max">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.max">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">max</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">keepdim</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to compute the max along a dimension.</span>
|
||
|
||
<span class="sd"> Computes the max along the dimension 'dim' of the input tensor.</span>
|
||
|
||
<span class="sd"> It is implemented using the IReduceLayer from TensorRT.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor.</span>
|
||
|
||
<span class="sd"> dim : int</span>
|
||
<span class="sd"> The dimension along which the mean is computed.</span>
|
||
|
||
<span class="sd"> keepdim : bool</span>
|
||
<span class="sd"> Is the dimension kept in the reduced tensor? When True the</span>
|
||
<span class="sd"> dimension is kept, it is removed from the shape otherwise.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by this reduction operation.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">reduce</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ReduceOperation</span><span class="o">.</span><span class="n">MAX</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="n">dim</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="n">keepdim</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="sum">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.sum">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">sum</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">keepdim</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to compute the sum along a dimension.</span>
|
||
|
||
<span class="sd"> Computes the sum along the dimension 'dim' of the input tensor.</span>
|
||
|
||
<span class="sd"> It is implemented using the IReduceLayer from TensorRT.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor.</span>
|
||
|
||
<span class="sd"> dim : int</span>
|
||
<span class="sd"> The dimension along which the mean is computed.</span>
|
||
|
||
<span class="sd"> keepdim : bool</span>
|
||
<span class="sd"> Is the dimension kept in the reduced tensor? When True the</span>
|
||
<span class="sd"> dimension is kept, it is removed from the shape otherwise.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by this reduction operation.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">reduce</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ReduceOperation</span><span class="o">.</span><span class="n">SUM</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="n">dim</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="n">keepdim</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="identity">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.identity">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">identity</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an identity operation.</span>
|
||
|
||
<span class="sd"> TODO: Document why it can be done using a plugin!!!</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by this identity operation.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">identity_plugin</span><span class="p">:</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_identity</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
|
||
<span class="s1">'Identity'</span><span class="p">,</span> <span class="s1">'1'</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">()</span>
|
||
<span class="n">id_plug</span> <span class="o">=</span> <span class="n">plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">"identity"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">]</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">id_plug</span><span class="p">)</span>
|
||
<span class="n">_add_plugin_info</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">plg_creator</span><span class="p">,</span> <span class="s2">"identity"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="argmax">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.argmax">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">argmax</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">keepdim</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an argmax operation.</span>
|
||
|
||
<span class="sd"> As explained in the ONNX documentation,</span>
|
||
|
||
<span class="sd"> https://github.com/onnx/onnx/blob/main/docs/Operators.md#argmax</span>
|
||
|
||
<span class="sd"> that function creates a layer computing the indices of the max elements of</span>
|
||
<span class="sd"> the input tensor's element along the provided dim. The resulting tensor</span>
|
||
<span class="sd"> has the same rank as the input if keepdims is True. If keepdims is False,</span>
|
||
<span class="sd"> then the resulting tensor has the reduced dimension pruned.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor.</span>
|
||
|
||
<span class="sd"> dim : int</span>
|
||
<span class="sd"> The dimension in which to compute the argmax indices.</span>
|
||
|
||
<span class="sd"> keepdim : bool</span>
|
||
<span class="sd"> Do we keep the dimension along which the reduction is performed?</span>
|
||
<span class="sd"> Yes, if set to True, no otherwise.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by this argmax operation.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">dim</span> <span class="o">=</span> <span class="n">dim_resolve_negative</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">())</span>
|
||
<span class="n">axes</span> <span class="o">=</span> <span class="n">dim_to_trt_axes</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_topk</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">TopKOperation</span><span class="o">.</span><span class="n">MAX</span><span class="p">,</span>
|
||
<span class="mi">1</span><span class="p">,</span> <span class="n">axes</span><span class="p">)</span>
|
||
<span class="n">output</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">keepdim</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">layer</span><span class="p">)</span>
|
||
|
||
<span class="n">output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">layer</span><span class="p">)</span>
|
||
<span class="n">a</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()))</span>
|
||
<span class="k">for</span> <span class="n">d</span> <span class="ow">in</span> <span class="n">dim</span><span class="p">:</span>
|
||
<span class="n">a</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="n">d</span><span class="p">)</span>
|
||
<span class="n">indices</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">int32_array</span><span class="p">(</span><span class="n">a</span><span class="p">))</span>
|
||
<span class="n">output_shape</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="n">output</span><span class="p">)</span>
|
||
<span class="n">new_shape</span> <span class="o">=</span> <span class="n">gather</span><span class="p">(</span><span class="n">output_shape</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">indices</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">view</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">new_shape</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="gelu">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.gelu">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">gelu</span><span class="p">(</span><span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add a GELU operation.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor on which the activation function is applied.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the activation layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="n">x</span> <span class="o">*</span> <span class="p">(</span>
|
||
<span class="n">tanh</span><span class="p">(</span><span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="mf">2.0</span> <span class="o">/</span> <span class="n">math</span><span class="o">.</span><span class="n">pi</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="n">x</span> <span class="o">+</span> <span class="mf">0.044715</span> <span class="o">*</span> <span class="nb">pow</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mf">3.0</span><span class="p">)))</span> <span class="o">+</span> <span class="mf">1.0</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="geglu">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.geglu">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">geglu</span><span class="p">(</span><span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add a Gated-GELU operation.</span>
|
||
|
||
<span class="sd"> That function takes a tensor, splits it into two halves along the last</span>
|
||
<span class="sd"> dimension, applies GELU to the second half and multiply the results. The</span>
|
||
<span class="sd"> behavior is undefined if the last dimension is not even.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor on which the activation function is applied.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the activation layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">a</span><span class="p">,</span> <span class="n">b</span> <span class="o">=</span> <span class="n">chunk</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">a</span> <span class="o">*</span> <span class="n">gelu</span><span class="p">(</span><span class="n">b</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="quick_gelu">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.quick_gelu">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">quick_gelu</span><span class="p">(</span><span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">x</span> <span class="o">*</span> <span class="n">sigmoid</span><span class="p">(</span><span class="mf">1.702</span> <span class="o">*</span> <span class="n">x</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="gegelu">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.gegelu">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">gegelu</span><span class="p">(</span><span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">limit</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="c1"># a, b = x[..., ::2], x[..., 1::2]</span>
|
||
<span class="n">ndim</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
|
||
<span class="n">a_starts</span> <span class="o">=</span> <span class="p">[</span><span class="mi">0</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)]</span>
|
||
<span class="n">b_starts</span> <span class="o">=</span> <span class="p">[</span><span class="mi">1</span> <span class="k">if</span> <span class="n">i</span> <span class="o">==</span> <span class="p">(</span><span class="n">ndim</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="k">else</span> <span class="mi">0</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)]</span>
|
||
<span class="n">shapes</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span>
|
||
<span class="n">shape</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">i</span><span class="p">)</span> <span class="o">/</span> <span class="mi">2</span> <span class="k">if</span> <span class="n">i</span> <span class="o">==</span> <span class="p">(</span><span class="n">ndim</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="k">else</span> <span class="n">shape</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">i</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)</span>
|
||
<span class="p">])</span>
|
||
<span class="n">strides</span> <span class="o">=</span> <span class="p">[</span><span class="mi">2</span> <span class="k">if</span> <span class="n">i</span> <span class="o">==</span> <span class="p">(</span><span class="n">ndim</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="k">else</span> <span class="mi">1</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)]</span>
|
||
|
||
<span class="n">a</span> <span class="o">=</span> <span class="nb">slice</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">a_starts</span><span class="p">,</span> <span class="n">shapes</span><span class="p">,</span> <span class="n">strides</span><span class="p">)</span>
|
||
<span class="n">b</span> <span class="o">=</span> <span class="nb">slice</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">b_starts</span><span class="p">,</span> <span class="n">shapes</span><span class="p">,</span> <span class="n">strides</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">limit</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">a</span> <span class="o">=</span> <span class="n">clip</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="nb">float</span><span class="p">(</span><span class="o">-</span><span class="mf">1e20</span><span class="p">),</span> <span class="n">beta</span><span class="o">=</span><span class="n">limit</span><span class="p">)</span>
|
||
<span class="n">b</span> <span class="o">=</span> <span class="n">clip</span><span class="p">(</span><span class="n">b</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=-</span><span class="n">limit</span><span class="p">,</span> <span class="n">beta</span><span class="o">=</span><span class="n">limit</span><span class="p">)</span>
|
||
|
||
<span class="c1"># C = B + 1</span>
|
||
<span class="n">const1</span> <span class="o">=</span> <span class="n">arange</span><span class="p">(</span><span class="n">constant</span><span class="p">(</span><span class="n">int32_array</span><span class="p">(</span><span class="mi">1</span><span class="p">)),</span> <span class="n">constant</span><span class="p">(</span><span class="n">int32_array</span><span class="p">(</span><span class="mi">2</span><span class="p">)),</span>
|
||
<span class="n">trt_dtype_to_str</span><span class="p">(</span><span class="n">b</span><span class="o">.</span><span class="n">dtype</span><span class="p">))</span>
|
||
<span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span> <span class="o">-</span> <span class="mi">1</span><span class="p">):</span>
|
||
<span class="n">const1</span> <span class="o">=</span> <span class="n">expand_dims</span><span class="p">(</span><span class="n">const1</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
|
||
|
||
<span class="n">b_shape</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span><span class="n">shape</span><span class="p">(</span><span class="n">b</span><span class="p">,</span> <span class="n">i</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)])</span>
|
||
<span class="n">const1_arr</span> <span class="o">=</span> <span class="n">expand</span><span class="p">(</span><span class="n">const1</span><span class="p">,</span> <span class="n">b_shape</span><span class="p">)</span>
|
||
|
||
<span class="k">return</span> <span class="n">quick_gelu</span><span class="p">(</span><span class="n">a</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="n">b</span> <span class="o">+</span> <span class="n">const1_arr</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="group_norm">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.group_norm">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">group_norm</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">num_groups</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">weight</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">bias</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">eps</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-05</span><span class="p">):</span>
|
||
|
||
<span class="c1">##</span>
|
||
<span class="c1">## TODO: Document that function!</span>
|
||
<span class="c1">##</span>
|
||
|
||
<span class="k">assert</span> <span class="ow">not</span> <span class="nb">input</span><span class="o">.</span><span class="n">is_dynamic</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="n">num_channels</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="mi">1</span><span class="p">]</span>
|
||
|
||
<span class="n">ndim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
|
||
<span class="n">old_shape</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">)</span>
|
||
<span class="n">new_shape</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span>
|
||
<span class="nb">input</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
|
||
<span class="n">num_groups</span><span class="p">,</span>
|
||
<span class="n">num_channels</span> <span class="o">//</span> <span class="n">num_groups</span><span class="p">,</span>
|
||
<span class="p">]</span> <span class="o">+</span> <span class="p">[</span><span class="nb">input</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="n">i</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">ndim</span><span class="p">)])</span>
|
||
<span class="n">x</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">new_shape</span><span class="p">)</span>
|
||
|
||
<span class="c1"># instance norm</span>
|
||
<span class="n">w_shape</span> <span class="o">=</span> <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="n">num_groups</span><span class="p">]</span> <span class="o">+</span> <span class="p">[</span><span class="mi">1</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)]</span>
|
||
<span class="n">instance_weight</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">w_shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">trt_dtype_to_np</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">dtype</span><span class="p">)))</span>
|
||
<span class="n">instance_bias</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">w_shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">trt_dtype_to_np</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">dtype</span><span class="p">)))</span>
|
||
<span class="n">axes_mask</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">x</span><span class="o">.</span><span class="n">ndim</span><span class="p">()):</span>
|
||
<span class="n">axes_mask</span> <span class="o">|=</span> <span class="mi">1</span> <span class="o"><<</span> <span class="n">i</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_normalization</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">instance_weight</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">instance_bias</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">axes_mask</span><span class="p">)</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">epsilon</span> <span class="o">=</span> <span class="n">eps</span>
|
||
<span class="n">y</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
<span class="n">y</span> <span class="o">=</span> <span class="n">y</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">old_shape</span><span class="p">)</span>
|
||
|
||
<span class="n">new_shape</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span><span class="n">num_channels</span><span class="p">]</span> <span class="o">+</span> <span class="p">[</span><span class="mi">1</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">ndim</span><span class="p">)])</span>
|
||
<span class="k">if</span> <span class="n">weight</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">y</span> <span class="o">=</span> <span class="n">y</span> <span class="o">*</span> <span class="n">weight</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">new_shape</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">y</span> <span class="o">=</span> <span class="n">y</span> <span class="o">+</span> <span class="n">bias</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">new_shape</span><span class="p">)</span>
|
||
|
||
<span class="k">return</span> <span class="n">y</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="softplus">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.softplus">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">softplus</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">beta</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">threshold</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add the softplus activation base on PyTorch definition.</span>
|
||
|
||
<span class="sd"> See https://pytorch.org/docs/stable/generated/torch.nn.functional.softplus.html for a</span>
|
||
<span class="sd"> description of that function.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> Input TensorRT-LLM Tensor.</span>
|
||
<span class="sd"> beta : float</span>
|
||
<span class="sd"> The parameter for softplus computation.</span>
|
||
<span class="sd"> threshold : float</span>
|
||
<span class="sd"> The threshold for reverting to the linear function when input * beta > threshold</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The output tensor created by that layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">sf_layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_activation</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">ActivationType</span><span class="o">.</span><span class="n">SOFTPLUS</span><span class="p">)</span>
|
||
<span class="n">sf_layer</span><span class="o">.</span><span class="n">alpha</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">/</span> <span class="n">beta</span>
|
||
<span class="n">sf_layer</span><span class="o">.</span><span class="n">beta</span> <span class="o">=</span> <span class="n">beta</span>
|
||
|
||
<span class="n">prod_tensor</span> <span class="o">=</span> <span class="nb">input</span> <span class="o">*</span> <span class="n">beta</span>
|
||
<span class="n">result</span> <span class="o">=</span> <span class="n">prod_tensor</span> <span class="o">></span> <span class="n">threshold</span>
|
||
|
||
<span class="k">return</span> <span class="n">where</span><span class="p">(</span><span class="n">result</span><span class="p">,</span> <span class="nb">input</span><span class="p">,</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">sf_layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
|
||
<span class="n">sf_layer</span><span class="p">))</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="outer">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.outer">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">outer</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">vec2</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to compute the outer product between two tensors.</span>
|
||
|
||
<span class="sd"> That operation creates an Einsum node.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The first input tensor.</span>
|
||
|
||
<span class="sd"> vec2 : Tensor</span>
|
||
<span class="sd"> The second input tensor.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The output tensor produced by this layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">einsum</span><span class="p">(</span><span class="s1">'i,j->ij'</span><span class="p">,</span> <span class="p">[</span><span class="nb">input</span><span class="p">,</span> <span class="n">vec2</span><span class="p">])</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="avg_pool2d">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.avg_pool2d">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">avg_pool2d</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">kernel_size</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span>
|
||
<span class="n">stride</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">padding</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span>
|
||
<span class="n">ceil_mode</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">count_include_pad</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
|
||
<span class="c1">##</span>
|
||
<span class="c1">## TODO: Document that function!</span>
|
||
<span class="c1">##</span>
|
||
|
||
<span class="k">assert</span> <span class="ow">not</span> <span class="nb">input</span><span class="o">.</span><span class="n">is_dynamic</span><span class="p">()</span>
|
||
<span class="n">ndim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
|
||
<span class="k">if</span> <span class="n">ndim</span> <span class="o">==</span> <span class="mi">3</span><span class="p">:</span>
|
||
<span class="nb">input</span> <span class="o">=</span> <span class="n">expand_dims</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_pooling_nd</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PoolingType</span><span class="o">.</span><span class="n">AVERAGE</span><span class="p">,</span>
|
||
<span class="n">kernel_size</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">stride</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">stride</span> <span class="o">=</span> <span class="n">kernel_size</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">stride_nd</span> <span class="o">=</span> <span class="n">stride</span>
|
||
|
||
<span class="n">output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">ndim</span> <span class="o">==</span> <span class="mi">3</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">output</span><span class="o">.</span><span class="n">view</span><span class="p">(</span>
|
||
<span class="n">concat</span><span class="p">([</span><span class="n">output</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span>
|
||
<span class="n">output</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">2</span><span class="p">),</span>
|
||
<span class="n">output</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">3</span><span class="p">)]))</span>
|
||
|
||
<span class="k">return</span> <span class="n">output</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="conv1d">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.conv1d">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">conv1d</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">weight</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">bias</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">stride</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">padding</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">dilation</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">groups</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
|
||
<span class="n">noutput</span> <span class="o">=</span> <span class="n">weight</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="mi">0</span><span class="p">]</span>
|
||
<span class="n">kernel_size</span> <span class="o">=</span> <span class="n">weight</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="o">-</span><span class="mi">2</span><span class="p">]</span>
|
||
<span class="n">is_weight_constant</span> <span class="o">=</span> <span class="p">(</span><span class="n">weight</span><span class="o">.</span><span class="n">producer</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="ow">and</span> <span class="n">weight</span><span class="o">.</span><span class="n">producer</span><span class="o">.</span><span class="n">type</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">LayerType</span><span class="o">.</span><span class="n">CONSTANT</span><span class="p">)</span>
|
||
<span class="n">weight</span> <span class="o">=</span> <span class="n">weight</span><span class="o">.</span><span class="n">producer</span><span class="o">.</span><span class="n">weights</span> <span class="k">if</span> <span class="n">is_weight_constant</span> <span class="k">else</span> <span class="n">trt</span><span class="o">.</span><span class="n">Weights</span><span class="p">()</span>
|
||
|
||
<span class="k">if</span> <span class="n">bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">is_bias_constant</span> <span class="o">=</span> <span class="p">(</span><span class="n">bias</span><span class="o">.</span><span class="n">producer</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="ow">and</span> <span class="n">bias</span><span class="o">.</span><span class="n">producer</span><span class="o">.</span><span class="n">type</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">LayerType</span><span class="o">.</span><span class="n">CONSTANT</span><span class="p">)</span>
|
||
<span class="n">bias</span> <span class="o">=</span> <span class="n">bias</span><span class="o">.</span><span class="n">producer</span><span class="o">.</span><span class="n">weights</span> <span class="k">if</span> <span class="n">is_bias_constant</span> <span class="k">else</span> <span class="n">trt</span><span class="o">.</span><span class="n">Weights</span><span class="p">()</span>
|
||
|
||
<span class="n">input_shuffled</span> <span class="o">=</span> <span class="n">stack</span><span class="p">([</span><span class="nb">input</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">())</span>
|
||
<span class="n">kernel_size</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">Dims</span><span class="p">([</span><span class="n">kernel_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">])</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_convolution_nd</span><span class="p">(</span><span class="n">input_shuffled</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">noutput</span><span class="p">,</span> <span class="n">kernel_size</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span>
|
||
<span class="n">bias</span><span class="p">)</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">stride_nd</span> <span class="o">=</span> <span class="p">(</span><span class="n">stride</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">padding_nd</span> <span class="o">=</span> <span class="p">(</span><span class="n">padding</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">dilation_nd</span> <span class="o">=</span> <span class="p">(</span><span class="n">dilation</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">num_groups</span> <span class="o">=</span> <span class="n">groups</span>
|
||
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="n">is_weight_constant</span><span class="p">:</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">weight</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">is_bias_constant</span><span class="p">:</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">bias</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
|
||
<span class="n">output_2d</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
<span class="n">output_1d</span> <span class="o">=</span> <span class="n">squeeze</span><span class="p">(</span><span class="n">output_2d</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">output_1d</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="conv2d">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.conv2d">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">conv2d</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">weight</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">bias</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">stride</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
|
||
<span class="n">padding</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span>
|
||
<span class="n">dilation</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
|
||
<span class="n">groups</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">pre_padding</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">post_padding</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="c1">##</span>
|
||
<span class="c1">## TODO: Document that function!</span>
|
||
<span class="c1">##</span>
|
||
|
||
<span class="n">ndim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
|
||
<span class="k">if</span> <span class="n">ndim</span> <span class="o">==</span> <span class="mi">3</span><span class="p">:</span>
|
||
<span class="nb">input</span> <span class="o">=</span> <span class="n">expand_dims</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
|
||
|
||
<span class="n">noutput</span> <span class="o">=</span> <span class="n">weight</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="mi">0</span><span class="p">]</span>
|
||
<span class="n">kernel_size</span> <span class="o">=</span> <span class="p">(</span><span class="n">weight</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="o">-</span><span class="mi">2</span><span class="p">],</span> <span class="n">weight</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
|
||
|
||
<span class="n">is_weight_constant</span> <span class="o">=</span> <span class="p">(</span><span class="n">weight</span><span class="o">.</span><span class="n">producer</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="ow">and</span> <span class="n">weight</span><span class="o">.</span><span class="n">producer</span><span class="o">.</span><span class="n">type</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">LayerType</span><span class="o">.</span><span class="n">CONSTANT</span><span class="p">)</span>
|
||
<span class="n">weight</span> <span class="o">=</span> <span class="n">weight</span><span class="o">.</span><span class="n">producer</span><span class="o">.</span><span class="n">weights</span> <span class="k">if</span> <span class="n">is_weight_constant</span> <span class="k">else</span> <span class="n">trt</span><span class="o">.</span><span class="n">Weights</span><span class="p">()</span>
|
||
|
||
<span class="k">if</span> <span class="n">bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">is_bias_constant</span> <span class="o">=</span> <span class="p">(</span><span class="n">bias</span><span class="o">.</span><span class="n">producer</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="ow">and</span> <span class="n">bias</span><span class="o">.</span><span class="n">producer</span><span class="o">.</span><span class="n">type</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">LayerType</span><span class="o">.</span><span class="n">CONSTANT</span><span class="p">)</span>
|
||
<span class="n">bias</span> <span class="o">=</span> <span class="n">bias</span><span class="o">.</span><span class="n">producer</span><span class="o">.</span><span class="n">weights</span> <span class="k">if</span> <span class="n">is_bias_constant</span> <span class="k">else</span> <span class="n">trt</span><span class="o">.</span><span class="n">Weights</span><span class="p">()</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_convolution_nd</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">noutput</span><span class="p">,</span>
|
||
<span class="n">kernel_size</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">bias</span><span class="p">)</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">stride_nd</span> <span class="o">=</span> <span class="n">stride</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">padding_nd</span> <span class="o">=</span> <span class="n">padding</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">dilation_nd</span> <span class="o">=</span> <span class="n">dilation</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">num_groups</span> <span class="o">=</span> <span class="n">groups</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">dilation_nd</span> <span class="o">=</span> <span class="n">dilation</span>
|
||
<span class="k">if</span> <span class="n">pre_padding</span><span class="p">:</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">pre_padding</span> <span class="o">=</span> <span class="n">pre_padding</span>
|
||
<span class="k">if</span> <span class="n">post_padding</span><span class="p">:</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">post_padding</span> <span class="o">=</span> <span class="n">post_padding</span>
|
||
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="n">is_weight_constant</span><span class="p">:</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">weight</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">is_bias_constant</span><span class="p">:</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">bias</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
|
||
<span class="n">output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">ndim</span> <span class="o">==</span> <span class="mi">3</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">output</span><span class="o">.</span><span class="n">view</span><span class="p">(</span>
|
||
<span class="n">concat</span><span class="p">([</span><span class="n">output</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span>
|
||
<span class="n">output</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">2</span><span class="p">),</span>
|
||
<span class="n">output</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">3</span><span class="p">)]))</span>
|
||
|
||
<span class="k">return</span> <span class="n">output</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="conv_transpose2d">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.conv_transpose2d">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">conv_transpose2d</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">weight</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">bias</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">stride</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
|
||
<span class="n">padding</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span>
|
||
<span class="n">output_padding</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span>
|
||
<span class="n">dilation</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
|
||
<span class="n">groups</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="c1">##</span>
|
||
<span class="c1">## TODO: Document that function!</span>
|
||
<span class="c1">##</span>
|
||
|
||
<span class="k">assert</span> <span class="ow">not</span> <span class="nb">input</span><span class="o">.</span><span class="n">is_dynamic</span><span class="p">()</span>
|
||
|
||
<span class="n">ndim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
|
||
<span class="k">if</span> <span class="n">ndim</span> <span class="o">==</span> <span class="mi">3</span><span class="p">:</span>
|
||
<span class="nb">input</span> <span class="o">=</span> <span class="n">expand_dims</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
|
||
|
||
<span class="n">noutput</span> <span class="o">=</span> <span class="n">weight</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="mi">1</span><span class="p">]</span>
|
||
<span class="n">kernel_size</span> <span class="o">=</span> <span class="p">(</span><span class="n">weight</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="o">-</span><span class="mi">2</span><span class="p">],</span> <span class="n">weight</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
|
||
|
||
<span class="n">is_weight_constant</span> <span class="o">=</span> <span class="p">(</span><span class="n">weight</span><span class="o">.</span><span class="n">producer</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="ow">and</span> <span class="n">weight</span><span class="o">.</span><span class="n">producer</span><span class="o">.</span><span class="n">type</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">LayerType</span><span class="o">.</span><span class="n">CONSTANT</span><span class="p">)</span>
|
||
<span class="n">weight</span> <span class="o">=</span> <span class="n">weight</span><span class="o">.</span><span class="n">producer</span><span class="o">.</span><span class="n">weights</span> <span class="k">if</span> <span class="n">is_weight_constant</span> <span class="k">else</span> <span class="n">trt</span><span class="o">.</span><span class="n">Weights</span><span class="p">()</span>
|
||
|
||
<span class="k">if</span> <span class="n">bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">is_bias_constant</span> <span class="o">=</span> <span class="p">(</span><span class="n">bias</span><span class="o">.</span><span class="n">producer</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="ow">and</span> <span class="n">bias</span><span class="o">.</span><span class="n">producer</span><span class="o">.</span><span class="n">type</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">LayerType</span><span class="o">.</span><span class="n">CONSTANT</span><span class="p">)</span>
|
||
<span class="n">bias</span> <span class="o">=</span> <span class="n">bias</span><span class="o">.</span><span class="n">producer</span><span class="o">.</span><span class="n">weights</span> <span class="k">if</span> <span class="n">is_bias_constant</span> <span class="k">else</span> <span class="n">trt</span><span class="o">.</span><span class="n">Weights</span><span class="p">()</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_deconvolution_nd</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">noutput</span><span class="p">,</span>
|
||
<span class="n">kernel_size</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">bias</span><span class="p">)</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">stride_nd</span> <span class="o">=</span> <span class="n">stride</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">padding_nd</span> <span class="o">=</span> <span class="n">padding</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">num_groups</span> <span class="o">=</span> <span class="n">groups</span>
|
||
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="n">is_weight_constant</span><span class="p">:</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">weight</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">is_bias_constant</span><span class="p">:</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">bias</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
|
||
<span class="n">output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">ndim</span> <span class="o">==</span> <span class="mi">3</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">output</span><span class="o">.</span><span class="n">view</span><span class="p">(</span>
|
||
<span class="n">concat</span><span class="p">([</span><span class="n">output</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span>
|
||
<span class="n">output</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">2</span><span class="p">),</span>
|
||
<span class="n">output</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">3</span><span class="p">)]))</span>
|
||
|
||
<span class="k">return</span> <span class="n">output</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="split">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.split">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">split</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">split_size_or_sections</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">Sequence</span><span class="p">[</span><span class="nb">int</span><span class="p">]],</span>
|
||
<span class="n">dim</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">)</span> <span class="o">-></span> <span class="n">Sequence</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation that splits a tensor into sub-tensors.</span>
|
||
|
||
<span class="sd"> This operation creates a list of tensors that are obtained from the input</span>
|
||
<span class="sd"> tensor by slicing it along the dimension 'dim'. If 'split_size_or_sections'</span>
|
||
<span class="sd"> is an integer, the tensor is split into 'input.shape[dim] /</span>
|
||
<span class="sd"> split_size_or_sections' slices. If 'split_size_or_sections' is a list of</span>
|
||
<span class="sd"> sizes, the tensor is split into 'len(split_size_or_sections)' slices and</span>
|
||
<span class="sd"> the size of the ith slice is given by 'split_size_or_sections[i]'.</span>
|
||
|
||
<span class="sd"> There are several constraints with the current implementation:</span>
|
||
|
||
<span class="sd"> - The input tensor must be static (no dynamic dimension),</span>
|
||
<span class="sd"> - If 'split_size_or_sections' is an integer, the number of elements in</span>
|
||
<span class="sd"> the 'dim' dimension of the input must be a multiple of</span>
|
||
<span class="sd"> 'split_size_or_sections': 'input.shape[dim] % split_size_or_sections == 0'.</span>
|
||
<span class="sd"> - If 'split_size_or_sections' is a sequence, the sum of the elements in</span>
|
||
<span class="sd"> 'split_size_or_sections' must be equal to the size in the dimension</span>
|
||
<span class="sd"> 'dim': 'input.shape[dim] == sum(ii for ii in split_size_or_sections)'.</span>
|
||
|
||
<span class="sd"> That operation is implemented using a 'slice' operation for each output</span>
|
||
<span class="sd"> slice.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> tensor : Tensor</span>
|
||
<span class="sd"> The input tensor to slice.</span>
|
||
|
||
<span class="sd"> split_size_or_sections : Union[int, Sequence[int]]</span>
|
||
<span class="sd"> If it is an integer, it encodes the size of each slice. Otherwise,</span>
|
||
<span class="sd"> if it is a sequence, it is the size of each slice.</span>
|
||
|
||
<span class="sd"> dim : int</span>
|
||
<span class="sd"> The dimension of the tensor to slice.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The list of tensors produced by the different operations.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">assert</span> <span class="ow">not</span> <span class="n">tensor</span><span class="o">.</span><span class="n">is_dynamic</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span>
|
||
|
||
<span class="n">ndim</span> <span class="o">=</span> <span class="n">tensor</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
|
||
<span class="k">if</span> <span class="n">dim</span> <span class="o"><</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="n">dim</span> <span class="o">+=</span> <span class="n">ndim</span>
|
||
<span class="n">dim_value</span> <span class="o">=</span> <span class="n">tensor</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="n">dim</span><span class="p">]</span>
|
||
<span class="n">starts</span> <span class="o">=</span> <span class="p">[</span><span class="n">constant</span><span class="p">(</span><span class="n">dims_array</span><span class="p">([</span><span class="mi">0</span><span class="p">]))</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)]</span>
|
||
<span class="n">sizes</span> <span class="o">=</span> <span class="p">[</span><span class="n">shape</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">i</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)]</span>
|
||
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">split_size_or_sections</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
|
||
<span class="c1"># TODO: support non-divisible cases</span>
|
||
<span class="k">assert</span> <span class="n">dim_value</span> <span class="o">%</span> <span class="n">split_size_or_sections</span> <span class="o">==</span> <span class="mi">0</span>
|
||
<span class="n">num_sections</span> <span class="o">=</span> <span class="n">dim_value</span> <span class="o">//</span> <span class="n">split_size_or_sections</span>
|
||
<span class="n">sizes</span><span class="p">[</span><span class="n">dim</span><span class="p">]</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">dims_array</span><span class="p">([</span><span class="n">split_size_or_sections</span><span class="p">]))</span>
|
||
|
||
<span class="n">outputs</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_sections</span><span class="p">):</span>
|
||
<span class="n">starts</span><span class="p">[</span><span class="n">dim</span><span class="p">]</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">dims_array</span><span class="p">([</span><span class="n">split_size_or_sections</span> <span class="o">*</span> <span class="n">i</span><span class="p">]))</span>
|
||
<span class="n">outputs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="nb">slice</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">concat</span><span class="p">(</span><span class="n">starts</span><span class="p">),</span> <span class="n">concat</span><span class="p">(</span><span class="n">sizes</span><span class="p">)))</span>
|
||
<span class="k">return</span> <span class="n">outputs</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">total_size</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">split_size_or_sections</span><span class="p">:</span>
|
||
<span class="n">total_size</span> <span class="o">+=</span> <span class="n">i</span>
|
||
<span class="k">assert</span> <span class="n">dim_value</span> <span class="o">==</span> <span class="n">total_size</span>
|
||
<span class="n">num_sections</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">split_size_or_sections</span><span class="p">)</span>
|
||
|
||
<span class="n">outputs</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_sections</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="n">i</span> <span class="o">></span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="n">starts</span><span class="p">[</span><span class="n">dim</span><span class="p">]</span> <span class="o">=</span> <span class="n">starts</span><span class="p">[</span><span class="n">dim</span><span class="p">]</span> <span class="o">+</span> <span class="n">sizes</span><span class="p">[</span><span class="n">dim</span><span class="p">]</span>
|
||
<span class="n">sizes</span><span class="p">[</span><span class="n">dim</span><span class="p">]</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">dims_array</span><span class="p">([</span><span class="n">split_size_or_sections</span><span class="p">[</span><span class="n">i</span><span class="p">]]))</span>
|
||
<span class="n">outputs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="nb">slice</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">concat</span><span class="p">(</span><span class="n">starts</span><span class="p">),</span> <span class="n">concat</span><span class="p">(</span><span class="n">sizes</span><span class="p">)))</span>
|
||
<span class="k">return</span> <span class="n">outputs</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="chunk">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.chunk">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">chunk</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">chunks</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation that splits a tensor into sub-tensors.</span>
|
||
|
||
<span class="sd"> This operation creates a list of tensors that are obtained from the input</span>
|
||
<span class="sd"> tensor by chunking it along the dimension 'dim'. It produces 'chunks'</span>
|
||
<span class="sd"> sub-tensors.</span>
|
||
|
||
<span class="sd"> That operation is only defined for static tensors (no dynamic dimension)</span>
|
||
<span class="sd"> and the size of the tensor in the dimension 'dim' must be a multiple of</span>
|
||
<span class="sd"> 'chunks': 'input.shape[dim] % chunks == 0'.</span>
|
||
|
||
<span class="sd"> It maps to 'split' with 'split_size = input.shape[dim] / chunks'.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> tensor : Tensor</span>
|
||
<span class="sd"> The input tensor to slice.</span>
|
||
|
||
<span class="sd"> chunks : int</span>
|
||
<span class="sd"> The number of slices to split the input tensor into.</span>
|
||
|
||
<span class="sd"> dim : int</span>
|
||
<span class="sd"> The dimension of the tensor to slice.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The list of tensors produced by the different operations.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">assert</span> <span class="ow">not</span> <span class="n">tensor</span><span class="o">.</span><span class="n">is_dynamic</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span>
|
||
|
||
<span class="n">ndim</span> <span class="o">=</span> <span class="n">tensor</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
|
||
<span class="k">if</span> <span class="n">dim</span> <span class="o"><</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="n">dim</span> <span class="o">+=</span> <span class="n">ndim</span>
|
||
<span class="n">dim_value</span> <span class="o">=</span> <span class="n">tensor</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="n">dim</span><span class="p">]</span>
|
||
<span class="k">assert</span> <span class="n">dim_value</span> <span class="o">%</span> <span class="n">chunks</span> <span class="o">==</span> <span class="mi">0</span>
|
||
|
||
<span class="k">return</span> <span class="n">split</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">dim_value</span> <span class="o">//</span> <span class="n">chunks</span><span class="p">,</span> <span class="n">dim</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="unbind">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.unbind">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">unbind</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Removes a tensor dimension.</span>
|
||
|
||
<span class="sd"> Returns a tuple of all slices along a given dimension, already without it.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">ndim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
|
||
<span class="n">outputs</span> <span class="o">=</span> <span class="n">split</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">dim</span><span class="p">)</span>
|
||
<span class="n">output_shape</span> <span class="o">=</span> <span class="p">[</span><span class="nb">input</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)</span> <span class="k">if</span> <span class="n">i</span> <span class="o">!=</span> <span class="n">dim</span><span class="p">]</span>
|
||
<span class="k">return</span> <span class="p">[</span><span class="n">output</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">output_shape</span><span class="p">)</span> <span class="k">for</span> <span class="n">output</span> <span class="ow">in</span> <span class="n">outputs</span><span class="p">]</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="AllReduceStrategy">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.AllReduceStrategy">[docs]</a>
|
||
<span class="k">class</span> <span class="nc">AllReduceStrategy</span><span class="p">(</span><span class="n">IntEnum</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> Warning: actual definition is in cpp/tensorrt_llm/kernels/customAllReduceKernels.h</span>
|
||
<span class="sd"> they must be kept in sync</span>
|
||
<span class="sd"> """</span>
|
||
<span class="n">NCCL</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="n">ONESHOT</span> <span class="o">=</span> <span class="mi">1</span>
|
||
<span class="n">TWOSHOT</span> <span class="o">=</span> <span class="mi">2</span>
|
||
<span class="n">UB</span> <span class="o">=</span> <span class="mi">3</span>
|
||
<span class="n">AUTO</span> <span class="o">=</span> <span class="mi">4</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="AllReduceConfig">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.AllReduceConfig">[docs]</a>
|
||
<span class="k">class</span> <span class="nc">AllReduceConfig</span><span class="p">(</span><span class="n">IntFlag</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> Warning: actual definition is in cpp/tensorrt_llm/kernels/customAllReduceKernels.h</span>
|
||
<span class="sd"> they must be kept in sync</span>
|
||
<span class="sd"> """</span>
|
||
<span class="n">USE_MEMCPY</span> <span class="o">=</span> <span class="n">auto</span><span class="p">()</span>
|
||
<span class="n">PUSH_MODE</span> <span class="o">=</span> <span class="n">auto</span><span class="p">()</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="AllReduceFusionOp">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.AllReduceFusionOp">[docs]</a>
|
||
<span class="k">class</span> <span class="nc">AllReduceFusionOp</span><span class="p">(</span><span class="n">IntFlag</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> Warning: actual definition is in cpp/tensorrt_llm/kernels/customAllReduceKernels.h</span>
|
||
<span class="sd"> they must be kept in sync</span>
|
||
<span class="sd"> """</span>
|
||
<span class="n">NONE</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="n">RESIDUAL_RMS_NORM</span> <span class="o">=</span> <span class="mi">1</span>
|
||
<span class="n">LAST_PROCESS_FOR_UB</span> <span class="o">=</span> <span class="mi">2</span>
|
||
<span class="n">RESIDUAL_RMS_PREPOST_NORM</span> <span class="o">=</span> <span class="mi">3</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="AllReduceParams">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.AllReduceParams">[docs]</a>
|
||
<span class="k">class</span> <span class="nc">AllReduceParams</span><span class="p">():</span>
|
||
|
||
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
|
||
<span class="n">strategy</span><span class="p">:</span> <span class="n">AllReduceStrategy</span> <span class="o">=</span> <span class="n">AllReduceStrategy</span><span class="o">.</span><span class="n">AUTO</span><span class="p">,</span>
|
||
<span class="n">config</span><span class="p">:</span> <span class="n">AllReduceConfig</span> <span class="o">=</span> <span class="n">AllReduceConfig</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
|
||
<span class="n">fusion_op</span><span class="p">:</span> <span class="n">AllReduceFusionOp</span> <span class="o">=</span> <span class="n">AllReduceFusionOp</span><span class="o">.</span><span class="n">NONE</span><span class="p">,</span>
|
||
<span class="n">bias</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">residual</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">norm_weight</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">scale</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">norm_pre_residual_weight</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">eps</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-06</span><span class="p">):</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">strategy</span> <span class="o">=</span> <span class="n">strategy</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">config</span> <span class="o">=</span> <span class="n">config</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">fusion_op</span> <span class="o">=</span> <span class="n">fusion_op</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">bias</span> <span class="o">=</span> <span class="n">bias</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">residual</span> <span class="o">=</span> <span class="n">residual</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">norm_weight</span> <span class="o">=</span> <span class="n">norm_weight</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">scale</span> <span class="o">=</span> <span class="n">scale</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">norm_pre_residual_weight</span> <span class="o">=</span> <span class="n">norm_pre_residual_weight</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="n">eps</span>
|
||
<span class="k">assert</span> <span class="n">fusion_op</span> <span class="o">==</span> <span class="n">AllReduceFusionOp</span><span class="o">.</span><span class="n">NONE</span> <span class="ow">or</span> <span class="p">(</span><span class="n">residual</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">)</span>
|
||
|
||
<div class="viewcode-block" id="AllReduceParams.has_affine">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.AllReduceParams.has_affine">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">has_affine</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="mi">1</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_weight</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="k">else</span> <span class="mi">0</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="AllReduceParams.has_bias">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.AllReduceParams.has_bias">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">has_bias</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="mi">1</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="k">else</span> <span class="mi">0</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="AllReduceParams.has_scale">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.AllReduceParams.has_scale">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">has_scale</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="mi">1</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="k">else</span> <span class="mi">0</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="AllReduceParams.update_strategy">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.AllReduceParams.update_strategy">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">update_strategy</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">strategy</span> <span class="o">==</span> <span class="n">AllReduceStrategy</span><span class="o">.</span><span class="n">AUTO</span> <span class="ow">and</span> <span class="n">default_net</span><span class="p">(</span>
|
||
<span class="p">)</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">user_buffer</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">strategy</span> <span class="o">=</span> <span class="n">AllReduceStrategy</span><span class="o">.</span><span class="n">UB</span></div>
|
||
</div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="create_allreduce_plugin">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.create_allreduce_plugin">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">create_allreduce_plugin</span><span class="p">(</span>
|
||
<span class="n">network</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">INetworkDefinition</span><span class="p">,</span>
|
||
<span class="n">tensor</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">ITensor</span><span class="p">,</span>
|
||
<span class="n">workspace</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">trt</span><span class="o">.</span><span class="n">ITensor</span><span class="p">],</span>
|
||
<span class="n">group</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="p">,</span>
|
||
<span class="n">all_reduce_params</span><span class="p">:</span> <span class="n">AllReduceParams</span><span class="p">,</span>
|
||
<span class="p">):</span>
|
||
<span class="n">allreduce_plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
|
||
<span class="s1">'AllReduce'</span><span class="p">,</span> <span class="s1">'1'</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">allreduce_plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
|
||
<span class="n">pf_group</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"group"</span><span class="p">,</span> <span class="n">group</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">pf_dtype</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"type_id"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">dtype</span><span class="p">)],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">pfc</span> <span class="o">=</span> <span class="p">[</span><span class="n">pf_group</span><span class="p">,</span> <span class="n">pf_dtype</span><span class="p">]</span>
|
||
<span class="n">p_strategy</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"strategy"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">all_reduce_params</span><span class="o">.</span><span class="n">strategy</span><span class="p">)],</span> <span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">pfc</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">p_strategy</span><span class="p">)</span>
|
||
<span class="n">p_config</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"config"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">all_reduce_params</span><span class="o">.</span><span class="n">config</span><span class="p">)],</span> <span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">pfc</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">p_config</span><span class="p">)</span>
|
||
<span class="n">p_fusion_op</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"fusion_op"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">all_reduce_params</span><span class="o">.</span><span class="n">fusion_op</span><span class="p">)],</span> <span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">pfc</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">p_fusion_op</span><span class="p">)</span>
|
||
<span class="n">p_eps</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"eps"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">float</span><span class="p">(</span><span class="n">all_reduce_params</span><span class="o">.</span><span class="n">eps</span><span class="p">)],</span> <span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">FLOAT32</span><span class="p">)</span>
|
||
<span class="n">pfc</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">p_eps</span><span class="p">)</span>
|
||
|
||
<span class="n">p_affine</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"affine"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">all_reduce_params</span><span class="o">.</span><span class="n">has_affine</span><span class="p">())],</span> <span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">pfc</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">p_affine</span><span class="p">)</span>
|
||
<span class="n">p_bias</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"bias"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">all_reduce_params</span><span class="o">.</span><span class="n">has_bias</span><span class="p">())],</span> <span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">pfc</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">p_bias</span><span class="p">)</span>
|
||
<span class="n">p_scale</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"scale"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">all_reduce_params</span><span class="o">.</span><span class="n">has_scale</span><span class="p">())],</span> <span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">pfc</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">p_scale</span><span class="p">)</span>
|
||
|
||
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">(</span><span class="n">pfc</span><span class="p">)</span>
|
||
<span class="n">ar_plug</span> <span class="o">=</span> <span class="n">allreduce_plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">"allreduce"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">tensor</span><span class="p">]</span>
|
||
<span class="k">if</span> <span class="n">all_reduce_params</span><span class="o">.</span><span class="n">strategy</span> <span class="o">!=</span> <span class="n">AllReduceStrategy</span><span class="o">.</span><span class="n">NCCL</span> <span class="ow">and</span> <span class="n">all_reduce_params</span><span class="o">.</span><span class="n">strategy</span> <span class="o">!=</span> <span class="n">AllReduceStrategy</span><span class="o">.</span><span class="n">UB</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">workspace</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">all_reduce_params</span><span class="o">.</span><span class="n">fusion_op</span> <span class="o">!=</span> <span class="n">AllReduceFusionOp</span><span class="o">.</span><span class="n">NONE</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="n">all_reduce_params</span><span class="o">.</span><span class="n">has_bias</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">all_reduce_params</span><span class="o">.</span><span class="n">bias</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="n">plug_inputs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">all_reduce_params</span><span class="o">.</span><span class="n">residual</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">all_reduce_params</span><span class="o">.</span><span class="n">has_affine</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">all_reduce_params</span><span class="o">.</span><span class="n">norm_weight</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">all_reduce_params</span><span class="o">.</span><span class="n">fusion_op</span> <span class="o">==</span> <span class="n">AllReduceFusionOp</span><span class="o">.</span><span class="n">RESIDUAL_RMS_PREPOST_NORM</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span>
|
||
<span class="n">all_reduce_params</span><span class="o">.</span><span class="n">norm_pre_residual_weight</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">all_reduce_params</span><span class="o">.</span><span class="n">has_scale</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">all_reduce_params</span><span class="o">.</span><span class="n">scale</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">network</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">ar_plug</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">layer</span><span class="p">,</span> <span class="n">allreduce_plg_creator</span><span class="p">,</span> <span class="n">pfc</span></div>
|
||
|
||
|
||
|
||
<span class="n">allreduce_ub_counter</span> <span class="o">=</span> <span class="mi">0</span>
|
||
|
||
|
||
<div class="viewcode-block" id="allreduce">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.allreduce">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">allreduce</span><span class="p">(</span>
|
||
<span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">group</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span>
|
||
<span class="n">all_reduce_params</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">AllReduceParams</span><span class="p">]</span> <span class="o">=</span> <span class="n">AllReduceParams</span><span class="p">()</span>
|
||
<span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation that performs a collective all-reduce.</span>
|
||
|
||
<span class="sd"> Let's define 'world_size' as the length of the 'group' list. That functions</span>
|
||
<span class="sd"> creates a layer to compute the sum of 'world_size' tensors distributed</span>
|
||
<span class="sd"> amongst the 'world_size' participating ranks (one GPU per rank).</span>
|
||
|
||
<span class="sd"> The list 'group' contains the identifiers of the ranks participating into</span>
|
||
<span class="sd"> the collective operation.</span>
|
||
|
||
<span class="sd"> The tensors in the different ranks must be 1D tensors (or views) and the output</span>
|
||
<span class="sd"> tensor will have that same shape. The output tensor will be replicated on</span>
|
||
<span class="sd"> the 'world_size' ranks.</span>
|
||
|
||
<span class="sd"> That operation is implemented using a plugin that wraps the NCCL all-reduce</span>
|
||
<span class="sd"> collective operation. See</span>
|
||
<span class="sd"> https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#allreduce</span>
|
||
<span class="sd"> for details.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> tensor : Tensor</span>
|
||
<span class="sd"> The input tensor.</span>
|
||
|
||
<span class="sd"> group : List[int]</span>
|
||
<span class="sd"> The ranks participating into the all-reduce operation.</span>
|
||
|
||
<span class="sd"> strategy: AllReduceStrategy</span>
|
||
<span class="sd"> NCCL delegates all-reduce to NCCL while ONESHOT and TWOSHOT are custom latency-optimal algorithms.</span>
|
||
<span class="sd"> AUTO chooses amongst the three based on a message-size heuristic.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by that layer.</span>
|
||
<span class="sd"> '''</span>
|
||
|
||
<span class="k">global</span> <span class="n">allreduce_ub_counter</span>
|
||
<span class="n">allreduce_ub_counter</span> <span class="o">+=</span> <span class="mi">1</span>
|
||
|
||
<span class="k">if</span> <span class="n">all_reduce_params</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">all_reduce_params</span> <span class="o">=</span> <span class="n">AllReduceParams</span><span class="p">()</span>
|
||
<span class="n">all_reduce_params</span><span class="o">.</span><span class="n">update_strategy</span><span class="p">()</span>
|
||
|
||
<span class="c1"># TODO(TRTLLM-996): remove this WAR when custom allreduce is supported</span>
|
||
<span class="c1"># for encoder models in C++ runtime.</span>
|
||
<span class="n">workspace</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="k">if</span> <span class="n">all_reduce_params</span><span class="o">.</span><span class="n">strategy</span> <span class="o">!=</span> <span class="n">AllReduceStrategy</span><span class="o">.</span><span class="n">NCCL</span> <span class="ow">and</span> <span class="n">all_reduce_params</span><span class="o">.</span><span class="n">strategy</span> <span class="o">!=</span> <span class="n">AllReduceStrategy</span><span class="o">.</span><span class="n">UB</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="n">current_all_reduce_helper</span><span class="p">()</span><span class="o">.</span><span class="n">workspace</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">all_reduce_params</span><span class="o">.</span><span class="n">strategy</span> <span class="o">=</span> <span class="n">AllReduceStrategy</span><span class="o">.</span><span class="n">NCCL</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">workspace</span> <span class="o">=</span> <span class="n">current_all_reduce_helper</span><span class="p">()</span><span class="o">.</span><span class="n">workspace</span><span class="o">.</span><span class="n">trt_tensor</span>
|
||
<span class="k">if</span> <span class="n">all_reduce_params</span><span class="o">.</span><span class="n">strategy</span> <span class="o">==</span> <span class="n">AllReduceStrategy</span><span class="o">.</span><span class="n">UB</span><span class="p">:</span>
|
||
<span class="n">tensor</span><span class="o">.</span><span class="n">mark_output</span><span class="p">(</span><span class="s2">"allreduce_ub_0_"</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">allreduce_ub_counter</span><span class="p">))</span>
|
||
<span class="n">dtype</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">nccl_plugin</span>
|
||
<span class="n">layer</span><span class="p">,</span> <span class="n">allreduce_plg_creator</span><span class="p">,</span> <span class="n">pfc</span> <span class="o">=</span> <span class="n">create_allreduce_plugin</span><span class="p">(</span>
|
||
<span class="n">network</span><span class="o">=</span><span class="n">default_trtnet</span><span class="p">(),</span>
|
||
<span class="n">tensor</span><span class="o">=</span><span class="n">tensor</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">workspace</span><span class="o">=</span><span class="n">workspace</span><span class="p">,</span>
|
||
<span class="n">group</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">group</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">dtype</span><span class="p">),</span>
|
||
<span class="n">all_reduce_params</span><span class="o">=</span><span class="n">all_reduce_params</span><span class="p">,</span>
|
||
<span class="p">)</span>
|
||
<span class="n">_add_plugin_info</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">allreduce_plg_creator</span><span class="p">,</span> <span class="s2">"allreduce"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">all_reduce_params</span><span class="o">.</span><span class="n">fusion_op</span> <span class="o">!=</span> <span class="n">AllReduceFusionOp</span><span class="o">.</span><span class="n">NONE</span><span class="p">:</span>
|
||
<span class="n">inter_output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span>
|
||
<span class="n">layer</span><span class="p">)</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">all_reduce_params</span><span class="o">.</span><span class="n">strategy</span> <span class="o">==</span> <span class="n">AllReduceStrategy</span><span class="o">.</span><span class="n">UB</span> <span class="ow">and</span> <span class="n">all_reduce_params</span><span class="o">.</span><span class="n">has_scale</span><span class="p">(</span>
|
||
<span class="p">)</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
|
||
<span class="c1"># data type: trt.DataType.FP8</span>
|
||
<span class="n">final_output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">final_output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
|
||
<span class="n">layer</span><span class="p">)</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">all_reduce_params</span><span class="o">.</span><span class="n">strategy</span> <span class="o">==</span> <span class="n">AllReduceStrategy</span><span class="o">.</span><span class="n">UB</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="n">all_reduce_params</span><span class="o">.</span><span class="n">has_scale</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
|
||
<span class="n">final_output</span><span class="o">.</span><span class="n">mark_output</span><span class="p">(</span><span class="s2">"allreduce_ub_1_"</span> <span class="o">+</span>
|
||
<span class="nb">str</span><span class="p">(</span><span class="n">allreduce_ub_counter</span><span class="p">))</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="k">assert</span> <span class="n">all_reduce_params</span><span class="o">.</span><span class="n">fusion_op</span> <span class="o">==</span> <span class="n">AllReduceFusionOp</span><span class="o">.</span><span class="n">LAST_PROCESS_FOR_UB</span>
|
||
<span class="n">inter_output</span><span class="o">.</span><span class="n">mark_output</span><span class="p">(</span><span class="s2">"allreduce_ub_1_"</span> <span class="o">+</span>
|
||
<span class="nb">str</span><span class="p">(</span><span class="n">allreduce_ub_counter</span><span class="p">))</span>
|
||
<span class="k">return</span> <span class="n">final_output</span><span class="p">,</span> <span class="n">inter_output</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">final_output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
|
||
<span class="n">layer</span><span class="p">)</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">final_output</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="allgather">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.allgather">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">allgather</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">group</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span> <span class="n">gather_dim</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation that performs a collective all-gather.</span>
|
||
|
||
<span class="sd"> Let's define 'group_size' as the length of the 'group' list. That functions</span>
|
||
<span class="sd"> creates a layer to gather 'group_size' tensors distributed</span>
|
||
<span class="sd"> amongst the 'group_size' participating ranks (one GPU per rank).</span>
|
||
|
||
<span class="sd"> The list 'group' contains the identifiers of the ranks participating into</span>
|
||
<span class="sd"> the collective operation.</span>
|
||
|
||
<span class="sd"> Note that 'group' here can be either TP group or PP group, because allgather communication is not limited to a specific split pattern. Therefore 'group_size' does not need to equal MPI 'world_size'.</span>
|
||
|
||
<span class="sd"> The tensors in the different ranks must be 1D tensors (or views) and the</span>
|
||
<span class="sd"> output tensor will have that same shape.</span>
|
||
|
||
<span class="sd"> Given the 'section_size = input.shape[0] / group_size', each rank</span>
|
||
<span class="sd"> contributes a section of its input tensor that correspond to</span>
|
||
<span class="sd"> 'rank*section_size:(rank+1)*section_size'.</span>
|
||
|
||
<span class="sd"> That operation is implemented using a plugin that wraps the NCCL all-gather</span>
|
||
<span class="sd"> collective operation. See</span>
|
||
<span class="sd"> https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#allgather</span>
|
||
<span class="sd"> for details.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> tensor : Tensor</span>
|
||
<span class="sd"> The input tensor.</span>
|
||
|
||
<span class="sd"> group : List[int]</span>
|
||
<span class="sd"> The ranks participating into the all-gather operation.</span>
|
||
|
||
<span class="sd"> gather_dim: int = 0</span>
|
||
<span class="sd"> Gather along given dimension. By default 0, i.e. treated as 1D tensor.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by that layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">allgather_plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
|
||
<span class="s1">'AllGather'</span><span class="p">,</span> <span class="s1">'1'</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">allgather_plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
|
||
<span class="n">group_size</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">group</span><span class="p">)</span>
|
||
<span class="n">group</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"group"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">group</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
|
||
<span class="n">p_dtype</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">nccl_plugin</span>
|
||
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"type_id"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">))],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
|
||
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span><span class="n">group</span><span class="p">,</span> <span class="n">pf_type</span><span class="p">])</span>
|
||
<span class="n">allgather</span> <span class="o">=</span> <span class="n">allgather_plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">"allgather"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">tensor</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">)</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">]</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">allgather</span><span class="p">)</span>
|
||
<span class="n">_add_plugin_info</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">allgather_plg_creator</span><span class="p">,</span> <span class="s2">"allgather"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
|
||
<span class="n">x</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
|
||
|
||
<span class="c1"># gather along a given dimension other than dim0</span>
|
||
<span class="k">if</span> <span class="n">gather_dim</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="c1"># also support -1 type of dim representation</span>
|
||
<span class="k">if</span> <span class="n">gather_dim</span> <span class="o"><</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="n">gather_dim</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">+</span> <span class="n">gather_dim</span>
|
||
|
||
<span class="c1"># plugin above gathers as 1D flattened tensor</span>
|
||
<span class="c1"># 1. [dim0, ...dimi, ...dimN] -> [group_size * dim0, ...dimi, ...dimN]</span>
|
||
|
||
<span class="c1"># now we need to gather-by-dim via split-concat</span>
|
||
<span class="c1"># 2. [group_size * dim0, ...dimi, ...dimN] -> [dim0, ...group_size * dimi, ...dimN]</span>
|
||
<span class="c1"># 2.1 split</span>
|
||
<span class="n">split_size</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="o">/</span> <span class="n">group_size</span>
|
||
<span class="n">ndim</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
|
||
<span class="n">starts</span> <span class="o">=</span> <span class="p">[</span><span class="n">constant</span><span class="p">(</span><span class="n">dims_array</span><span class="p">([</span><span class="mi">0</span><span class="p">]))</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)]</span>
|
||
<span class="n">sizes</span> <span class="o">=</span> <span class="p">[</span><span class="n">shape</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="n">d</span><span class="p">)</span> <span class="k">for</span> <span class="n">d</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)]</span>
|
||
<span class="n">sizes</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">split_size</span>
|
||
<span class="n">sections</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">group_size</span><span class="p">):</span>
|
||
<span class="n">starts</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">split_size</span> <span class="o">*</span> <span class="n">i</span>
|
||
<span class="n">sections</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="nb">slice</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">concat</span><span class="p">(</span><span class="n">starts</span><span class="p">),</span> <span class="n">concat</span><span class="p">(</span><span class="n">sizes</span><span class="p">)))</span>
|
||
<span class="c1"># 2.2 concat</span>
|
||
<span class="n">x</span> <span class="o">=</span> <span class="n">concat</span><span class="p">(</span><span class="n">sections</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="n">gather_dim</span><span class="p">)</span>
|
||
|
||
<span class="k">return</span> <span class="n">x</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="reduce_scatter">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.reduce_scatter">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">reduce_scatter</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">group</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">])</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
|
||
<span class="n">plg_creater</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
|
||
<span class="s1">'ReduceScatter'</span><span class="p">,</span> <span class="s1">'1'</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">plg_creater</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
|
||
<span class="n">p_dtype</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">nccl_plugin</span>
|
||
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"type_id"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">))],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">group</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"group"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">group</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span><span class="n">group</span><span class="p">,</span> <span class="n">pf_type</span><span class="p">])</span>
|
||
|
||
<span class="n">reduce_scatter_plug</span> <span class="o">=</span> <span class="n">plg_creater</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">"reduce_scatter"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">tensor</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">)</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">]</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">reduce_scatter_plug</span><span class="p">)</span>
|
||
<span class="n">_add_plugin_info</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">plg_creater</span><span class="p">,</span> <span class="s2">"reduce_scatter"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="send">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.send">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">send</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">tgt</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation that performs a send from a rank to another.</span>
|
||
|
||
<span class="sd"> The send operation sends a tensor from one rank to another. If a rank 'i'</span>
|
||
<span class="sd"> sends a tensor to a rank 'j', the rank 'j' must have a corresponding 'recv'</span>
|
||
<span class="sd"> operation from rank 'i'. See 'recv'.</span>
|
||
|
||
<span class="sd"> That operation is implemented using a plugin that wraps the NCCL send</span>
|
||
<span class="sd"> point-to-point operation. See</span>
|
||
<span class="sd"> https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/p2p.html#ncclsend</span>
|
||
<span class="sd"> for details.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> tensor : Tensor</span>
|
||
<span class="sd"> The input tensor.</span>
|
||
|
||
<span class="sd"> tgt : int</span>
|
||
<span class="sd"> The rank that receives the tensor.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by that layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">send_plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
|
||
<span class="s1">'Send'</span><span class="p">,</span> <span class="s1">'1'</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">send_plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
|
||
<span class="n">tgt</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"tgt_rank"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">tgt</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
|
||
<span class="n">p_dtype</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">nccl_plugin</span>
|
||
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"type_id"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">))],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
|
||
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span><span class="n">tgt</span><span class="p">,</span> <span class="n">pf_type</span><span class="p">])</span>
|
||
<span class="n">send_plug</span> <span class="o">=</span> <span class="n">send_plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">"send"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">tensor</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">)</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">]</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">send_plug</span><span class="p">)</span>
|
||
<span class="n">_add_plugin_info</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">send_plg_creator</span><span class="p">,</span> <span class="s2">"send"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="recv">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.recv">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">recv</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">src</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation that performs a recv to a rank from another.</span>
|
||
|
||
<span class="sd"> The recv operation receives a tensor from on a rank from another. If a rank 'i'</span>
|
||
<span class="sd"> receives a tensor from a rank 'j', the rank 'j' must have a corresponding 'send'</span>
|
||
<span class="sd"> operation to rank 'j'. See 'send'.</span>
|
||
|
||
<span class="sd"> That operation is implemented using a plugin that wraps the NCCL recv</span>
|
||
<span class="sd"> point-to-point operation. See</span>
|
||
<span class="sd"> https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/p2p.html#ncclrecv</span>
|
||
<span class="sd"> for details.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> tensor : Tensor</span>
|
||
<span class="sd"> The input tensor.</span>
|
||
|
||
<span class="sd"> src : int</span>
|
||
<span class="sd"> The rank that sends the tensor to.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by that layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">recv_plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
|
||
<span class="s1">'Recv'</span><span class="p">,</span> <span class="s1">'1'</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">recv_plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
|
||
<span class="n">src</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"src_rank"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">src</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">p_dtype</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">nccl_plugin</span>
|
||
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"type_id"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">))],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
|
||
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span><span class="n">src</span><span class="p">,</span> <span class="n">pf_type</span><span class="p">])</span>
|
||
<span class="n">recv_plug</span> <span class="o">=</span> <span class="n">recv_plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">"recv"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">tensor</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">)</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">]</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">recv_plug</span><span class="p">)</span>
|
||
<span class="n">_add_plugin_info</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">recv_plg_creator</span><span class="p">,</span> <span class="s2">"recv"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="bert_attention">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.bert_attention">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">bert_attention</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">input_lengths</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">num_heads</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">head_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">q_scaling</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span>
|
||
<span class="n">relative_attention</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">relative_attention_bias</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">max_distance</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">max_input_length</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tuple</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation that performs the multi-head attention in BERT.</span>
|
||
|
||
<span class="sd"> The multi-head attention (MHA) is the sequence of a batched matmul, a</span>
|
||
<span class="sd"> softmax and a batched matmul as described in</span>
|
||
<span class="sd"> https://arxiv.org/abs/1706.03762. That function adds an operation that</span>
|
||
<span class="sd"> performs those computations using a single GPU kernel.</span>
|
||
|
||
<span class="sd"> The input tensor contains the Q, K and V elements. It is a 2D tensor and</span>
|
||
<span class="sd"> its shape is '[sum_of_tokens, 3*hidden_dim]' where the 'sum_of_tokens' is</span>
|
||
<span class="sd"> the sum of the sequence lengths in the batch.</span>
|
||
|
||
<span class="sd"> In MHA, the output of the Q*K^T product is scaled by a constant value that</span>
|
||
<span class="sd"> is computed as:</span>
|
||
|
||
<span class="sd"> 1.f / (q_scaling * sqrt(head_size)).</span>
|
||
|
||
<span class="sd"> That 'q_scaling' constant is the last argument of that function.</span>
|
||
|
||
<span class="sd"> That layer is implemented using a plugin (see bertAttentionPlugin).</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> tensor : Tensor</span>
|
||
<span class="sd"> The QKV input tensor.</span>
|
||
|
||
<span class="sd"> input_lengths : Tensor</span>
|
||
<span class="sd"> The length of each sequence. It is a 1D tensor of size 'batch_size'.</span>
|
||
|
||
<span class="sd"> num_heads : int</span>
|
||
<span class="sd"> The number of heads.</span>
|
||
|
||
<span class="sd"> head_size : int</span>
|
||
<span class="sd"> The size of each head.</span>
|
||
|
||
<span class="sd"> q_scaling : float</span>
|
||
<span class="sd"> The factor to compute the scaling factor to scale the output of the</span>
|
||
<span class="sd"> 'Q*K^T' product.</span>
|
||
|
||
<span class="sd"> relative_attention: bool = False</span>
|
||
<span class="sd"> If enable relative attention.</span>
|
||
|
||
<span class="sd"> relative_attention_bias: Tensor = None</span>
|
||
<span class="sd"> The relative attention bias [num_heads, max_seq_len, max_seq_len], or The relative attention embedding table for implicit mode, [num_heads, num_buckets].</span>
|
||
|
||
<span class="sd"> max_distance: int = 0</span>
|
||
<span class="sd"> The maximum distance of relative position in attention, for implicit mode.</span>
|
||
<span class="sd"> Default value is 0, meaning to use the regular mode of relative attention bias.</span>
|
||
<span class="sd"> Implicit mode is only enabled when passing in non-zero positive max_distance value.</span>
|
||
<span class="sd"> See relative attention bias in docs/source/advanced/gpt-attention.md</span>
|
||
|
||
<span class="sd"> max_input_length: Tensor = None</span>
|
||
<span class="sd"> The maximum input sequence length represented by Tensor shape. Requires for remove_input_padding to pre-define plugin workspace size.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by that layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">attn_plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
|
||
<span class="s1">'BertAttention'</span><span class="p">,</span> <span class="s1">'1'</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">attn_plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
|
||
<span class="n">nheads</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"num_heads"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">num_heads</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">head_size</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"head_size"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">head_size</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">q_scaling</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"q_scaling"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">q_scaling</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">FLOAT32</span><span class="p">)</span>
|
||
<span class="n">context_fmha_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"context_fmha_type"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">context_fmha_type</span><span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">p_dtype</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">bert_attention_plugin</span>
|
||
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"type_id"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">))],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">do_relative_attention</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"do_relative_attention"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">relative_attention</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">max_distance</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"max_distance"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">max_distance</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">remove_padding</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"remove_padding"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span>
|
||
<span class="n">nheads</span><span class="p">,</span> <span class="n">head_size</span><span class="p">,</span> <span class="n">q_scaling</span><span class="p">,</span> <span class="n">context_fmha_type</span><span class="p">,</span> <span class="n">pf_type</span><span class="p">,</span>
|
||
<span class="n">do_relative_attention</span><span class="p">,</span> <span class="n">max_distance</span><span class="p">,</span> <span class="n">remove_padding</span>
|
||
<span class="p">])</span>
|
||
|
||
<span class="n">attn_plug</span> <span class="o">=</span> <span class="n">attn_plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">"padding_attn"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">tensor</span><span class="p">,</span> <span class="n">input_lengths</span><span class="p">]</span>
|
||
<span class="k">if</span> <span class="n">max_input_length</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="c1"># for remove padding mode</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">max_input_length</span><span class="p">]</span>
|
||
<span class="k">if</span> <span class="n">relative_attention_bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="c1"># for relative attention mode</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">relative_attention_bias</span><span class="p">]</span>
|
||
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">i</span><span class="o">.</span><span class="n">trt_tensor</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">plug_inputs</span><span class="p">]</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">attn_plug</span><span class="p">)</span>
|
||
<span class="n">_add_plugin_info</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">attn_plg_creator</span><span class="p">,</span> <span class="s2">"padding_attn"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">layer</span><span class="o">.</span><span class="n">num_outputs</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span> \
|
||
<span class="sa">f</span><span class="s2">"Plugin outputs number mismatch with expected, got </span><span class="si">{</span><span class="n">layer</span><span class="o">.</span><span class="n">num_outputs</span><span class="si">}</span><span class="s2">, expected 1"</span>
|
||
<span class="n">output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">output</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="k">return</span> <span class="n">output</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="RopeEmbeddingUtils">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.RopeEmbeddingUtils">[docs]</a>
|
||
<span class="k">class</span> <span class="nc">RopeEmbeddingUtils</span><span class="p">:</span>
|
||
|
||
<div class="viewcode-block" id="RopeEmbeddingUtils.apply_llama3_scaling">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.RopeEmbeddingUtils.apply_llama3_scaling">[docs]</a>
|
||
<span class="nd">@staticmethod</span>
|
||
<span class="c1"># ref: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py#L298</span>
|
||
<span class="k">def</span> <span class="nf">apply_llama3_scaling</span><span class="p">(</span><span class="n">inv_freqs</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> <span class="n">rope_scaling_config</span><span class="p">:</span> <span class="nb">dict</span><span class="p">):</span>
|
||
|
||
<span class="n">scale_factor</span> <span class="o">=</span> <span class="n">rope_scaling_config</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">"factor"</span><span class="p">,</span> <span class="mf">8.0</span><span class="p">)</span>
|
||
<span class="n">low_freq_factor</span> <span class="o">=</span> <span class="n">rope_scaling_config</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">"low_freq_factor"</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">)</span>
|
||
<span class="n">high_freq_factor</span> <span class="o">=</span> <span class="n">rope_scaling_config</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">"high_freq_factor"</span><span class="p">,</span> <span class="mf">4.0</span><span class="p">)</span>
|
||
<span class="n">old_context_len</span> <span class="o">=</span> <span class="n">rope_scaling_config</span><span class="o">.</span><span class="n">get</span><span class="p">(</span>
|
||
<span class="s2">"original_max_position_embeddings"</span><span class="p">,</span> <span class="mi">8192</span><span class="p">)</span>
|
||
|
||
<span class="n">low_freq_wavelen</span> <span class="o">=</span> <span class="n">old_context_len</span> <span class="o">/</span> <span class="n">low_freq_factor</span>
|
||
<span class="n">high_freq_wavelen</span> <span class="o">=</span> <span class="n">old_context_len</span> <span class="o">/</span> <span class="n">high_freq_factor</span>
|
||
<span class="n">new_inv_freqs</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="k">for</span> <span class="n">inv_freq</span> <span class="ow">in</span> <span class="n">inv_freqs</span><span class="p">:</span>
|
||
<span class="n">wavelen</span> <span class="o">=</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">math</span><span class="o">.</span><span class="n">pi</span> <span class="o">/</span> <span class="n">inv_freq</span>
|
||
<span class="k">if</span> <span class="n">wavelen</span> <span class="o"><</span> <span class="n">high_freq_wavelen</span><span class="p">:</span>
|
||
<span class="n">new_inv_freqs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">inv_freq</span><span class="p">)</span>
|
||
<span class="k">elif</span> <span class="n">wavelen</span> <span class="o">></span> <span class="n">low_freq_wavelen</span><span class="p">:</span>
|
||
<span class="n">new_inv_freqs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">inv_freq</span> <span class="o">/</span> <span class="n">scale_factor</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="k">assert</span> <span class="n">low_freq_wavelen</span> <span class="o">!=</span> <span class="n">high_freq_wavelen</span>
|
||
<span class="n">smooth</span> <span class="o">=</span> <span class="p">(</span><span class="n">old_context_len</span> <span class="o">/</span> <span class="n">wavelen</span> <span class="o">-</span> <span class="n">low_freq_factor</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span>
|
||
<span class="n">high_freq_factor</span> <span class="o">-</span> <span class="n">low_freq_factor</span><span class="p">)</span>
|
||
<span class="n">new_inv_freqs</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="mi">1</span> <span class="o">-</span> <span class="n">smooth</span><span class="p">)</span> <span class="o">*</span> <span class="n">inv_freq</span> <span class="o">/</span> <span class="n">scale_factor</span> <span class="o">+</span>
|
||
<span class="n">smooth</span> <span class="o">*</span> <span class="n">inv_freq</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">new_inv_freqs</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">inv_freqs</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="RopeEmbeddingUtils.create_sinusoidal_positions">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.RopeEmbeddingUtils.create_sinusoidal_positions">[docs]</a>
|
||
<span class="nd">@staticmethod</span>
|
||
<span class="k">def</span> <span class="nf">create_sinusoidal_positions</span><span class="p">(</span><span class="n">num_pos</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">theta</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">10000.0</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">):</span>
|
||
<span class="n">inv_freq</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">/</span> <span class="p">(</span><span class="n">theta</span><span class="o">**</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span> <span class="o">/</span> <span class="n">dim</span><span class="p">))</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span>
|
||
<span class="n">sinusoid_inp</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s2">"i , j -> i j"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">num_pos</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">),</span>
|
||
<span class="n">inv_freq</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
|
||
<span class="n">concat</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">((</span><span class="n">np</span><span class="o">.</span><span class="n">sin</span><span class="p">(</span><span class="n">sinusoid_inp</span><span class="p">),</span> <span class="n">np</span><span class="o">.</span><span class="n">cos</span><span class="p">(</span><span class="n">sinusoid_inp</span><span class="p">)),</span>
|
||
<span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">concat</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="RopeEmbeddingUtils.create_sinusoidal_positions_for_attention_plugin">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.RopeEmbeddingUtils.create_sinusoidal_positions_for_attention_plugin">[docs]</a>
|
||
<span class="nd">@staticmethod</span>
|
||
<span class="k">def</span> <span class="nf">create_sinusoidal_positions_for_attention_plugin</span><span class="p">(</span>
|
||
<span class="n">num_pos</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">theta</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">10000.0</span><span class="p">,</span>
|
||
<span class="n">scale</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span>
|
||
<span class="n">scale_type</span><span class="p">:</span> <span class="n">RotaryScalingType</span> <span class="o">=</span> <span class="n">RotaryScalingType</span><span class="o">.</span><span class="n">none</span><span class="p">,</span>
|
||
<span class="c1"># Other scaling configs that only used by certain scaling types.</span>
|
||
<span class="n">rope_scaling_config</span><span class="p">:</span> <span class="nb">dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="n">scale_type</span> <span class="o">==</span> <span class="n">RotaryScalingType</span><span class="o">.</span><span class="n">linear</span><span class="p">:</span>
|
||
<span class="n">scale</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">/</span> <span class="n">scale</span>
|
||
<span class="k">if</span> <span class="n">scale_type</span> <span class="o">==</span> <span class="n">RotaryScalingType</span><span class="o">.</span><span class="n">llama3</span><span class="p">:</span>
|
||
<span class="k">assert</span> <span class="n">rope_scaling_config</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">,</span> <span class="s2">"rotary_scaling config must be provided."</span>
|
||
<span class="n">inv_freq</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">/</span> <span class="p">(</span><span class="n">theta</span><span class="o">**</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span> <span class="o">/</span> <span class="n">dim</span><span class="p">))</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span>
|
||
<span class="n">inv_freq</span> <span class="o">=</span> <span class="n">RopeEmbeddingUtils</span><span class="o">.</span><span class="n">apply_llama3_scaling</span><span class="p">(</span>
|
||
<span class="n">inv_freq</span><span class="p">,</span> <span class="n">rope_scaling_config</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">inv_freq</span> <span class="o">=</span> <span class="n">scale</span> <span class="o">/</span> <span class="p">(</span><span class="n">theta</span>
|
||
<span class="o">**</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span> <span class="o">/</span> <span class="n">dim</span><span class="p">))</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span>
|
||
<span class="n">sinusoid_inp</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s2">"i , j -> i j"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">num_pos</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">),</span>
|
||
<span class="n">inv_freq</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">),</span>
|
||
<span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="c1"># fuse cos/sin into float2 (cos, sin).</span>
|
||
<span class="n">concat</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">(</span>
|
||
<span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">cos</span><span class="p">(</span><span class="n">sinusoid_inp</span><span class="p">),</span> <span class="n">np</span><span class="o">.</span><span class="n">sin</span><span class="p">(</span><span class="n">sinusoid_inp</span><span class="p">)),</span>
|
||
<span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span> <span class="c1">#np.cos(sinusoid_inp).shape = (32768, 64, 1)</span>
|
||
|
||
<span class="k">return</span> <span class="n">inv_freq</span><span class="p">,</span> <span class="n">concat</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="RopeEmbeddingUtils.create_sinusoidal_positions_for_cogvlm_attention_plugin">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.RopeEmbeddingUtils.create_sinusoidal_positions_for_cogvlm_attention_plugin">[docs]</a>
|
||
<span class="nd">@staticmethod</span>
|
||
<span class="k">def</span> <span class="nf">create_sinusoidal_positions_for_cogvlm_attention_plugin</span><span class="p">(</span>
|
||
<span class="n">num_pos</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">theta</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">10000.0</span><span class="p">,</span>
|
||
<span class="n">scale</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span>
|
||
<span class="n">scale_type</span><span class="p">:</span> <span class="n">RotaryScalingType</span> <span class="o">=</span> <span class="n">RotaryScalingType</span><span class="o">.</span><span class="n">none</span><span class="p">,</span>
|
||
<span class="n">vision_start</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">vision_length</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1225</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="n">scale_type</span> <span class="o">==</span> <span class="n">RotaryScalingType</span><span class="o">.</span><span class="n">linear</span><span class="p">:</span>
|
||
<span class="n">scale</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">/</span> <span class="n">scale</span>
|
||
<span class="n">inv_freq</span> <span class="o">=</span> <span class="n">scale</span> <span class="o">/</span> <span class="p">(</span><span class="n">theta</span><span class="o">**</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span> <span class="o">/</span> <span class="n">dim</span><span class="p">))</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span>
|
||
<span class="n">position_id</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">hstack</span><span class="p">([</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">vision_start</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">),</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">full</span><span class="p">(</span><span class="n">vision_length</span><span class="p">,</span> <span class="n">vision_start</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">),</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">vision_start</span> <span class="o">+</span> <span class="mi">2</span><span class="p">,</span>
|
||
<span class="n">num_pos</span> <span class="o">-</span> <span class="p">(</span><span class="n">vision_length</span> <span class="o">-</span> <span class="mi">1</span><span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
|
||
<span class="p">])</span>
|
||
<span class="n">sinusoid_inp</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s2">"i , j -> i j"</span><span class="p">,</span>
|
||
<span class="n">position_id</span><span class="p">,</span>
|
||
<span class="n">inv_freq</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">),</span>
|
||
<span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="c1"># fuse cos/sin into float2 (cos, sin).</span>
|
||
<span class="n">concat</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">((</span><span class="n">np</span><span class="o">.</span><span class="n">cos</span><span class="p">(</span><span class="n">sinusoid_inp</span><span class="p">),</span> <span class="n">np</span><span class="o">.</span><span class="n">sin</span><span class="p">(</span><span class="n">sinusoid_inp</span><span class="p">)),</span>
|
||
<span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
|
||
|
||
<span class="k">return</span> <span class="n">inv_freq</span><span class="p">,</span> <span class="n">concat</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="RopeEmbeddingUtils.create_sinusoidal_positions_long_rope">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.RopeEmbeddingUtils.create_sinusoidal_positions_long_rope">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">create_sinusoidal_positions_long_rope</span><span class="p">(</span>
|
||
<span class="n">num_pos</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">num_orig_pos</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">theta</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">10000.0</span><span class="p">,</span>
|
||
<span class="n">scaling_short_factors</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span>
|
||
<span class="n">scaling_long_factors</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span>
|
||
<span class="n">short_mscale</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">long_mscale</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">):</span>
|
||
|
||
<span class="k">def</span> <span class="nf">_calc_mscale</span><span class="p">(</span><span class="n">scale</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="n">scale</span> <span class="o"><=</span> <span class="mf">1.0</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="mf">1.0</span>
|
||
<span class="k">return</span> <span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="mi">1</span> <span class="o">+</span> <span class="n">math</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">scale</span><span class="p">)</span> <span class="o">/</span> <span class="n">math</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">num_orig_pos</span><span class="p">))</span>
|
||
|
||
<span class="k">if</span> <span class="n">short_mscale</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">short_mscale</span> <span class="o">=</span> <span class="n">_calc_mscale</span><span class="p">(</span><span class="n">num_pos</span> <span class="o">/</span> <span class="n">num_orig_pos</span><span class="p">)</span>
|
||
<span class="n">long_mscale</span> <span class="o">=</span> <span class="n">short_mscale</span>
|
||
|
||
<span class="k">def</span> <span class="nf">_compute_sinusoidal_positions</span><span class="p">(</span><span class="n">scale_factors</span><span class="p">,</span> <span class="n">is_short</span><span class="p">,</span>
|
||
<span class="n">for_attention_plugin</span><span class="p">):</span>
|
||
<span class="n">inv_freq</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">/</span> <span class="p">(</span><span class="n">scale_factors</span> <span class="o">*</span>
|
||
<span class="p">(</span><span class="n">theta</span><span class="o">**</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span> <span class="o">/</span> <span class="n">dim</span><span class="p">))</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">dtype</span><span class="p">))</span>
|
||
<span class="n">sinusoid_inp</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s2">"i , j -> i j"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">num_pos</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">),</span>
|
||
<span class="n">inv_freq</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">for_attention_plugin</span><span class="p">:</span>
|
||
<span class="n">sinusoid_inp</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">sinusoid_inp</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="n">concat</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">(</span>
|
||
<span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">cos</span><span class="p">(</span><span class="n">sinusoid_inp</span><span class="p">),</span> <span class="n">np</span><span class="o">.</span><span class="n">sin</span><span class="p">(</span><span class="n">sinusoid_inp</span><span class="p">)),</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">concat</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">(</span>
|
||
<span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">sin</span><span class="p">(</span><span class="n">sinusoid_inp</span><span class="p">),</span> <span class="n">np</span><span class="o">.</span><span class="n">cos</span><span class="p">(</span><span class="n">sinusoid_inp</span><span class="p">)),</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="n">concat</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">concat</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
|
||
|
||
<span class="n">mscale</span> <span class="o">=</span> <span class="n">short_mscale</span> <span class="k">if</span> <span class="n">is_short</span> <span class="k">else</span> <span class="n">long_mscale</span>
|
||
<span class="n">concat</span> <span class="o">=</span> <span class="n">concat</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span> <span class="o">*</span> <span class="n">mscale</span>
|
||
|
||
<span class="c1"># gpt attention plugins also need inv_freq.</span>
|
||
<span class="k">if</span> <span class="n">for_attention_plugin</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">inv_freq</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">),</span> <span class="n">concat</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">concat</span>
|
||
|
||
<span class="k">return</span> <span class="n">_compute_sinusoidal_positions</span><span class="p">(</span>
|
||
<span class="n">scaling_short_factors</span><span class="p">,</span> <span class="kc">True</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span> <span class="n">_compute_sinusoidal_positions</span><span class="p">(</span>
|
||
<span class="n">scaling_long_factors</span><span class="p">,</span>
|
||
<span class="kc">False</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span> <span class="n">_compute_sinusoidal_positions</span><span class="p">(</span>
|
||
<span class="n">scaling_short_factors</span><span class="p">,</span> <span class="kc">True</span><span class="p">,</span>
|
||
<span class="kc">True</span><span class="p">),</span> <span class="n">_compute_sinusoidal_positions</span><span class="p">(</span>
|
||
<span class="n">scaling_long_factors</span><span class="p">,</span> <span class="kc">False</span><span class="p">,</span> <span class="kc">True</span><span class="p">),</span> <span class="n">short_mscale</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="RopeEmbeddingUtils.create_fake_weight">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.RopeEmbeddingUtils.create_fake_weight">[docs]</a>
|
||
<span class="nd">@staticmethod</span>
|
||
<span class="k">def</span> <span class="nf">create_fake_weight</span><span class="p">(</span><span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">half</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="RopeEmbeddingUtils.create_sinusoidal_positions_for_deepseek_attention_plugin">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.RopeEmbeddingUtils.create_sinusoidal_positions_for_deepseek_attention_plugin">[docs]</a>
|
||
<span class="nd">@staticmethod</span>
|
||
<span class="k">def</span> <span class="nf">create_sinusoidal_positions_for_deepseek_attention_plugin</span><span class="p">(</span>
|
||
<span class="n">num_pos</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">base</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">10000</span><span class="p">,</span>
|
||
<span class="n">scaling_factor</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span>
|
||
<span class="n">original_max_position_embeddings</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">4096</span><span class="p">,</span>
|
||
<span class="n">beta_fast</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">32</span><span class="p">,</span>
|
||
<span class="n">beta_slow</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">mscale</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span>
|
||
<span class="n">mscale_all_dim</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">):</span>
|
||
|
||
<span class="c1"># Copy from https://huggingface.co/deepseek-ai/DeepSeek-V2/blob/main/modeling_deepseek.py</span>
|
||
<span class="c1"># Inverse dim formula to find dim based on number of rotations</span>
|
||
<span class="k">def</span> <span class="nf">yarn_find_correction_dim</span><span class="p">(</span><span class="n">num_rotations</span><span class="p">,</span>
|
||
<span class="n">dim</span><span class="p">,</span>
|
||
<span class="n">base</span><span class="o">=</span><span class="mi">10000</span><span class="p">,</span>
|
||
<span class="n">max_position_embeddings</span><span class="o">=</span><span class="mi">2048</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="p">(</span><span class="n">dim</span> <span class="o">*</span> <span class="n">math</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">max_position_embeddings</span> <span class="o">/</span>
|
||
<span class="p">(</span><span class="n">num_rotations</span> <span class="o">*</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">math</span><span class="o">.</span><span class="n">pi</span><span class="p">)))</span> <span class="o">/</span> <span class="p">(</span>
|
||
<span class="mi">2</span> <span class="o">*</span> <span class="n">math</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">base</span><span class="p">))</span>
|
||
|
||
<span class="c1"># Find dim range bounds based on rotations</span>
|
||
<span class="k">def</span> <span class="nf">yarn_find_correction_range</span><span class="p">(</span><span class="n">low_rot</span><span class="p">,</span>
|
||
<span class="n">high_rot</span><span class="p">,</span>
|
||
<span class="n">dim</span><span class="p">,</span>
|
||
<span class="n">base</span><span class="o">=</span><span class="mi">10000</span><span class="p">,</span>
|
||
<span class="n">max_position_embeddings</span><span class="o">=</span><span class="mi">2048</span><span class="p">):</span>
|
||
<span class="n">low</span> <span class="o">=</span> <span class="n">math</span><span class="o">.</span><span class="n">floor</span><span class="p">(</span>
|
||
<span class="n">yarn_find_correction_dim</span><span class="p">(</span><span class="n">low_rot</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">base</span><span class="p">,</span>
|
||
<span class="n">max_position_embeddings</span><span class="p">))</span>
|
||
<span class="n">high</span> <span class="o">=</span> <span class="n">math</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span>
|
||
<span class="n">yarn_find_correction_dim</span><span class="p">(</span><span class="n">high_rot</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">base</span><span class="p">,</span>
|
||
<span class="n">max_position_embeddings</span><span class="p">))</span>
|
||
<span class="k">if</span> <span class="n">low</span> <span class="o"><</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="n">low</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="k">if</span> <span class="n">high</span> <span class="o">></span> <span class="n">dim</span> <span class="o">-</span> <span class="mi">1</span><span class="p">:</span>
|
||
<span class="n">high</span> <span class="o">=</span> <span class="n">dim</span> <span class="o">-</span> <span class="mi">1</span>
|
||
<span class="k">return</span> <span class="n">low</span><span class="p">,</span> <span class="n">high</span> <span class="c1"># Clamp values just in case</span>
|
||
|
||
<span class="k">def</span> <span class="nf">yarn_get_mscale</span><span class="p">(</span><span class="n">scale</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">mscale</span><span class="o">=</span><span class="mi">1</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="n">scale</span> <span class="o"><=</span> <span class="mi">1</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="mf">1.0</span>
|
||
<span class="k">return</span> <span class="mf">0.1</span> <span class="o">*</span> <span class="n">mscale</span> <span class="o">*</span> <span class="n">math</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">scale</span><span class="p">)</span> <span class="o">+</span> <span class="mf">1.0</span>
|
||
|
||
<span class="k">def</span> <span class="nf">yarn_linear_ramp_mask</span><span class="p">(</span><span class="nb">min</span><span class="p">,</span> <span class="nb">max</span><span class="p">,</span> <span class="n">dim</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="nb">min</span> <span class="o">==</span> <span class="nb">max</span><span class="p">:</span>
|
||
<span class="nb">max</span> <span class="o">+=</span> <span class="mf">0.001</span> <span class="c1"># Prevent singularity</span>
|
||
|
||
<span class="n">linear_func</span> <span class="o">=</span> <span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span> <span class="o">-</span> <span class="nb">min</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span><span class="nb">max</span> <span class="o">-</span> <span class="nb">min</span><span class="p">)</span>
|
||
<span class="n">ramp_func</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">clip</span><span class="p">(</span><span class="n">linear_func</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">ramp_func</span>
|
||
|
||
<span class="n">freq_extra</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">/</span> <span class="p">(</span><span class="n">base</span><span class="o">**</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span> <span class="o">/</span> <span class="n">dim</span><span class="p">))</span>
|
||
<span class="n">freq_inter</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">/</span> <span class="p">(</span><span class="n">scaling_factor</span> <span class="o">*</span>
|
||
<span class="n">base</span><span class="o">**</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span> <span class="o">/</span> <span class="n">dim</span><span class="p">))</span>
|
||
|
||
<span class="n">low</span><span class="p">,</span> <span class="n">high</span> <span class="o">=</span> <span class="n">yarn_find_correction_range</span><span class="p">(</span>
|
||
<span class="n">beta_fast</span><span class="p">,</span>
|
||
<span class="n">beta_slow</span><span class="p">,</span>
|
||
<span class="n">dim</span><span class="p">,</span>
|
||
<span class="n">base</span><span class="p">,</span>
|
||
<span class="n">original_max_position_embeddings</span><span class="p">,</span>
|
||
<span class="p">)</span>
|
||
<span class="n">inv_freq_mask</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">-</span> <span class="n">yarn_linear_ramp_mask</span><span class="p">(</span><span class="n">low</span><span class="p">,</span> <span class="n">high</span><span class="p">,</span>
|
||
<span class="n">dim</span> <span class="o">//</span> <span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span>
|
||
<span class="n">inv_freq</span> <span class="o">=</span> <span class="n">freq_inter</span> <span class="o">*</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">inv_freq_mask</span><span class="p">)</span> <span class="o">+</span> <span class="n">freq_extra</span> <span class="o">*</span> <span class="n">inv_freq_mask</span>
|
||
<span class="n">t</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">num_pos</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
|
||
|
||
<span class="n">freqs</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">outer</span><span class="p">(</span><span class="n">t</span><span class="p">,</span> <span class="n">inv_freq</span><span class="p">)</span>
|
||
|
||
<span class="n">_mscale</span> <span class="o">=</span> <span class="nb">float</span><span class="p">(</span>
|
||
<span class="n">yarn_get_mscale</span><span class="p">(</span><span class="n">scaling_factor</span><span class="p">,</span> <span class="n">mscale</span><span class="p">)</span> <span class="o">/</span>
|
||
<span class="n">yarn_get_mscale</span><span class="p">(</span><span class="n">scaling_factor</span><span class="p">,</span> <span class="n">mscale_all_dim</span><span class="p">))</span>
|
||
|
||
<span class="n">emb</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">((</span><span class="n">freqs</span><span class="p">,</span> <span class="n">freqs</span><span class="p">),</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
|
||
|
||
<span class="n">concat</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">((</span><span class="n">np</span><span class="o">.</span><span class="n">cos</span><span class="p">(</span><span class="n">emb</span><span class="p">)</span> <span class="o">*</span> <span class="n">_mscale</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">sin</span><span class="p">(</span><span class="n">emb</span><span class="p">)</span> <span class="o">*</span> <span class="n">_mscale</span><span class="p">),</span>
|
||
<span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
|
||
|
||
<span class="n">concat</span> <span class="o">=</span> <span class="n">concat</span><span class="o">.</span><span class="n">reshape</span><span class="p">((</span><span class="n">num_pos</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">dim</span><span class="p">))</span>
|
||
<span class="n">concat</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="n">concat</span><span class="p">,</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
|
||
|
||
<span class="k">return</span> <span class="n">concat</span><span class="o">.</span><span class="n">reshape</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">))</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="RopeEmbeddingUtils.rotate_every_two">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.RopeEmbeddingUtils.rotate_every_two">[docs]</a>
|
||
<span class="nd">@staticmethod</span>
|
||
<span class="k">def</span> <span class="nf">rotate_every_two</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="k">assert</span> <span class="n">tensor</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">==</span> <span class="mi">4</span>
|
||
|
||
<span class="n">shape_tensor</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span>
|
||
<span class="n">shape</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">i</span><span class="p">)</span> <span class="o">/</span> <span class="mi">2</span> <span class="k">if</span> <span class="n">i</span> <span class="o">==</span> <span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">-</span>
|
||
<span class="mi">1</span><span class="p">)</span> <span class="k">else</span> <span class="n">shape</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">i</span><span class="p">)</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">ndim</span><span class="p">())</span>
|
||
<span class="p">])</span>
|
||
<span class="n">x1</span> <span class="o">=</span> <span class="nb">slice</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">shape_tensor</span><span class="p">,</span> <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">])</span>
|
||
<span class="n">x2</span> <span class="o">=</span> <span class="nb">slice</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">],</span> <span class="n">shape_tensor</span><span class="p">,</span> <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">])</span>
|
||
<span class="n">x1</span> <span class="o">=</span> <span class="n">expand_dims</span><span class="p">(</span><span class="n">x1</span><span class="p">,</span> <span class="mi">4</span><span class="p">)</span>
|
||
<span class="n">x2</span> <span class="o">=</span> <span class="n">expand_dims</span><span class="p">(</span><span class="n">x2</span><span class="p">,</span> <span class="mi">4</span><span class="p">)</span>
|
||
<span class="n">zero</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">ascontiguousarray</span><span class="p">(</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="mi">1</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">trt_dtype_to_np</span><span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">dtype</span><span class="p">))))</span>
|
||
<span class="n">x2</span> <span class="o">=</span> <span class="n">zero</span> <span class="o">-</span> <span class="n">x2</span>
|
||
<span class="n">x</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span><span class="n">x2</span><span class="p">,</span> <span class="n">x1</span><span class="p">],</span> <span class="mi">4</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">view</span><span class="p">(</span>
|
||
<span class="n">x</span><span class="p">,</span> <span class="n">concat</span><span class="p">([</span><span class="n">shape</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span>
|
||
<span class="n">shape</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
|
||
<span class="n">shape</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mi">2</span><span class="p">),</span>
|
||
<span class="n">shape</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span> <span class="o">*</span> <span class="mi">2</span><span class="p">]))</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="RopeEmbeddingUtils.rotate_half">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.RopeEmbeddingUtils.rotate_half">[docs]</a>
|
||
<span class="nd">@staticmethod</span>
|
||
<span class="k">def</span> <span class="nf">rotate_half</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="c1"># [bs, num_attention_kv_heads, seqlen, attention_head_size]</span>
|
||
<span class="k">assert</span> <span class="n">tensor</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">==</span> <span class="mi">4</span>
|
||
<span class="n">shape_tensor</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span>
|
||
<span class="n">shape</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">i</span><span class="p">)</span> <span class="o">/</span> <span class="mi">2</span> <span class="k">if</span> <span class="n">i</span> <span class="o">==</span> <span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">-</span>
|
||
<span class="mi">1</span><span class="p">)</span> <span class="k">else</span> <span class="n">shape</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">i</span><span class="p">)</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">ndim</span><span class="p">())</span>
|
||
<span class="p">])</span>
|
||
<span class="n">last_dim</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">tensor</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">/</span> <span class="mi">2</span>
|
||
<span class="n">x1</span> <span class="o">=</span> <span class="nb">slice</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">shape_tensor</span><span class="p">,</span> <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">])</span>
|
||
<span class="n">x2</span> <span class="o">=</span> <span class="nb">slice</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">concat</span><span class="p">([</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">last_dim</span><span class="p">]),</span> <span class="n">shape_tensor</span><span class="p">,</span>
|
||
<span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">])</span>
|
||
<span class="n">zero</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">ascontiguousarray</span><span class="p">(</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="mi">1</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">trt_dtype_to_np</span><span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">dtype</span><span class="p">))))</span>
|
||
<span class="n">x2</span> <span class="o">=</span> <span class="n">zero</span> <span class="o">-</span> <span class="n">x2</span>
|
||
<span class="n">x</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span><span class="n">x2</span><span class="p">,</span> <span class="n">x1</span><span class="p">],</span> <span class="mi">3</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">x</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="RopeEmbeddingUtils.apply_rotary_pos_emb">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.RopeEmbeddingUtils.apply_rotary_pos_emb">[docs]</a>
|
||
<span class="nd">@staticmethod</span>
|
||
<span class="k">def</span> <span class="nf">apply_rotary_pos_emb</span><span class="p">(</span>
|
||
<span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">position_embedding</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">pos_emb_type</span><span class="p">:</span> <span class="n">PositionEmbeddingType</span> <span class="o">=</span> <span class="n">PositionEmbeddingType</span><span class="o">.</span><span class="n">rope_gptj</span>
|
||
<span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
|
||
<span class="n">rotate_func</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="k">if</span> <span class="n">pos_emb_type</span> <span class="o">==</span> <span class="n">PositionEmbeddingType</span><span class="o">.</span><span class="n">rope_gpt_neox</span> <span class="ow">or</span> <span class="n">pos_emb_type</span> <span class="o">==</span> <span class="n">PositionEmbeddingType</span><span class="o">.</span><span class="n">long_rope</span><span class="p">:</span>
|
||
<span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">position_embedding</span><span class="p">)</span> <span class="o">==</span> <span class="mi">2</span>
|
||
<span class="n">cos</span><span class="p">,</span> <span class="n">sin</span> <span class="o">=</span> <span class="n">position_embedding</span>
|
||
<span class="n">sin</span> <span class="o">=</span> <span class="n">expand_dims</span><span class="p">(</span><span class="n">sin</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
|
||
<span class="n">cos</span> <span class="o">=</span> <span class="n">expand_dims</span><span class="p">(</span><span class="n">cos</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
|
||
<span class="n">sin</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span><span class="n">sin</span><span class="p">,</span> <span class="n">sin</span><span class="p">],</span> <span class="mi">3</span><span class="p">)</span>
|
||
<span class="n">cos</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span><span class="n">cos</span><span class="p">,</span> <span class="n">cos</span><span class="p">],</span> <span class="mi">3</span><span class="p">)</span>
|
||
<span class="n">rotate_func</span> <span class="o">=</span> <span class="n">RopeEmbeddingUtils</span><span class="o">.</span><span class="n">rotate_half</span>
|
||
<span class="k">elif</span> <span class="n">pos_emb_type</span> <span class="o">==</span> <span class="n">PositionEmbeddingType</span><span class="o">.</span><span class="n">rope_gptj</span><span class="p">:</span>
|
||
<span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">position_embedding</span><span class="p">)</span> <span class="o">==</span> <span class="mi">2</span>
|
||
<span class="n">cos</span><span class="p">,</span> <span class="n">sin</span> <span class="o">=</span> <span class="n">position_embedding</span>
|
||
<span class="n">sin</span> <span class="o">=</span> <span class="n">expand_dims</span><span class="p">(</span><span class="n">sin</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
|
||
<span class="n">cos</span> <span class="o">=</span> <span class="n">expand_dims</span><span class="p">(</span><span class="n">cos</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
|
||
<span class="n">sin</span> <span class="o">=</span> <span class="n">repeat_interleave</span><span class="p">(</span><span class="n">sin</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span>
|
||
<span class="n">cos</span> <span class="o">=</span> <span class="n">repeat_interleave</span><span class="p">(</span><span class="n">cos</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span>
|
||
<span class="n">rotate_func</span> <span class="o">=</span> <span class="n">RopeEmbeddingUtils</span><span class="o">.</span><span class="n">rotate_every_two</span>
|
||
<span class="k">elif</span> <span class="n">pos_emb_type</span> <span class="o">==</span> <span class="n">PositionEmbeddingType</span><span class="o">.</span><span class="n">chatglm</span><span class="p">:</span>
|
||
<span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">position_embedding</span><span class="p">)</span> <span class="o">==</span> <span class="mi">4</span>
|
||
<span class="n">cos0</span><span class="p">,</span> <span class="n">cos1</span><span class="p">,</span> <span class="n">sin0</span><span class="p">,</span> <span class="n">sin1</span> <span class="o">=</span> <span class="n">position_embedding</span>
|
||
<span class="n">shape_tensor</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span>
|
||
<span class="n">shape</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">i</span><span class="p">)</span> <span class="o">/</span> <span class="mi">2</span> <span class="k">if</span> <span class="n">i</span> <span class="o">==</span> <span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">-</span>
|
||
<span class="mi">1</span><span class="p">)</span> <span class="k">else</span> <span class="n">shape</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">i</span><span class="p">)</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">ndim</span><span class="p">())</span>
|
||
<span class="p">])</span>
|
||
<span class="n">last_dim</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">tensor</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">/</span> <span class="mi">2</span>
|
||
<span class="n">x_part0</span> <span class="o">=</span> <span class="nb">slice</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">shape_tensor</span><span class="p">,</span> <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">])</span>
|
||
<span class="n">x_part1</span> <span class="o">=</span> <span class="nb">slice</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">concat</span><span class="p">([</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">last_dim</span><span class="p">]),</span> <span class="n">shape_tensor</span><span class="p">,</span>
|
||
<span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">])</span>
|
||
|
||
<span class="n">y_part0</span> <span class="o">=</span> <span class="p">(</span><span class="n">x_part0</span> <span class="o">*</span>
|
||
<span class="n">cos0</span><span class="p">)</span> <span class="o">+</span> <span class="p">(</span><span class="n">RopeEmbeddingUtils</span><span class="o">.</span><span class="n">rotate_half</span><span class="p">(</span><span class="n">x_part0</span><span class="p">)</span> <span class="o">*</span> <span class="n">sin0</span><span class="p">)</span>
|
||
<span class="n">y_part1</span> <span class="o">=</span> <span class="p">(</span><span class="n">x_part1</span> <span class="o">*</span>
|
||
<span class="n">cos1</span><span class="p">)</span> <span class="o">+</span> <span class="p">(</span><span class="n">RopeEmbeddingUtils</span><span class="o">.</span><span class="n">rotate_half</span><span class="p">(</span><span class="n">x_part1</span><span class="p">)</span> <span class="o">*</span> <span class="n">sin1</span><span class="p">)</span>
|
||
|
||
<span class="n">result</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span><span class="n">y_part0</span><span class="p">,</span> <span class="n">y_part1</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">3</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">result</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">shape</span><span class="p">(</span><span class="n">tensor</span><span class="p">))</span>
|
||
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s1">'The PositionEmbeddingType is not RoPE'</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="p">(</span><span class="n">tensor</span> <span class="o">*</span> <span class="n">cos</span><span class="p">)</span> <span class="o">+</span> <span class="p">(</span><span class="n">rotate_func</span><span class="p">(</span><span class="n">tensor</span><span class="p">)</span> <span class="o">*</span> <span class="n">sin</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="RopeEmbeddingUtils.apply_rotary_pos_emb_chatglm">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.RopeEmbeddingUtils.apply_rotary_pos_emb_chatglm">[docs]</a>
|
||
<span class="nd">@staticmethod</span>
|
||
<span class="k">def</span> <span class="nf">apply_rotary_pos_emb_chatglm</span><span class="p">(</span><span class="n">qkv</span><span class="p">,</span> <span class="n">position_embedding</span><span class="p">,</span>
|
||
<span class="n">num_attention_heads</span><span class="p">,</span> <span class="n">attention_head_size</span><span class="p">,</span>
|
||
<span class="n">max_position_embeddings</span><span class="p">,</span>
|
||
<span class="n">rotary_embedding_scale</span><span class="p">,</span>
|
||
<span class="n">remove_input_padding</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
|
||
<span class="n">half_head_size</span> <span class="o">=</span> <span class="n">attention_head_size</span> <span class="o">//</span> <span class="mi">2</span>
|
||
<span class="nb">input</span> <span class="o">=</span> <span class="n">qkv</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">qkv</span><span class="p">,</span> <span class="nb">list</span><span class="p">)</span> <span class="k">else</span> <span class="n">qkv</span>
|
||
<span class="n">input_shape</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">)</span>
|
||
<span class="n">batch_size</span> <span class="o">=</span> <span class="mi">1</span> <span class="k">if</span> <span class="n">remove_input_padding</span> <span class="k">else</span> <span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
|
||
<span class="n">seqlen</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="mi">0</span> <span class="k">if</span> <span class="n">remove_input_padding</span> <span class="k">else</span> <span class="mi">1</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">qkv</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span>
|
||
<span class="n">query</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span> <span class="o">=</span> <span class="n">qkv</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">qkv</span> <span class="o">=</span> <span class="n">qkv</span><span class="o">.</span><span class="n">view</span><span class="p">(</span>
|
||
<span class="n">concat</span><span class="p">([</span>
|
||
<span class="n">batch_size</span><span class="p">,</span>
|
||
<span class="n">seqlen</span><span class="p">,</span>
|
||
<span class="n">num_attention_heads</span><span class="p">,</span>
|
||
<span class="mi">3</span><span class="p">,</span>
|
||
<span class="n">attention_head_size</span><span class="p">,</span>
|
||
<span class="p">]))</span>
|
||
<span class="n">query</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span> <span class="o">=</span> <span class="n">split</span><span class="p">(</span><span class="n">qkv</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">3</span><span class="p">)</span>
|
||
<span class="n">q_shape</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span>
|
||
<span class="n">batch_size</span><span class="p">,</span>
|
||
<span class="n">seqlen</span><span class="p">,</span>
|
||
<span class="n">num_attention_heads</span><span class="p">,</span>
|
||
<span class="n">attention_head_size</span><span class="p">,</span>
|
||
<span class="p">])</span>
|
||
<span class="n">query</span> <span class="o">=</span> <span class="n">query</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">q_shape</span><span class="p">)</span>
|
||
<span class="n">key</span> <span class="o">=</span> <span class="n">key</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">q_shape</span><span class="p">)</span>
|
||
<span class="n">value</span> <span class="o">=</span> <span class="n">value</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">q_shape</span><span class="p">)</span>
|
||
|
||
<span class="n">embedding_weight</span> <span class="o">=</span> <span class="n">RopeEmbeddingUtils</span><span class="o">.</span><span class="n">create_sinusoidal_positions</span><span class="p">(</span>
|
||
<span class="n">max_position_embeddings</span><span class="p">,</span> <span class="n">half_head_size</span><span class="p">)</span>
|
||
<span class="n">embedding_weight</span> <span class="o">/=</span> <span class="n">rotary_embedding_scale</span>
|
||
<span class="n">embedding_weight</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">embedding_weight</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="mi">2</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="n">embedding_weight</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">(</span>
|
||
<span class="p">[</span>
|
||
<span class="n">embedding_weight</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span>
|
||
<span class="n">embedding_weight</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span>
|
||
<span class="n">embedding_weight</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span>
|
||
<span class="n">embedding_weight</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span>
|
||
<span class="p">],</span>
|
||
<span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
|
||
<span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">remove_input_padding</span><span class="p">:</span>
|
||
<span class="n">position_embedding</span> <span class="o">=</span> <span class="n">unsqueeze</span><span class="p">(</span><span class="n">position_embedding</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
|
||
|
||
<span class="n">embedding_weight</span> <span class="o">=</span> <span class="n">embedding_weight</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">trt_dtype_to_np</span><span class="p">(</span><span class="n">query</span><span class="o">.</span><span class="n">dtype</span><span class="p">))</span>
|
||
<span class="n">embedding_weight</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">embedding_weight</span><span class="p">)</span>
|
||
<span class="n">position_embedding</span> <span class="o">=</span> <span class="n">embedding</span><span class="p">(</span><span class="n">position_embedding</span><span class="p">,</span> <span class="n">embedding_weight</span><span class="p">)</span>
|
||
<span class="n">position_embedding</span><span class="p">,</span> <span class="n">block_embedding</span> <span class="o">=</span> <span class="n">split</span><span class="p">(</span>
|
||
<span class="n">position_embedding</span><span class="p">,</span>
|
||
<span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
|
||
<span class="p">)</span>
|
||
<span class="n">sin0</span><span class="p">,</span> <span class="n">cos0</span> <span class="o">=</span> <span class="n">split</span><span class="p">(</span><span class="n">position_embedding</span><span class="p">,</span> <span class="n">half_head_size</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">3</span><span class="p">)</span>
|
||
<span class="n">sin1</span><span class="p">,</span> <span class="n">cos1</span> <span class="o">=</span> <span class="n">split</span><span class="p">(</span><span class="n">block_embedding</span><span class="p">,</span> <span class="n">half_head_size</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">3</span><span class="p">)</span>
|
||
|
||
<span class="n">new_shape</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span>
|
||
<span class="n">batch_size</span><span class="p">,</span>
|
||
<span class="n">seqlen</span><span class="p">,</span>
|
||
<span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">half_head_size</span><span class="p">,</span>
|
||
<span class="p">])</span>
|
||
<span class="n">position_embedding</span> <span class="o">=</span> <span class="p">[</span>
|
||
<span class="n">tensor</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">new_shape</span><span class="p">)</span> <span class="k">for</span> <span class="n">tensor</span> <span class="ow">in</span> <span class="p">[</span><span class="n">cos0</span><span class="p">,</span> <span class="n">cos1</span><span class="p">,</span> <span class="n">sin0</span><span class="p">,</span> <span class="n">sin1</span><span class="p">]</span>
|
||
<span class="p">]</span>
|
||
|
||
<span class="n">query</span> <span class="o">=</span> <span class="n">RopeEmbeddingUtils</span><span class="o">.</span><span class="n">apply_rotary_pos_emb</span><span class="p">(</span>
|
||
<span class="n">tensor</span><span class="o">=</span><span class="n">query</span><span class="p">,</span>
|
||
<span class="n">position_embedding</span><span class="o">=</span><span class="n">position_embedding</span><span class="p">,</span>
|
||
<span class="n">pos_emb_type</span><span class="o">=</span><span class="n">PositionEmbeddingType</span><span class="o">.</span><span class="n">chatglm</span><span class="p">)</span>
|
||
<span class="n">key</span> <span class="o">=</span> <span class="n">RopeEmbeddingUtils</span><span class="o">.</span><span class="n">apply_rotary_pos_emb</span><span class="p">(</span>
|
||
<span class="n">tensor</span><span class="o">=</span><span class="n">key</span><span class="p">,</span>
|
||
<span class="n">position_embedding</span><span class="o">=</span><span class="n">position_embedding</span><span class="p">,</span>
|
||
<span class="n">pos_emb_type</span><span class="o">=</span><span class="n">PositionEmbeddingType</span><span class="o">.</span><span class="n">chatglm</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">qkv</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span>
|
||
<span class="n">qkv</span> <span class="o">=</span> <span class="p">[</span>
|
||
<span class="n">query</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">input_shape</span><span class="p">),</span>
|
||
<span class="n">key</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">input_shape</span><span class="p">),</span>
|
||
<span class="n">value</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">input_shape</span><span class="p">),</span>
|
||
<span class="p">]</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">qkv</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span><span class="n">query</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
|
||
<span class="n">qkv</span> <span class="o">=</span> <span class="n">qkv</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">input_shape</span><span class="p">)</span>
|
||
|
||
<span class="k">return</span> <span class="n">qkv</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="RopeEmbeddingUtils.apply_rotary_pos_emb_cogvlm">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.RopeEmbeddingUtils.apply_rotary_pos_emb_cogvlm">[docs]</a>
|
||
<span class="nd">@staticmethod</span>
|
||
<span class="k">def</span> <span class="nf">apply_rotary_pos_emb_cogvlm</span><span class="p">(</span><span class="n">qkv</span><span class="p">,</span> <span class="n">position_embedding</span><span class="p">,</span>
|
||
<span class="n">num_attention_heads</span><span class="p">,</span> <span class="n">attention_head_size</span><span class="p">,</span>
|
||
<span class="n">max_position_embeddings</span><span class="p">,</span>
|
||
<span class="n">rotary_embedding_scale</span><span class="p">,</span>
|
||
<span class="n">remove_input_padding</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="nb">input</span> <span class="o">=</span> <span class="n">qkv</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">qkv</span><span class="p">,</span> <span class="nb">list</span><span class="p">)</span> <span class="k">else</span> <span class="n">qkv</span>
|
||
<span class="n">input_shape</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">)</span>
|
||
<span class="n">batch_size</span> <span class="o">=</span> <span class="mi">1</span> <span class="k">if</span> <span class="n">remove_input_padding</span> <span class="k">else</span> <span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
|
||
<span class="n">seqlen</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="mi">0</span> <span class="k">if</span> <span class="n">remove_input_padding</span> <span class="k">else</span> <span class="mi">1</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">qkv</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span>
|
||
<span class="n">query</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span> <span class="o">=</span> <span class="n">qkv</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">qkv</span> <span class="o">=</span> <span class="n">qkv</span><span class="o">.</span><span class="n">view</span><span class="p">(</span>
|
||
<span class="n">concat</span><span class="p">([</span>
|
||
<span class="n">batch_size</span><span class="p">,</span>
|
||
<span class="n">seqlen</span><span class="p">,</span>
|
||
<span class="mi">3</span><span class="p">,</span>
|
||
<span class="n">num_attention_heads</span><span class="p">,</span>
|
||
<span class="n">attention_head_size</span><span class="p">,</span>
|
||
<span class="p">]))</span>
|
||
<span class="n">query</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span> <span class="o">=</span> <span class="n">split</span><span class="p">(</span><span class="n">qkv</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
|
||
<span class="n">q_shape</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span>
|
||
<span class="n">batch_size</span><span class="p">,</span>
|
||
<span class="n">seqlen</span><span class="p">,</span>
|
||
<span class="n">num_attention_heads</span><span class="p">,</span>
|
||
<span class="n">attention_head_size</span><span class="p">,</span>
|
||
<span class="p">])</span>
|
||
<span class="n">query</span> <span class="o">=</span> <span class="n">query</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">q_shape</span><span class="p">)</span>
|
||
<span class="n">key</span> <span class="o">=</span> <span class="n">key</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">q_shape</span><span class="p">)</span>
|
||
<span class="n">value</span> <span class="o">=</span> <span class="n">value</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">q_shape</span><span class="p">)</span>
|
||
|
||
<span class="n">embedding_weight</span> <span class="o">=</span> <span class="n">RopeEmbeddingUtils</span><span class="o">.</span><span class="n">create_sinusoidal_positions</span><span class="p">(</span>
|
||
<span class="n">max_position_embeddings</span><span class="p">,</span> <span class="n">attention_head_size</span><span class="p">)</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
||
<span class="n">embedding_weight</span> <span class="o">/=</span> <span class="n">rotary_embedding_scale</span> <span class="c1"># [max_position_embeddings, attention_head_size]</span>
|
||
|
||
<span class="k">if</span> <span class="n">remove_input_padding</span><span class="p">:</span>
|
||
<span class="n">position_embedding</span> <span class="o">=</span> <span class="n">unsqueeze</span><span class="p">(</span><span class="n">position_embedding</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span> <span class="c1"># [1, seqlen]</span>
|
||
|
||
<span class="n">embedding_weight</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">embedding_weight</span><span class="p">)</span> <span class="c1"># float32</span>
|
||
<span class="n">position_embedding</span> <span class="o">=</span> <span class="n">embedding</span><span class="p">(</span>
|
||
<span class="n">position_embedding</span><span class="p">,</span>
|
||
<span class="n">embedding_weight</span><span class="p">)</span> <span class="c1"># [1, seqlen, attention_head_size]</span>
|
||
<span class="n">sin</span><span class="p">,</span> <span class="n">cos</span> <span class="o">=</span> <span class="n">split</span><span class="p">(</span><span class="n">position_embedding</span><span class="p">,</span> <span class="n">attention_head_size</span> <span class="o">//</span> <span class="mi">2</span><span class="p">,</span>
|
||
<span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span> <span class="c1"># [1, seqlen, attention_head_size//2]</span>
|
||
|
||
<span class="n">input_dtype</span> <span class="o">=</span> <span class="n">query</span><span class="o">.</span><span class="n">dtype</span>
|
||
<span class="n">fp32_query</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">query</span><span class="p">,</span> <span class="s2">"float32"</span><span class="p">)</span>
|
||
<span class="n">fp32_key</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="s2">"float32"</span><span class="p">)</span>
|
||
<span class="n">fp32_query</span> <span class="o">=</span> <span class="n">RopeEmbeddingUtils</span><span class="o">.</span><span class="n">apply_rotary_pos_emb</span><span class="p">(</span>
|
||
<span class="n">tensor</span><span class="o">=</span><span class="n">fp32_query</span><span class="p">,</span>
|
||
<span class="n">position_embedding</span><span class="o">=</span><span class="p">[</span><span class="n">cos</span><span class="p">,</span> <span class="n">sin</span><span class="p">],</span>
|
||
<span class="n">pos_emb_type</span><span class="o">=</span><span class="n">PositionEmbeddingType</span><span class="o">.</span><span class="n">rope_gpt_neox</span><span class="p">)</span>
|
||
<span class="n">fp32_key</span> <span class="o">=</span> <span class="n">RopeEmbeddingUtils</span><span class="o">.</span><span class="n">apply_rotary_pos_emb</span><span class="p">(</span>
|
||
<span class="n">tensor</span><span class="o">=</span><span class="n">fp32_key</span><span class="p">,</span>
|
||
<span class="n">position_embedding</span><span class="o">=</span><span class="p">[</span><span class="n">cos</span><span class="p">,</span> <span class="n">sin</span><span class="p">],</span>
|
||
<span class="n">pos_emb_type</span><span class="o">=</span><span class="n">PositionEmbeddingType</span><span class="o">.</span><span class="n">rope_gpt_neox</span><span class="p">)</span>
|
||
|
||
<span class="n">query</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">fp32_query</span><span class="p">,</span> <span class="n">input_dtype</span><span class="p">)</span>
|
||
<span class="n">key</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">fp32_key</span><span class="p">,</span> <span class="n">input_dtype</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">qkv</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span>
|
||
<span class="n">qkv</span> <span class="o">=</span> <span class="p">[</span>
|
||
<span class="n">query</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">input_shape</span><span class="p">),</span>
|
||
<span class="n">key</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">input_shape</span><span class="p">),</span>
|
||
<span class="n">value</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">input_shape</span><span class="p">),</span>
|
||
<span class="p">]</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">qkv</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span><span class="n">query</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
|
||
<span class="n">qkv</span> <span class="o">=</span> <span class="n">qkv</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">input_shape</span><span class="p">)</span>
|
||
|
||
<span class="k">return</span> <span class="n">qkv</span></div>
|
||
</div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="gpt_attention">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.gpt_attention">[docs]</a>
|
||
<span class="nd">@gw</span><span class="o">.</span><span class="n">record_signature</span>
|
||
<span class="k">def</span> <span class="nf">gpt_attention</span><span class="p">(</span>
|
||
<span class="o">*</span><span class="p">,</span>
|
||
<span class="n">qkv</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">past_key_value</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">attention_mask</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">attention_packed_mask</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">sequence_length</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">host_past_key_value_lengths</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">],</span>
|
||
<span class="n">host_max_attention_window_sizes</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">host_sink_token_length</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">context_lengths</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">],</span>
|
||
<span class="n">cache_indirection</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">],</span>
|
||
<span class="n">host_request_types</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">layer_idx</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">num_heads</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">num_kv_heads</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">hidden_size_per_head</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">q_scaling</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span>
|
||
<span class="n">attn_logit_softcapping_scale</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span>
|
||
<span class="n">rotary_embedding_dim</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">rotary_embedding_base</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">10000.0</span><span class="p">,</span>
|
||
<span class="n">rotary_embedding_scale_type</span><span class="p">:</span> <span class="n">RotaryScalingType</span> <span class="o">=</span> <span class="n">RotaryScalingType</span><span class="o">.</span><span class="n">none</span><span class="p">,</span>
|
||
<span class="n">rotary_embedding_short_m_scale</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span>
|
||
<span class="n">rotary_embedding_long_m_scale</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span>
|
||
<span class="n">rotary_embedding_scale</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span>
|
||
<span class="n">rotary_embedding_max_positions</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1024</span><span class="p">,</span>
|
||
<span class="n">rotary_embedding_original_max_positions</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1024</span><span class="p">,</span>
|
||
<span class="n">position_embedding_type</span><span class="p">:</span> <span class="n">PositionEmbeddingType</span> <span class="o">=</span> <span class="n">PositionEmbeddingType</span><span class="o">.</span>
|
||
<span class="n">learned_absolute</span><span class="p">,</span>
|
||
<span class="n">rotary_inv_freq</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">rotary_cos_sin</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">kv_orig_quant_scale</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">kv_quant_orig_scale</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">attention_output_orig_quant_scale</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">kv_cache_quant_mode</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">QuantModeWrapper</span><span class="p">,</span> <span class="n">QuantMode</span><span class="p">]</span> <span class="o">=</span> <span class="n">QuantMode</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
|
||
<span class="n">max_context_length</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">mask_type</span><span class="p">:</span> <span class="n">AttentionMaskType</span> <span class="o">=</span> <span class="n">AttentionMaskType</span><span class="o">.</span><span class="n">causal</span><span class="p">,</span>
|
||
<span class="n">block_sparse_block_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">64</span><span class="p">,</span>
|
||
<span class="n">block_sparse_homo_head_pattern</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">block_sparse_num_local_blocks</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">16</span><span class="p">,</span>
|
||
<span class="n">block_sparse_vertical_stride</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">8</span><span class="p">,</span>
|
||
<span class="n">alibi_slopes</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">tp_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">tp_rank</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">vision_start</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">vision_length</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">kv_cache_block_offsets</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">host_kv_cache_block_offsets</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">host_kv_cache_pool_pointers</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">host_kv_cache_pool_mapping</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">do_cross_attention</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">cross_kv</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="c1"># for cross attention</span>
|
||
<span class="n">cross_kv_length</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="c1"># for cross attention</span>
|
||
<span class="n">encoder_input_lengths</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="c1"># for cross attention</span>
|
||
<span class="n">relative_attention_bias</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="c1"># for relative attention</span>
|
||
<span class="n">logn_scaling</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="c1"># for logn scaling</span>
|
||
<span class="n">max_distance</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span> <span class="c1"># for relative attention</span>
|
||
<span class="n">host_context_lengths</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="c1"># for pad-free input mode</span>
|
||
<span class="n">qkv_bias</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">use_cache</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
|
||
<span class="n">spec_decoding_is_generation_length_variable</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">spec_decoding_max_generation_length</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">spec_decoding_generation_lengths</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">spec_decoding_position_offsets</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">spec_decoding_packed_mask</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">spec_decoding_use</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">long_rope_rotary_inv_freq</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">long_rope_rotary_cos_sin</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">mrope_rotary_cos_sin</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">mrope_position_deltas</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">host_runtime_perf_knobs</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">host_context_progress</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">layer_idx_in_cache_pool</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">is_mla_enabled_flag</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">q_lora_rank</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">kv_lora_rank</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">qk_nope_head_dim</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">qk_rope_head_dim</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">v_head_dim</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">fused_q_proj</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">q_b_proj</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">kv_b_proj</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">skip_attn</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">cp_group</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="p">[</span><span class="mi">0</span><span class="p">],</span>
|
||
<span class="n">cp_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">cp_rank</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
|
||
<span class="p">)</span> <span class="o">-></span> <span class="n">Tuple</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]]:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation that performs the multi-head attention in GPT-like models.</span>
|
||
|
||
<span class="sd"> The signature of the function will change in the future release - we are in</span>
|
||
<span class="sd"> the process of simplifying the API. The current version is still</span>
|
||
<span class="sd"> work-in-progress! The following API is provided with hints regarding the</span>
|
||
<span class="sd"> arguments that are likely to be removed or merged with others in the future</span>
|
||
<span class="sd"> release.</span>
|
||
|
||
<span class="sd"> See docs/source/advanced/gpt-attention.md for the documentation of that function.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> qkv: Tensor (On GPU)</span>
|
||
<span class="sd"> The input QKV tensor. Its shape is [batch_beam_size, max_seqlen, qkv_dim] in padded mode and [1, num_tokens, qkv_dim] in</span>
|
||
<span class="sd"> packed mode. Where qkv_dim depends on using MQA, GQA, or MHA. See QKV Input in docs/source/advanced/gpt-attention.md,</span>
|
||
|
||
<span class="sd"> past_key_value: Tensor (On GPU)</span>
|
||
<span class="sd"> The tensor that stores KV cache data. Its shape is</span>
|
||
<span class="sd"> [max_batch_size * max_beam_width, 2, num_kv_heads, max_seqlen, hidden_dim_per_head]</span>
|
||
<span class="sd"> in contiguous mode and</span>
|
||
<span class="sd"> [max_blocks, 2, num_kv_heads, num_tokens_per_block, hidden_dim_per_head]</span>
|
||
<span class="sd"> in paged mode. See KV Cache in docs/source/advanced/gpt-attention.md,</span>
|
||
|
||
<span class="sd"> attention_mask: Tensor (On GPU)</span>
|
||
<span class="sd"> The tensor that stores the attention mask for unfused MHA or MMHA.</span>
|
||
<span class="sd"> Its shape is [num_tokens, max_kv_seqlen].</span>
|
||
|
||
<span class="sd"> attention_packed_mask: Tensor (On GPU)</span>
|
||
<span class="sd"> The tensor that stores the packed custom mask for fmha.</span>
|
||
<span class="sd"> Its shape is [num_tokens, max_kv_seqlen / 32], where each bit represents one mask position.</span>
|
||
|
||
<span class="sd"> sequence_lengths: Tensor (On GPU)</span>
|
||
<span class="sd"> The tensor that stores the length of each sequence. Its shape is</span>
|
||
<span class="sd"> [batch_size]. See QKV Input in docs/source/advanced/gpt-attention.md,</span>
|
||
|
||
<span class="sd"> host_past_key_value_lengths: Tensor (On CPU)</span>
|
||
<span class="sd"> An INT32 tensor of shape [batch_size],</span>
|
||
|
||
<span class="sd"> host_max_attention_window_sizes: Tensor (On CPU)</span>
|
||
<span class="sd"> An INT32 tensor of shape [1].</span>
|
||
<span class="sd"> by default, the max_attention_window_size is determined by the shape of cache_indir_table.</span>
|
||
<span class="sd"> And we support independent max_attention_window_size for each layer.</span>
|
||
<span class="sd"> This controls the sliding-window-attention/cyclic-kv-cache features.</span>
|
||
|
||
<span class="sd"> context_lengths: Tensor (On GPU)</span>
|
||
<span class="sd"> The tensor that stores the context-phase sequence length of each request. Its shape</span>
|
||
<span class="sd"> is [batch_size]. See QKV Input in doc/functional.py,</span>
|
||
|
||
<span class="sd"> cache_indirection: Tensor (On GPU)</span>
|
||
<span class="sd"> The tensor to reconstruct the paths when using beam-search. Its</span>
|
||
<span class="sd"> shape is [batch_size, beam_width, max_seqlen]. See Beam-Search in</span>
|
||
<span class="sd"> docs/source/advanced/gpt-attention.md,</span>
|
||
|
||
<span class="sd"> host_request_types: Tensor = None (On CPU)</span>
|
||
<span class="sd"> The tensor on the host that indicates if a request is in context or</span>
|
||
<span class="sd"> generation phase. Its shape is [batch_size]. See Inflight Batching</span>
|
||
<span class="sd"> in docs/source/advanced/gpt-attention.md,</span>
|
||
|
||
<span class="sd"> layer_idx: int</span>
|
||
<span class="sd"> The index of this attention layer, used to access kv_cache_block_offsets,</span>
|
||
|
||
<span class="sd"> num_heads: int</span>
|
||
<span class="sd"> The number of heads,</span>
|
||
|
||
<span class="sd"> num_kv_heads: int</span>
|
||
<span class="sd"> The number of KV heads, generic to handle MHA/MQA/GQA,</span>
|
||
|
||
<span class="sd"> hidden_size_per_head: int</span>
|
||
<span class="sd"> The hidden size per head,</span>
|
||
|
||
<span class="sd"> q_scaling: float</span>
|
||
<span class="sd"> The value used to compute the scaling factor applied to the output</span>
|
||
<span class="sd"> of the Q*K^T product. See Scaling Factors in docs/source/advanced/gpt-attention.md,</span>
|
||
|
||
<span class="sd"> attn_logit_softcapping_scale: float</span>
|
||
<span class="sd"> The scale * tanh(value / scale) used to compute the scaling factor applied to the output</span>
|
||
<span class="sd"> of the Q*K^T product.</span>
|
||
|
||
<span class="sd"> rotary_embedding_dim: int</span>
|
||
<span class="sd"> The dimension to compute RoPE. Use 0 when position_embedding_type is not RoPE.</span>
|
||
|
||
<span class="sd"> rotary_embedding_base: float</span>
|
||
<span class="sd"> The theta value to use for RoPE. Ignored when position_embedding_type is not RoPE.</span>
|
||
|
||
<span class="sd"> rotary_embedding_scale_type: RotaryScalingType</span>
|
||
<span class="sd"> The scaling type of RoPE. Ignored when position_embedding_type is not RoPE.</span>
|
||
<span class="sd"> Possible rotary scaling type:</span>
|
||
<span class="sd"> * RotaryScalingType.none</span>
|
||
<span class="sd"> * RotaryScalingType.linear</span>
|
||
<span class="sd"> * RotaryScalingType.dynamic</span>
|
||
<span class="sd"> * RotaryScalingType.longrope</span>
|
||
<span class="sd"> * RotaryScalingType.llama3</span>
|
||
|
||
<span class="sd"> rotary_embedding_scale: float</span>
|
||
<span class="sd"> The scale value to use for linear/dynamic scaling in RoPE.</span>
|
||
<span class="sd"> Ignored when position_embedding_type is not RoPE.</span>
|
||
<span class="sd"> Must be set to 1 (default) if rotary_embedding_scale_type is `none`.</span>
|
||
|
||
<span class="sd"> rotary_inv_freq: float Tensor</span>
|
||
<span class="sd"> The rotary inv freq with shape [head_size / 2].</span>
|
||
|
||
<span class="sd"> rotary_cos_sin: float2(cos/sin) Tensor</span>
|
||
<span class="sd"> The rotary cos/sin cache, which will be reused among different requests.</span>
|
||
<span class="sd"> It is taken as constant tensor.</span>
|
||
|
||
<span class="sd"> rotary_embedding_max_positions: int</span>
|
||
<span class="sd"> Needed only for `dynamic` RoPE scaling. Ignored otherwise.</span>
|
||
|
||
<span class="sd"> position_embedding_type: PositionEmbeddingType</span>
|
||
<span class="sd"> The position embedding type:</span>
|
||
<span class="sd"> * PositionEmbeddingType.learned_absolute</span>
|
||
<span class="sd"> * PositionEmbeddingType.relative</span>
|
||
<span class="sd"> * PositionEmbeddingType.rope_gptj</span>
|
||
<span class="sd"> * PositionEmbeddingType.rope_gpt_neox</span>
|
||
<span class="sd"> * PositionEmbeddingType.alibi</span>
|
||
<span class="sd"> * PositionEmbeddingType.alibi_with_scale</span>
|
||
|
||
<span class="sd"> kv_orig_quant_scale: Tensor</span>
|
||
<span class="sd"> The tensor to store the scaling factor for quantization to INT8/FP8</span>
|
||
<span class="sd"> in the KV cache. Its shape is [1]. See INT8/FP8 KV Cache in</span>
|
||
<span class="sd"> docs/source/advanced/gpt-attention.md,</span>
|
||
|
||
<span class="sd"> kv_quant_orig_scale: Tensor</span>
|
||
<span class="sd"> The tensor to store the scaling factor for dequantization from</span>
|
||
<span class="sd"> INT8/FP8 in the KV cache. Its shape is [1]. See INT8/FP8 KV Cache</span>
|
||
<span class="sd"> in docs/source/advanced/gpt-attention.md,</span>
|
||
|
||
<span class="sd"> attention_output_orig_quant_scale: Tensor</span>
|
||
<span class="sd"> The tensor to store the scaling factor for quantization to FP8</span>
|
||
<span class="sd"> in the KV cache. Its shape is [1].</span>
|
||
|
||
<span class="sd"> kv_cache_quant_mode: QuantMode (int flags)</span>
|
||
<span class="sd"> Do we enable the INT8 or FP8 KV cache?</span>
|
||
|
||
<span class="sd"> max_context_length: int32_t</span>
|
||
<span class="sd"> The length of the longest input sequence. See QKV Input in</span>
|
||
<span class="sd"> docs/source/advanced/gpt-attention.md,</span>
|
||
|
||
<span class="sd"> mask_type: int = 1</span>
|
||
<span class="sd"> The type of mask:</span>
|
||
<span class="sd"> * tensorrt_llm.layers.AttentionMaskType.padding for BERT,</span>
|
||
<span class="sd"> * tensorrt_llm.layers.AttentionMaskType.causal for GPT,</span>
|
||
<span class="sd"> * tensorrt_llm.layers.AttentionMaskType.sliding_window_causal for GPT,</span>
|
||
<span class="sd"> * tensorrt_llm.layers.AttentionMaskType.bidirectional for ChatGLM-6B,</span>
|
||
<span class="sd"> * tensorrt_llm.layers.AttentionMaskType.bidirectionalglm for GLM-10B,</span>
|
||
<span class="sd"> * tensorrt_llm.layers.AttentionMaskType.blocksparse for Phi-3-small,</span>
|
||
<span class="sd"> * tensorrt_llm.layers.AttentionMaskType.custom_mask for any models.</span>
|
||
|
||
<span class="sd"> block_sparse_block_size: int</span>
|
||
<span class="sd"> Block size in block sparse attention</span>
|
||
|
||
<span class="sd"> block_sparse_homo_head_pattern: bool</span>
|
||
<span class="sd"> Do all attention heads share same vertical stride pattern?</span>
|
||
|
||
<span class="sd"> block_sparse_num_local_blocks: int</span>
|
||
<span class="sd"> Number of active blocks near diagonal</span>
|
||
|
||
<span class="sd"> block_sparse_vertical_stride: int</span>
|
||
<span class="sd"> Stride of active blocks in vertical dimension</span>
|
||
|
||
<span class="sd"> alibi_slopes: Tensor</span>
|
||
<span class="sd"> The ALiBi slopes. The ALiBi bias is computed on-the-fly in the kernel</span>
|
||
<span class="sd"> when possible,</span>
|
||
|
||
<span class="sd"> tp_size: int</span>
|
||
<span class="sd"> The number of processes/GPUs when tensor parallelism is activated,</span>
|
||
|
||
<span class="sd"> tp_rank: int</span>
|
||
<span class="sd"> The rank of that process (when running tensor parallelism),</span>
|
||
|
||
<span class="sd"> kv_cache_block_offsets:</span>
|
||
<span class="sd"> The tensor of block offsets for the KV cache. Its shape is</span>
|
||
<span class="sd"> [num_layers, max_batch_size, max_beam_width, 2, max_blocks_per_sequence * 2],</span>
|
||
<span class="sd"> See KV cache section in docs/source/advanced/gpt-attention.md, on gpu,</span>
|
||
|
||
<span class="sd"> host_kv_cache_block_offsets:</span>
|
||
<span class="sd"> The same as kv_cache_block_offsets, but on cpu,</span>
|
||
|
||
<span class="sd"> host_kv_cache_pool_pointers:</span>
|
||
<span class="sd"> The tensor of pool pointers for the KV cache. Its shape is [num_layers, 2],</span>
|
||
<span class="sd"> See KV cache section in docs/source/advanced/gpt-attention.md, on gpu,</span>
|
||
|
||
<span class="sd"> host_kv_cache_pool_mapping:</span>
|
||
<span class="sd"> The tensor of pool mapping for the different memory pools. Its shape is [num_layers,],</span>
|
||
|
||
<span class="sd"> do_cross_attention: bool = False</span>
|
||
<span class="sd"> Do we use this as cross attention instead of self attention,</span>
|
||
|
||
<span class="sd"> cross_kv: Tensor = None</span>
|
||
<span class="sd"> The KV tensor of encoder output hidden states. Its shape is [batch_size, max_seqlen, 2 * kvHeadNum * headSize] in padded mode and [1, num_tokens, 2 * kvHeadNum * headSize] in</span>
|
||
<span class="sd"> packed mode,</span>
|
||
|
||
<span class="sd"> cross_kv_length: Tensor = None</span>
|
||
<span class="sd"> The length of the longest encoder output sequence,</span>
|
||
|
||
<span class="sd"> encoder_input_lengths: Tensor</span>
|
||
<span class="sd"> The tensor that stores the length of each encoder input sequence. Its shape is [batch_size],</span>
|
||
|
||
<span class="sd"> logn_scaling: Tensor = None</span>
|
||
<span class="sd"> The logn scaling tensor [max_position_embedding_len], which is applied to q in order to help extrapolation</span>
|
||
|
||
<span class="sd"> relative_attention_bias: Tensor = None</span>
|
||
<span class="sd"> The relative attention bias [num_heads, max_seq_len, max_seq_len], or The relative attention embedding table for implicit mode, [num_heads, num_buckets].</span>
|
||
|
||
<span class="sd"> max_distance: int = 0</span>
|
||
<span class="sd"> The maximum distance of relative position in attention, for implicit mode.</span>
|
||
<span class="sd"> Default value is 0, meaning to use the regular mode of relative attention bias.</span>
|
||
<span class="sd"> Implicit mode is only enabled when passing in non-zero positive max_distance value.</span>
|
||
<span class="sd"> See relative attention bias in docs/source/advanced/gpt-attention.md</span>
|
||
|
||
<span class="sd"> host_context_lengths: Tensor = None (On CPU)</span>
|
||
<span class="sd"> A host tensor that contains the lengths of the different inputs,</span>
|
||
|
||
<span class="sd"> qkv_bias: Tensor = None,</span>
|
||
<span class="sd"> The qkv bias tensor.</span>
|
||
|
||
<span class="sd"> use_cache: bool = False</span>
|
||
<span class="sd"> Do we need to store kv cache ? not needed if there is no generation phase.</span>
|
||
|
||
<span class="sd"> spec_decoding_is_generation_length_variable: bool = False,</span>
|
||
<span class="sd"> Whether the generation lengths can be different for each sequence in a batch.</span>
|
||
<span class="sd"> For Medusa, this should be set False.</span>
|
||
<span class="sd"> For Redrafter, this should be set to True.</span>
|
||
|
||
<span class="sd"> spec_decoding_max_generation_length: int = 1,</span>
|
||
<span class="sd"> The maximum number of tokens possible in the generation phase per sequence.</span>
|
||
|
||
<span class="sd"> spec_decoding_generation_lengths: Tensor = None,</span>
|
||
<span class="sd"> The generation phase tokens' lengths for each sequence.</span>
|
||
<span class="sd"> Shape: [batch_size]</span>
|
||
|
||
<span class="sd"> spec_decoding_position_offsets: Tensor = None,</span>
|
||
<span class="sd"> The speculative decoding tokens's position offsets (shared by all sequences).</span>
|
||
<span class="sd"> Shape: [batch_size, num_draft_tokens + 1].</span>
|
||
|
||
<span class="sd"> spec_decoding_packed_mask: Tensor = None,</span>
|
||
<span class="sd"> The speculative decoding tokens's attention mask (packed into uint32_t bits).</span>
|
||
<span class="sd"> remove_input_padding is False:</span>
|
||
<span class="sd"> Shape: [batch_size, num_draft_tokens + 1, divUp(num_draft_tokens + 1, 32)].</span>
|
||
<span class="sd"> remove_input_padding is True:</span>
|
||
<span class="sd"> Shape: [sum(spec_decoding_generation_lengths), divUp(num_draft_tokens + 1, 32)].</span>
|
||
|
||
<span class="sd"> long_rope_rotary_inv_freq: float Tensor</span>
|
||
<span class="sd"> Additional rotary inv freq used for longer sequence lengths. Shape: [head_size / 2]</span>
|
||
|
||
<span class="sd"> long_rope_rotary_cos_sin: float2(cos/sin) Tensor</span>
|
||
<span class="sd"> Additional rotary cos/sin cache used for longer sequence lengths.</span>
|
||
|
||
<span class="sd"> is_mla_enable: bool = False</span>
|
||
<span class="sd"> Do we need to enable deepseekv2 mla?</span>
|
||
|
||
<span class="sd"> host_runtime_perf_knobs: Tensor = None,</span>
|
||
<span class="sd"> The runtime perf knobs bit mask, controls whether to use certain perf knob in the runtime.</span>
|
||
|
||
<span class="sd"> host_context_progress: Tensor = None,</span>
|
||
<span class="sd"> The structure used to track layer-wise progress in context phase.</span>
|
||
|
||
<span class="sd"> skip_attn: Tensor = None,</span>
|
||
<span class="sd"> A bool tensor on CPU. If it is true, don't run attention plugin, returning directly.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by that layer.</span>
|
||
<span class="sd"> '''</span>
|
||
|
||
<span class="k">assert</span> <span class="n">host_request_types</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="k">assert</span> <span class="p">(</span><span class="n">alibi_slopes</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">)</span> <span class="o">==</span> <span class="p">(</span><span class="n">position_embedding_type</span><span class="o">.</span><span class="n">is_alibi</span><span class="p">())</span>
|
||
<span class="k">assert</span> <span class="p">(</span><span class="n">mrope_rotary_cos_sin</span>
|
||
<span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">)</span> <span class="o">==</span> <span class="p">(</span><span class="n">position_embedding_type</span><span class="o">.</span><span class="n">is_mrope</span><span class="p">())</span>
|
||
<span class="n">attn_plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
|
||
<span class="s1">'GPTAttention'</span><span class="p">,</span> <span class="s1">'1'</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">attn_plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="k">assert</span> <span class="n">host_context_lengths</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">or</span> <span class="ow">not</span> <span class="n">default_net</span><span class="p">(</span>
|
||
<span class="p">)</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">remove_input_padding</span>
|
||
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">max_context_length</span><span class="p">,</span> <span class="nb">int</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">host_max_attention_window_sizes</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="k">assert</span> <span class="n">host_sink_token_length</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
|
||
<span class="k">if</span> <span class="n">layer_idx_in_cache_pool</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">layer_idx_in_cache_pool</span> <span class="o">=</span> <span class="n">layer_idx</span>
|
||
|
||
<span class="n">paged_kv_cache_flag</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">paged_kv_cache</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">qkv</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span>
|
||
<span class="n">is_unfuse_qkv_gemm</span> <span class="o">=</span> <span class="mi">1</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">is_unfuse_qkv_gemm</span> <span class="o">=</span> <span class="mi">0</span>
|
||
|
||
<span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">context_fmha_type</span>
|
||
<span class="k">if</span> <span class="n">do_cross_attention</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">paged_kv_cache_flag</span><span class="p">:</span>
|
||
<span class="k">pass</span>
|
||
<span class="k">if</span> <span class="n">logn_scaling</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">use_logn_scaling</span> <span class="o">=</span> <span class="mi">1</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">use_logn_scaling</span> <span class="o">=</span> <span class="mi">0</span>
|
||
|
||
<span class="n">unfuse_qkv_gemm</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"unfuse_qkv_gemm"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">is_unfuse_qkv_gemm</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
|
||
<span class="n">layer_idx</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"layer_idx"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">layer_idx</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">nheads</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"num_heads"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">num_heads</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">vision_start</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"vision_start"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">vision_start</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">vision_length</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"vision_length"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">vision_length</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">num_kv_heads</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"num_kv_heads"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">num_kv_heads</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">layer_idx_in_cache_pool</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"layer_idx_in_cache_pool"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">layer_idx_in_cache_pool</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">head_size</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"head_size"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">hidden_size_per_head</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">unidirectional</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"unidirectional"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">q_scaling</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"q_scaling"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">q_scaling</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">FLOAT32</span><span class="p">)</span>
|
||
<span class="n">attn_logit_softcapping_scale</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"attn_logit_softcapping_scale"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">attn_logit_softcapping_scale</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">FLOAT32</span><span class="p">)</span>
|
||
<span class="n">rotary_embedding_dim</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"rotary_embedding_dim"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">rotary_embedding_dim</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">rotary_embedding_base</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"rotary_embedding_base"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">rotary_embedding_base</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">FLOAT32</span><span class="p">)</span>
|
||
<span class="n">rotary_embedding_scale_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"rotary_embedding_scale_type"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">rotary_embedding_scale_type</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">rotary_embedding_scale</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"rotary_embedding_scale"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">rotary_embedding_scale</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">FLOAT32</span><span class="p">)</span>
|
||
<span class="n">rotary_embedding_short_m_scale</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"rotary_embedding_short_m_scale"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">rotary_embedding_short_m_scale</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">FLOAT32</span><span class="p">)</span>
|
||
<span class="n">rotary_embedding_long_m_scale</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"rotary_embedding_long_m_scale"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">rotary_embedding_long_m_scale</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">FLOAT32</span><span class="p">)</span>
|
||
<span class="n">rotary_embedding_max_positions</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"rotary_embedding_max_positions"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">rotary_embedding_max_positions</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">rotary_embedding_original_max_positions</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"rotary_embedding_original_max_positions"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">rotary_embedding_original_max_positions</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">position_embedding_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"position_embedding_type"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="nb">int</span><span class="p">(</span><span class="n">position_embedding_type</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">context_fmha_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"context_fmha_type"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">context_fmha_type</span><span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">remove_input_padding</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"remove_input_padding"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">is_spec_decoding_enabled</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"is_spec_decoding_enabled"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">spec_decoding_packed_mask</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">spec_decoding_is_generation_length_variable</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"spec_decoding_is_generation_length_variable"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">spec_decoding_is_generation_length_variable</span><span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">spec_decoding_max_generation_length</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"spec_decoding_max_generation_length"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">spec_decoding_max_generation_length</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">is_mla_enabled</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"is_mla_enabled"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">is_mla_enabled_flag</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">q_lora_rank</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"q_lora_rank"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">q_lora_rank</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">kv_lora_rank</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"kv_lora_rank"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">kv_lora_rank</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">qk_nope_head_dim</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"qk_nope_head_dim"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">qk_nope_head_dim</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">qk_rope_head_dim</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"qk_rope_head_dim"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">qk_rope_head_dim</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">v_head_dim</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"v_head_dim"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">v_head_dim</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">p_dtype</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">gpt_attention_plugin</span>
|
||
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"type_id"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">))],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="c1"># reset mask_type to custom_mask.</span>
|
||
<span class="k">if</span> <span class="p">(</span><span class="n">attention_mask</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">)</span> <span class="ow">or</span> <span class="p">(</span><span class="n">attention_packed_mask</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">):</span>
|
||
<span class="c1"># context fmha needs packed mask.</span>
|
||
<span class="k">assert</span> <span class="n">attention_packed_mask</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="n">mask_type</span> <span class="o">=</span> <span class="n">AttentionMaskType</span><span class="o">.</span><span class="n">custom_mask</span>
|
||
|
||
<span class="n">mask_type_filed</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"mask_type"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">mask_type</span><span class="p">)],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">block_sparse_block_size</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"block_sparse_block_size"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">block_sparse_block_size</span><span class="p">],</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">block_sparse_homo_head_pattern</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"block_sparse_homo_head_pattern"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">block_sparse_homo_head_pattern</span><span class="p">),</span> <span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">block_sparse_num_local_blocks</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"block_sparse_num_local_blocks"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">block_sparse_num_local_blocks</span><span class="p">],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">block_sparse_vertical_stride</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"block_sparse_vertical_stride"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">block_sparse_vertical_stride</span><span class="p">],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">tp_size</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"tp_size"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">tp_size</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">tp_rank</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"tp_rank"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">tp_rank</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">kv_cache_quant_mode</span><span class="p">,</span> <span class="n">QuantModeWrapper</span><span class="p">):</span>
|
||
<span class="c1"># Now in TRT-LLM only use global kv_cache, so it's enough to get the first quant mode from list</span>
|
||
<span class="n">kv_cache_quant_mode</span> <span class="o">=</span> <span class="n">kv_cache_quant_mode</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
||
<span class="n">kv_cache_quant_mode_field</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"kv_cache_quant_mode"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">kv_cache_quant_mode</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">paged_kv_cache</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"paged_kv_cache"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">paged_kv_cache_flag</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">tokens_per_block</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"tokens_per_block"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">tokens_per_block</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">max_context_length</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"max_context_length"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">max_context_length</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">pos_shift_enabled</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"pos_shift_enabled"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">streamingllm</span><span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">dense_context_fmha</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"dense_context_fmha"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">streamingllm</span><span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">qkv_bias</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">qkv_bias_enabled</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"qkv_bias_enabled"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">qkv_bias_enabled</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"qkv_bias_enabled"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">do_cross_attention_field</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"do_cross_attention"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">do_cross_attention</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">max_distance</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"max_distance"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">max_distance</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">use_paged_context_fmha_field</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"use_paged_context_fmha"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_paged_context_fmha</span><span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">use_fp8_context_fmha_field</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"use_fp8_context_fmha"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_fp8_context_fmha</span><span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">has_full_attention_mask_field</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"has_full_attention_mask"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">attention_mask</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">use_cache_pf</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"use_cache"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">use_cache</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">skip_attn_pf</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"skip_attn"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">skip_attn</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">cp_size</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"cp_size"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">cp_size</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">cp_rank</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"cp_rank"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">cp_rank</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">cp_group</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">cp_group</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
|
||
<span class="n">cp_group</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"cp_group"</span><span class="p">,</span> <span class="n">cp_group</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">use_logn_scaling</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"use_logn_scaling"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">use_logn_scaling</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
|
||
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span>
|
||
<span class="n">layer_idx</span><span class="p">,</span> <span class="n">nheads</span><span class="p">,</span> <span class="n">vision_start</span><span class="p">,</span> <span class="n">vision_length</span><span class="p">,</span> <span class="n">num_kv_heads</span><span class="p">,</span>
|
||
<span class="n">layer_idx_in_cache_pool</span><span class="p">,</span> <span class="n">head_size</span><span class="p">,</span> <span class="n">unidirectional</span><span class="p">,</span> <span class="n">q_scaling</span><span class="p">,</span>
|
||
<span class="n">attn_logit_softcapping_scale</span><span class="p">,</span> <span class="n">position_embedding_type</span><span class="p">,</span>
|
||
<span class="n">rotary_embedding_dim</span><span class="p">,</span> <span class="n">rotary_embedding_base</span><span class="p">,</span>
|
||
<span class="n">rotary_embedding_scale_type</span><span class="p">,</span> <span class="n">rotary_embedding_scale</span><span class="p">,</span>
|
||
<span class="n">rotary_embedding_short_m_scale</span><span class="p">,</span> <span class="n">rotary_embedding_long_m_scale</span><span class="p">,</span>
|
||
<span class="n">rotary_embedding_max_positions</span><span class="p">,</span> <span class="n">rotary_embedding_original_max_positions</span><span class="p">,</span>
|
||
<span class="n">tp_size</span><span class="p">,</span> <span class="n">tp_rank</span><span class="p">,</span> <span class="n">unfuse_qkv_gemm</span><span class="p">,</span> <span class="n">context_fmha_type</span><span class="p">,</span>
|
||
<span class="n">kv_cache_quant_mode_field</span><span class="p">,</span> <span class="n">remove_input_padding</span><span class="p">,</span> <span class="n">mask_type_filed</span><span class="p">,</span>
|
||
<span class="n">block_sparse_block_size</span><span class="p">,</span> <span class="n">block_sparse_homo_head_pattern</span><span class="p">,</span>
|
||
<span class="n">block_sparse_num_local_blocks</span><span class="p">,</span> <span class="n">block_sparse_vertical_stride</span><span class="p">,</span>
|
||
<span class="n">paged_kv_cache</span><span class="p">,</span> <span class="n">tokens_per_block</span><span class="p">,</span> <span class="n">pf_type</span><span class="p">,</span> <span class="n">max_context_length</span><span class="p">,</span>
|
||
<span class="n">qkv_bias_enabled</span><span class="p">,</span> <span class="n">do_cross_attention_field</span><span class="p">,</span> <span class="n">max_distance</span><span class="p">,</span>
|
||
<span class="n">pos_shift_enabled</span><span class="p">,</span> <span class="n">dense_context_fmha</span><span class="p">,</span> <span class="n">use_paged_context_fmha_field</span><span class="p">,</span>
|
||
<span class="n">use_fp8_context_fmha_field</span><span class="p">,</span> <span class="n">has_full_attention_mask_field</span><span class="p">,</span> <span class="n">use_cache_pf</span><span class="p">,</span>
|
||
<span class="n">is_spec_decoding_enabled</span><span class="p">,</span> <span class="n">spec_decoding_is_generation_length_variable</span><span class="p">,</span>
|
||
<span class="n">spec_decoding_max_generation_length</span><span class="p">,</span> <span class="n">is_mla_enabled</span><span class="p">,</span> <span class="n">q_lora_rank</span><span class="p">,</span>
|
||
<span class="n">kv_lora_rank</span><span class="p">,</span> <span class="n">qk_nope_head_dim</span><span class="p">,</span> <span class="n">qk_rope_head_dim</span><span class="p">,</span> <span class="n">v_head_dim</span><span class="p">,</span>
|
||
<span class="n">skip_attn_pf</span><span class="p">,</span> <span class="n">cp_size</span><span class="p">,</span> <span class="n">cp_rank</span><span class="p">,</span> <span class="n">cp_group</span><span class="p">,</span> <span class="n">use_logn_scaling</span>
|
||
<span class="p">])</span>
|
||
|
||
<span class="n">attn_plug</span> <span class="o">=</span> <span class="n">attn_plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">"causal_attn"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">attn_plug</span>
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="o">*</span><span class="n">qkv</span><span class="p">]</span> <span class="k">if</span> <span class="n">is_unfuse_qkv_gemm</span> <span class="k">else</span> <span class="p">[</span><span class="n">qkv</span><span class="p">]</span>
|
||
<span class="k">if</span> <span class="n">attention_mask</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">mask_type</span> <span class="o">==</span> <span class="n">AttentionMaskType</span><span class="o">.</span><span class="n">custom_mask</span><span class="p">:</span>
|
||
<span class="c1"># useFullCustomMask</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">attention_mask</span><span class="p">]</span>
|
||
<span class="k">if</span> <span class="n">attention_packed_mask</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="c1"># usePackedCustomMask</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">attention_packed_mask</span><span class="p">]</span>
|
||
<span class="k">if</span> <span class="n">use_cache</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span>
|
||
<span class="n">sequence_length</span><span class="p">,</span>
|
||
<span class="n">host_past_key_value_lengths</span><span class="p">,</span>
|
||
<span class="n">host_max_attention_window_sizes</span><span class="p">,</span>
|
||
<span class="n">host_sink_token_length</span><span class="p">,</span>
|
||
<span class="n">context_lengths</span><span class="p">,</span>
|
||
<span class="n">cache_indirection</span><span class="p">,</span>
|
||
<span class="n">host_request_types</span><span class="p">,</span>
|
||
<span class="p">]</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span>
|
||
<span class="n">host_max_attention_window_sizes</span><span class="p">,</span>
|
||
<span class="n">host_sink_token_length</span><span class="p">,</span>
|
||
<span class="n">context_lengths</span><span class="p">,</span>
|
||
<span class="n">host_request_types</span><span class="p">,</span>
|
||
<span class="p">]</span>
|
||
<span class="k">if</span> <span class="n">use_cache</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="n">paged_kv_cache_flag</span><span class="p">:</span>
|
||
<span class="k">assert</span> <span class="n">kv_cache_block_offsets</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">,</span> <span class="s2">"Paged kv cache is enabled, the kv_cache_block_offsets tensor shall not be None"</span>
|
||
<span class="k">assert</span> <span class="n">host_kv_cache_block_offsets</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">,</span> <span class="s2">"Paged kv cache is enabled, the host_kv_cache_block_offsets tensor shall not be None"</span>
|
||
<span class="k">assert</span> <span class="n">host_kv_cache_pool_pointers</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">,</span> <span class="s2">"Paged kv cache is enabled, the host_kv_cache_pool_pointers tensor shall not be None"</span>
|
||
<span class="k">assert</span> <span class="n">host_kv_cache_pool_mapping</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">,</span> <span class="s2">"Paged kv cache is enabled, the host_kv_cache_pool_mapping tensor shall not be None"</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span>
|
||
<span class="n">kv_cache_block_offsets</span><span class="p">,</span> <span class="n">host_kv_cache_block_offsets</span><span class="p">,</span>
|
||
<span class="n">host_kv_cache_pool_pointers</span><span class="p">,</span> <span class="n">host_kv_cache_pool_mapping</span>
|
||
<span class="p">]</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">past_key_value</span><span class="p">]</span>
|
||
|
||
<span class="k">if</span> <span class="n">use_cache</span> <span class="ow">and</span> <span class="n">kv_cache_quant_mode</span><span class="o">.</span><span class="n">has_kv_cache_quant</span><span class="p">():</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">kv_orig_quant_scale</span><span class="p">,</span> <span class="n">kv_quant_orig_scale</span><span class="p">]</span>
|
||
|
||
<span class="k">if</span> <span class="n">attention_output_orig_quant_scale</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">assert</span> <span class="n">default_net</span><span class="p">(</span>
|
||
<span class="p">)</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_fp8_context_fmha</span><span class="p">,</span> <span class="s2">"FP8 Context FMHA needs to be enabled"</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">attention_output_orig_quant_scale</span><span class="p">]</span>
|
||
|
||
<span class="k">if</span> <span class="n">rotary_inv_freq</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">rotary_inv_freq</span><span class="p">]</span>
|
||
<span class="k">if</span> <span class="n">rotary_cos_sin</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">rotary_cos_sin</span><span class="p">]</span>
|
||
|
||
<span class="k">if</span> <span class="n">alibi_slopes</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">alibi_slopes</span><span class="p">]</span>
|
||
|
||
<span class="k">if</span> <span class="n">relative_attention_bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">relative_attention_bias</span><span class="p">]</span>
|
||
|
||
<span class="k">if</span> <span class="n">do_cross_attention</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">cross_kv</span><span class="p">,</span> <span class="n">cross_kv_length</span><span class="p">,</span> <span class="n">encoder_input_lengths</span><span class="p">]</span>
|
||
|
||
<span class="k">if</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">host_context_lengths</span><span class="p">]</span>
|
||
|
||
<span class="k">if</span> <span class="n">qkv_bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">qkv_bias</span><span class="p">]</span>
|
||
|
||
<span class="k">if</span> <span class="n">spec_decoding_packed_mask</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="c1"># add position_ids as well only if speculative decoding mode</span>
|
||
<span class="k">assert</span> <span class="n">spec_decoding_position_offsets</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="k">assert</span> <span class="n">spec_decoding_generation_lengths</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="k">assert</span> <span class="n">spec_decoding_use</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span>
|
||
<span class="n">spec_decoding_generation_lengths</span><span class="p">,</span> <span class="n">spec_decoding_packed_mask</span><span class="p">,</span>
|
||
<span class="n">spec_decoding_position_offsets</span><span class="p">,</span> <span class="n">spec_decoding_use</span>
|
||
<span class="p">]</span>
|
||
|
||
<span class="k">if</span> <span class="n">long_rope_rotary_inv_freq</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">assert</span> <span class="n">long_rope_rotary_cos_sin</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">long_rope_rotary_inv_freq</span><span class="p">,</span> <span class="n">long_rope_rotary_cos_sin</span><span class="p">]</span>
|
||
|
||
<span class="k">if</span> <span class="n">mrope_rotary_cos_sin</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">assert</span> <span class="n">mrope_position_deltas</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span>
|
||
<span class="n">mrope_rotary_cos_sin</span><span class="p">,</span>
|
||
<span class="n">mrope_position_deltas</span><span class="p">,</span>
|
||
<span class="p">]</span>
|
||
<span class="k">if</span> <span class="n">host_runtime_perf_knobs</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">host_runtime_perf_knobs</span><span class="p">]</span>
|
||
|
||
<span class="k">if</span> <span class="n">host_context_progress</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">host_context_progress</span><span class="p">]</span>
|
||
|
||
<span class="k">if</span> <span class="n">is_mla_enabled_flag</span><span class="p">:</span>
|
||
<span class="k">assert</span> <span class="n">fused_q_proj</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="k">assert</span> <span class="n">q_b_proj</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="k">assert</span> <span class="n">kv_b_proj</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">fused_q_proj</span><span class="p">,</span> <span class="n">q_b_proj</span><span class="p">,</span> <span class="n">kv_b_proj</span><span class="p">]</span>
|
||
|
||
<span class="k">if</span> <span class="n">skip_attn</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">skip_attn</span><span class="p">]</span>
|
||
|
||
<span class="k">if</span> <span class="n">logn_scaling</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">logn_scaling</span><span class="p">]</span>
|
||
|
||
<span class="k">for</span> <span class="n">idx</span><span class="p">,</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">):</span>
|
||
<span class="k">assert</span> <span class="n">i</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"Found None input for </span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s2"> th item in plugin inputs </span><span class="si">{</span><span class="n">plug_inputs</span><span class="si">}</span><span class="s2">"</span>
|
||
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">i</span><span class="o">.</span><span class="n">trt_tensor</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">plug_inputs</span><span class="p">]</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">attn_plug</span><span class="p">)</span>
|
||
<span class="n">_add_plugin_info</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">attn_plg_creator</span><span class="p">,</span> <span class="s2">"causal_attn"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="n">output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
<span class="n">present_key_value</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="k">if</span> <span class="n">use_cache</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">paged_kv_cache_flag</span><span class="p">:</span>
|
||
<span class="n">present_key_value</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">present_key_value</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="n">expected_outputs</span> <span class="o">=</span> <span class="mi">2</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">expected_outputs</span> <span class="o">=</span> <span class="mi">1</span>
|
||
|
||
<span class="k">assert</span> <span class="n">layer</span><span class="o">.</span><span class="n">num_outputs</span> <span class="o">==</span> <span class="n">expected_outputs</span><span class="p">,</span> \
|
||
<span class="sa">f</span><span class="s2">"Plugin outputs number mismatch with expected, got </span><span class="si">{</span><span class="n">layer</span><span class="o">.</span><span class="n">num_outputs</span><span class="si">}</span><span class="s2">, expected </span><span class="si">{</span><span class="n">expected_outputs</span><span class="si">}</span><span class="s2">"</span>
|
||
|
||
<span class="k">if</span> <span class="n">kv_cache_quant_mode</span><span class="o">.</span><span class="n">has_int8_kv_cache</span><span class="p">(</span>
|
||
<span class="p">)</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">strongly_typed</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="n">paged_kv_cache_flag</span><span class="p">:</span>
|
||
<span class="c1"># past key value</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">get_input</span><span class="p">(</span><span class="mi">8</span><span class="p">)</span><span class="o">.</span><span class="n">set_dynamic_range</span><span class="p">(</span><span class="o">-</span><span class="mi">127</span><span class="p">,</span> <span class="mi">127</span><span class="p">)</span>
|
||
<span class="c1"># present key value</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">set_dynamic_range</span><span class="p">(</span><span class="o">-</span><span class="mi">127</span><span class="p">,</span> <span class="mi">127</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">get_input</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">set_dynamic_range</span><span class="p">(</span><span class="o">-</span><span class="mi">127</span><span class="p">,</span> <span class="mi">127</span><span class="p">)</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">get_input</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">set_dynamic_range</span><span class="p">(</span><span class="o">-</span><span class="mi">127</span><span class="p">,</span> <span class="mi">127</span><span class="p">)</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">set_dynamic_range</span><span class="p">(</span><span class="o">-</span><span class="mi">127</span><span class="p">,</span> <span class="mi">127</span><span class="p">)</span>
|
||
|
||
<span class="k">assert</span> <span class="n">output</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="k">return</span> <span class="n">output</span><span class="p">,</span> <span class="n">present_key_value</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="assertion">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.assertion">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">assertion</span><span class="p">(</span><span class="n">condition</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">message</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s1">''</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_assertion</span><span class="p">(</span><span class="n">condition</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">message</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="layer_norm">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.layer_norm">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">layer_norm</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">normalized_shape</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">]],</span>
|
||
<span class="n">weight</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">bias</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">eps</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-05</span><span class="p">,</span>
|
||
<span class="n">use_diff_of_squares</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add a layer-norm operation on a tensor.</span>
|
||
|
||
<span class="sd"> That operation applies the layer-normalization to its input tensor. In its</span>
|
||
<span class="sd"> simplest form, for large language models, the 'normalized_shape' should be</span>
|
||
<span class="sd"> set to the hidden dimension of the activation tensor. Otherwise, it is the</span>
|
||
<span class="sd"> shape of the normalized fraction of the tensor (starting from the</span>
|
||
<span class="sd"> right-most dimension).</span>
|
||
|
||
<span class="sd"> The 'weight' tensor corresponds to 'gamma' in the layer-norm formula and</span>
|
||
<span class="sd"> 'bias' is 'beta'. The 'eps' value is added to the variance before computing</span>
|
||
<span class="sd"> the squared-root.</span>
|
||
|
||
<span class="sd"> This implementation (when using the plugin) supports an additional flag to</span>
|
||
<span class="sd"> enable/disable the use of a difference of squares ('Var = Mean(X^2) -</span>
|
||
<span class="sd"> Mean(X)^2').</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The tensor to normalize.</span>
|
||
|
||
<span class="sd"> normalized_shape : Union[int, Tuple[int]]</span>
|
||
<span class="sd"> The shape of the sub-tensor that is normalized. Use 'hidden_dim' to</span>
|
||
<span class="sd"> normalize the inner-most dimension of an activation tensor in LLMs.</span>
|
||
|
||
<span class="sd"> weight : Optional[Tensor] = None</span>
|
||
<span class="sd"> The 'gamma' term in layer-norm. Its shape must be</span>
|
||
<span class="sd"> 'normalized_shape'.</span>
|
||
|
||
<span class="sd"> bias : Optional[Tensor] = None</span>
|
||
<span class="sd"> The 'beta' term in layer-norm. Its shape must be</span>
|
||
<span class="sd"> 'normalized_shape'.</span>
|
||
|
||
<span class="sd"> eps : float</span>
|
||
<span class="sd"> The epsilon term to be added to the variance in the squared-root.</span>
|
||
|
||
<span class="sd"> use_diff_of_squares : bool</span>
|
||
<span class="sd"> Does the plugin use the difference of squares to compute the</span>
|
||
<span class="sd"> variance?</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The output tensor of that operation.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="nb">input</span><span class="p">,</span> <span class="n">weight</span> <span class="o">=</span> <span class="n">broadcast_helper</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">weight</span><span class="p">)</span>
|
||
<span class="nb">input</span><span class="p">,</span> <span class="n">bias</span> <span class="o">=</span> <span class="n">broadcast_helper</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">bias</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">normalized_shape</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span> <span class="c1"># FIXME: better way?</span>
|
||
<span class="n">axis</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">-</span> <span class="mi">1</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">axis</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">-</span> <span class="nb">len</span><span class="p">(</span><span class="n">normalized_shape</span><span class="p">)</span>
|
||
<span class="n">axes_mask</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">axis</span><span class="p">,</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()):</span>
|
||
<span class="n">axes_mask</span> <span class="o">|=</span> <span class="mi">1</span> <span class="o"><<</span> <span class="n">i</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_normalization</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">weight</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">bias</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">axes_mask</span><span class="p">)</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">epsilon</span> <span class="o">=</span> <span class="n">eps</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="rms_norm">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.rms_norm">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">rms_norm</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">normalized_shape</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">]],</span>
|
||
<span class="n">num_groups</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">weight</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">eps</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-06</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add a RMS norm operation on a tensor.</span>
|
||
|
||
<span class="sd"> That operation applies the rms-normalization to its input tensor. In its</span>
|
||
<span class="sd"> simplest form, for large language models, the 'normalized_shape' should be</span>
|
||
<span class="sd"> set to the hidden dimension of the activation tensor. Otherwise, it is the</span>
|
||
<span class="sd"> shape of the normalized fraction of the tensor (starting from the</span>
|
||
<span class="sd"> right-most dimension).</span>
|
||
|
||
<span class="sd"> The 'weight' tensor corresponds to 'gamma' in the rms-norm formula.</span>
|
||
<span class="sd"> The 'eps' value is added to the variance before computing the squared-root.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input: Tensor</span>
|
||
<span class="sd"> The tensor to normalize.</span>
|
||
|
||
<span class="sd"> normalized_shape : Union[int, Tuple[int]]</span>
|
||
<span class="sd"> The shape of the sub-tensor that is normalized. Use 'hidden_dim' to</span>
|
||
<span class="sd"> normalize the inner-most dimension of an activation tensor in LLMs.</span>
|
||
|
||
<span class="sd"> num_groups: int = 1</span>
|
||
<span class="sd"> The group size.</span>
|
||
|
||
<span class="sd"> weight : Optional[Tensor] = None</span>
|
||
<span class="sd"> The 'gamma' term in layer-norm. Its shape must be</span>
|
||
<span class="sd"> 'normalized_shape'.</span>
|
||
|
||
<span class="sd"> eps : float</span>
|
||
<span class="sd"> The epsilon term to be added to the variance in the squared-root.weig</span>
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The output tensor of that operation.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">normalized_shape</span> <span class="o">=</span> <span class="p">[</span><span class="n">normalized_shape</span><span class="p">]</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span>
|
||
<span class="n">normalized_shape</span><span class="p">,</span> <span class="nb">int</span><span class="p">)</span> <span class="k">else</span> <span class="n">normalized_shape</span>
|
||
|
||
<span class="n">dim</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">([</span><span class="o">-</span><span class="n">i</span> <span class="o">-</span> <span class="mi">1</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">normalized_shape</span><span class="p">))])</span>
|
||
|
||
<span class="k">if</span> <span class="n">num_groups</span> <span class="o">></span> <span class="mi">1</span><span class="p">:</span>
|
||
<span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">normalized_shape</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span>
|
||
<span class="n">num_channels</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
|
||
<span class="n">ndim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
|
||
<span class="n">old_shape</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">)</span>
|
||
<span class="n">new_shape</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span><span class="nb">input</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="n">i</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)]</span> <span class="o">+</span>
|
||
<span class="p">[</span><span class="n">num_groups</span><span class="p">,</span> <span class="n">num_channels</span> <span class="o">//</span> <span class="n">num_groups</span><span class="p">])</span>
|
||
<span class="nb">input</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">new_shape</span><span class="p">)</span>
|
||
|
||
<span class="k">with</span> <span class="n">precision</span><span class="p">(</span><span class="s2">"float32"</span><span class="p">):</span>
|
||
<span class="n">input_dtype</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">dtype</span>
|
||
<span class="n">fp32_input</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="s2">"float32"</span><span class="p">)</span>
|
||
<span class="n">varx</span> <span class="o">=</span> <span class="nb">pow</span><span class="p">(</span><span class="n">fp32_input</span><span class="p">,</span> <span class="mf">2.0</span><span class="p">)</span>
|
||
|
||
<span class="n">varx</span> <span class="o">=</span> <span class="n">varx</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="n">dim</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||
<span class="n">denom</span> <span class="o">=</span> <span class="n">varx</span> <span class="o">+</span> <span class="n">eps</span>
|
||
<span class="n">denom</span> <span class="o">=</span> <span class="n">denom</span><span class="o">.</span><span class="n">sqrt</span><span class="p">()</span>
|
||
<span class="n">fp32_y</span> <span class="o">=</span> <span class="n">fp32_input</span> <span class="o">/</span> <span class="n">denom</span>
|
||
<span class="n">y</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">fp32_y</span><span class="p">,</span> <span class="n">input_dtype</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">num_groups</span> <span class="o">></span> <span class="mi">1</span><span class="p">:</span>
|
||
<span class="n">y</span> <span class="o">=</span> <span class="n">y</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">old_shape</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">weight</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">y</span> <span class="o">=</span> <span class="n">y</span> <span class="o">*</span> <span class="n">weight</span>
|
||
|
||
<span class="k">return</span> <span class="n">y</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="repeat_interleave">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.repeat_interleave">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">repeat_interleave</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">repeats</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Repeats elements of a tensor along an axis.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> repeats : int</span>
|
||
<span class="sd"> The number of repetitions along axis specified.</span>
|
||
<span class="sd"> dim : int</span>
|
||
<span class="sd"> The dimension along which repetitions are performed.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> A tensor with the same shape as input except for repeated elements along specified dim.</span>
|
||
|
||
<span class="sd"> TODO: Allow repeats to be a list of integers and dim to be unspecified.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">expanded_tensor</span> <span class="o">=</span> <span class="n">expand_dims</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">dim</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
|
||
<span class="n">tile_output_size</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span>
|
||
<span class="n">repeats</span> <span class="k">if</span> <span class="n">i</span> <span class="o">==</span> <span class="p">(</span><span class="n">dim</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="k">else</span> <span class="n">shape</span><span class="p">(</span><span class="n">expanded_tensor</span><span class="p">,</span> <span class="n">i</span><span class="p">)</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">expanded_tensor</span><span class="o">.</span><span class="n">ndim</span><span class="p">())</span>
|
||
<span class="p">])</span>
|
||
<span class="n">tile</span> <span class="o">=</span> <span class="n">expand</span><span class="p">(</span><span class="n">expanded_tensor</span><span class="p">,</span> <span class="n">tile_output_size</span><span class="p">)</span>
|
||
<span class="n">tile_reshape_size</span> <span class="o">=</span> <span class="p">[</span><span class="n">shape</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">i</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">ndim</span><span class="p">())]</span>
|
||
<span class="n">tile_reshape_size</span><span class="p">[</span><span class="n">dim</span><span class="p">]</span> <span class="o">=</span> <span class="n">tile_reshape_size</span><span class="p">[</span><span class="n">dim</span><span class="p">]</span> <span class="o">*</span> <span class="n">repeats</span>
|
||
<span class="n">tensor</span> <span class="o">=</span> <span class="n">tile</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">concat</span><span class="p">(</span><span class="n">tile_reshape_size</span><span class="p">))</span>
|
||
<span class="k">return</span> <span class="n">tensor</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="generate_logn_scaling">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.generate_logn_scaling">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">generate_logn_scaling</span><span class="p">(</span><span class="n">seq_length</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">8192</span><span class="p">,</span>
|
||
<span class="n">max_position_embeddings</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">32768</span><span class="p">)</span> <span class="o">-></span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Compute the Log-N scaling vector for Qwen inference extrapolation</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> seq_length : int</span>
|
||
<span class="sd"> The max seq length in training (default to 8192 in Qwen-1)</span>
|
||
<span class="sd"> max_position_embeddings : int</span>
|
||
<span class="sd"> The max position embeddings. (default to 32768 in Qwen-1)</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> A constant np.ndarray that contains logn scaling vector</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">logn_list</span> <span class="o">=</span> <span class="p">[</span>
|
||
<span class="n">math</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="n">seq_length</span><span class="p">)</span> <span class="k">if</span> <span class="n">i</span> <span class="o">></span> <span class="n">seq_length</span> <span class="k">else</span> <span class="mi">1</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">max_position_embeddings</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
|
||
<span class="p">]</span>
|
||
<span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">asarray</span><span class="p">(</span><span class="n">logn_list</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="generate_alibi_slopes">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.generate_alibi_slopes">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">generate_alibi_slopes</span><span class="p">(</span><span class="n">num_heads</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">tp_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">tp_rank</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">alibi_scale</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span>
|
||
<span class="n">alibi_bias_max</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">8</span><span class="p">)</span> <span class="o">-></span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Compute the ALiBi slopes as described in https://arxiv.org/abs/2211.05100.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> num_heads : int</span>
|
||
<span class="sd"> The number of heads.</span>
|
||
<span class="sd"> dtype : trt.DataType</span>
|
||
<span class="sd"> The data type of the returned slopes</span>
|
||
<span class="sd"> tp_size : int</span>
|
||
<span class="sd"> The tensor parallelism size</span>
|
||
<span class="sd"> tp_rank : int</span>
|
||
<span class="sd"> The tensor parallelism rank</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> A constant tensor that contains the ALiBi slopes.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">start_head_id</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="n">end_head_id</span> <span class="o">=</span> <span class="n">num_heads</span>
|
||
|
||
<span class="k">if</span> <span class="n">tp_size</span> <span class="o">></span> <span class="mi">1</span><span class="p">:</span>
|
||
<span class="n">rank_heads</span> <span class="o">=</span> <span class="n">num_heads</span> <span class="o">//</span> <span class="n">tp_size</span>
|
||
<span class="n">start_head_id</span> <span class="o">=</span> <span class="n">rank_heads</span> <span class="o">*</span> <span class="n">tp_rank</span>
|
||
<span class="n">end_head_id</span> <span class="o">=</span> <span class="n">start_head_id</span> <span class="o">+</span> <span class="n">rank_heads</span>
|
||
|
||
<span class="n">closest_power_of_2</span> <span class="o">=</span> <span class="mi">2</span><span class="o">**</span><span class="n">np</span><span class="o">.</span><span class="n">floor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">log2</span><span class="p">(</span><span class="n">num_heads</span><span class="p">))</span>
|
||
<span class="c1"># FT's implementation</span>
|
||
<span class="c1"># https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/kernels/gen_relative_pos_bias.cu#L248</span>
|
||
<span class="n">slopes_ft</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="k">for</span> <span class="n">h_id</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">start_head_id</span><span class="p">,</span> <span class="n">end_head_id</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="n">h_id</span> <span class="o"><</span> <span class="n">closest_power_of_2</span><span class="p">:</span>
|
||
<span class="n">slopes_ft</span><span class="o">.</span><span class="n">append</span><span class="p">(</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">power</span><span class="p">(</span>
|
||
<span class="mi">2</span><span class="o">**</span><span class="p">(</span><span class="o">-</span><span class="p">(</span><span class="mi">2</span><span class="o">**-</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">log2</span><span class="p">(</span><span class="n">closest_power_of_2</span><span class="p">)</span> <span class="o">-</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">log2</span><span class="p">(</span><span class="n">alibi_bias_max</span><span class="p">)))),</span> <span class="n">h_id</span> <span class="o">+</span> <span class="mi">1</span><span class="p">))</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">slopes_ft</span><span class="o">.</span><span class="n">append</span><span class="p">(</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">power</span><span class="p">(</span>
|
||
<span class="mi">2</span><span class="o">**</span><span class="p">(</span><span class="o">-</span><span class="p">(</span><span class="mi">2</span><span class="o">**-</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">log2</span><span class="p">(</span><span class="n">closest_power_of_2</span> <span class="o">*</span> <span class="mi">2</span><span class="p">)</span> <span class="o">-</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">log2</span><span class="p">(</span><span class="n">alibi_bias_max</span><span class="p">)))),</span>
|
||
<span class="p">(</span><span class="n">h_id</span> <span class="o">-</span> <span class="n">closest_power_of_2</span><span class="p">)</span> <span class="o">*</span> <span class="mi">2</span> <span class="o">+</span> <span class="mi">1</span><span class="p">))</span>
|
||
<span class="n">slopes</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">asarray</span><span class="p">(</span><span class="n">slopes_ft</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
||
|
||
<span class="n">slopes</span> <span class="o">=</span> <span class="n">alibi_scale</span> <span class="o">*</span> <span class="n">slopes</span>
|
||
<span class="n">slopes</span> <span class="o">=</span> <span class="n">slopes</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="p">(</span><span class="n">end_head_id</span> <span class="o">-</span> <span class="n">start_head_id</span><span class="p">),</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">slopes</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="generate_alibi_biases">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.generate_alibi_biases">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">generate_alibi_biases</span><span class="p">(</span><span class="n">slopes</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">key_length</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Compute the ALiBi biases as described in https://arxiv.org/abs/2211.05100.</span>
|
||
|
||
<span class="sd"> The ALiBi biases are added to the result of the Q*K^T product in the</span>
|
||
<span class="sd"> multi-head attention block.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> slopes : Tensor</span>
|
||
<span class="sd"> The slopes.</span>
|
||
|
||
<span class="sd"> key_length : Tensor</span>
|
||
<span class="sd"> The size of the K vector per head.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> A constant tensor that contains the ALiBi biases.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="c1"># We don't need to care about the batch size or query length since we can just broadcast</span>
|
||
<span class="c1"># across the batch and query dimensions</span>
|
||
|
||
<span class="n">trt_0</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">int32_array</span><span class="p">(</span><span class="mi">0</span><span class="p">))</span>
|
||
<span class="n">arange_shape</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">key_length</span><span class="p">])</span>
|
||
|
||
<span class="n">arange_tensor</span> <span class="o">=</span> <span class="n">arange</span><span class="p">(</span><span class="n">trt_0</span><span class="p">,</span> <span class="n">key_length</span><span class="p">,</span> <span class="s2">"float32"</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">arange_shape</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">slopes</span> <span class="o">*</span> <span class="n">arange_tensor</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="expand_mask">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.expand_mask">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">expand_mask</span><span class="p">(</span><span class="n">mask</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">tgt_len</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Expand an attention mask.</span>
|
||
|
||
<span class="sd"> That function adds the sequence of operations to expand from a tensor of</span>
|
||
<span class="sd"> shape '[batch_size, src_seq_len]' to a tensor of shape</span>
|
||
<span class="sd"> '[batch_size, 1, tgt_seq_len, src_seq_len]'. It can be used to create the</span>
|
||
<span class="sd"> mask applied to the Q*K^T product before the softmax operation in the</span>
|
||
<span class="sd"> multi-head attention block.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> mask : Tensor</span>
|
||
<span class="sd"> The input mask</span>
|
||
|
||
<span class="sd"> tgt_len : Optional[Tensor]</span>
|
||
<span class="sd"> The dimension of the 3rd dimension in the output tensor. If None,</span>
|
||
<span class="sd"> the 2nd dimension of the input is used.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor created by that sequence of operations.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">bsz</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
|
||
<span class="n">src_len</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
|
||
<span class="n">tgt_len</span> <span class="o">=</span> <span class="n">tgt_len</span> <span class="k">if</span> <span class="n">tgt_len</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">src_len</span>
|
||
|
||
<span class="n">mask</span> <span class="o">=</span> <span class="n">mask</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">concat</span><span class="p">([</span><span class="n">bsz</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">src_len</span><span class="p">]))</span>
|
||
|
||
<span class="n">mask</span> <span class="o">=</span> <span class="n">expand</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="n">concat</span><span class="p">([</span><span class="n">bsz</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">tgt_len</span><span class="p">,</span> <span class="n">src_len</span><span class="p">]))</span>
|
||
<span class="n">mask</span> <span class="o">=</span> <span class="n">where</span><span class="p">(</span><span class="n">mask</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="nb">float</span><span class="p">(</span><span class="s1">'-inf'</span><span class="p">),</span> <span class="mf">0.0</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">mask</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="gather_last_token_logits">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.gather_last_token_logits">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">gather_last_token_logits</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">last_token_ids</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">remove_input_padding</span><span class="p">:</span> <span class="nb">bool</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Extract the logits that correspond to the last token from the hidden states.</span>
|
||
|
||
<span class="sd"> That function adds the operations to extract the logits of the last tokens</span>
|
||
<span class="sd"> in a batch of sequences.</span>
|
||
|
||
<span class="sd"> Depending on whether 'remove_input_padding' is 'True' or 'False', that</span>
|
||
<span class="sd"> function assumes inputs of different shapes.</span>
|
||
|
||
<span class="sd"> When 'remove_input_padding' is 'True', the 'hidden_states' tensor is</span>
|
||
<span class="sd"> assumed to be packed. It has a shape '[num_tokens, hidden_dim]' where</span>
|
||
<span class="sd"> 'num_tokens' is the sum of the lengths of the sequences in the batch and</span>
|
||
<span class="sd"> 'hidden_dim' is the hidden dimension. The 'last_tokens_ids' is a 1D tensor</span>
|
||
<span class="sd"> that encodes the inclusive prefix-sums of the lengths of the sequences in</span>
|
||
<span class="sd"> the batch.</span>
|
||
|
||
<span class="sd"> When 'remove_input_padding' is 'False', the 'hidden_states' tensor is</span>
|
||
<span class="sd"> assumed to be padded. It has a shape '[batch_size, max_seqlen, hidden_dim]'</span>
|
||
<span class="sd"> where 'max_seqlen' is the length of the longest sequence in the batch and</span>
|
||
<span class="sd"> 'hidden_dim' is the hidden dimension. The 'last_token_ids' is a 1D tensor</span>
|
||
<span class="sd"> that encodes the length of each sequence in the batch.</span>
|
||
|
||
<span class="sd"> In both cases, that function produces a tensor of shape '[batch_size,</span>
|
||
<span class="sd"> hidden_size]' where the row at index 'i' corresponds to the logits of the</span>
|
||
<span class="sd"> last token from the 'i'-th sequence.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> hidden_states : Tensor</span>
|
||
<span class="sd"> The hidden states</span>
|
||
|
||
<span class="sd"> last_token_ids : Tensor</span>
|
||
<span class="sd"> The inclusive prefix-sum of the lengths or the lengths of the</span>
|
||
<span class="sd"> sequences in the batch.</span>
|
||
|
||
<span class="sd"> remove_input_padding : bool</span>
|
||
<span class="sd"> Indicate if the hidden_states are packed ('True') or padded</span>
|
||
<span class="sd"> ('False').</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor created by that sequence of operations.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="n">last_token_ids</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">hidden_states</span>
|
||
|
||
<span class="k">if</span> <span class="n">remove_input_padding</span><span class="p">:</span>
|
||
<span class="n">hidden_states</span> <span class="o">=</span> <span class="n">index_select</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">last_token_ids</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="c1"># [seq_len, hidden]</span>
|
||
|
||
<span class="n">hidden_states</span> <span class="o">=</span> <span class="n">hidden_states</span><span class="o">.</span><span class="n">view</span><span class="p">(</span>
|
||
<span class="n">concat</span><span class="p">([</span><span class="n">shape</span><span class="p">(</span><span class="n">last_token_ids</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span>
|
||
<span class="n">shape</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span> <span class="mi">1</span><span class="p">)]))</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">ndim</span> <span class="o">=</span> <span class="n">last_token_ids</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
|
||
<span class="k">if</span> <span class="n">ndim</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
|
||
<span class="c1"># only calculate logits for the last token</span>
|
||
<span class="c1"># [batch_size, seqlen, hidden_size] -> [batch_size, hidden_size]</span>
|
||
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">last_token_ids</span><span class="o">.</span><span class="n">view</span><span class="p">(</span>
|
||
<span class="n">concat</span><span class="p">([</span><span class="n">shape</span><span class="p">(</span><span class="n">last_token_ids</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">]))</span>
|
||
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">expand</span><span class="p">(</span>
|
||
<span class="n">last_token_ids</span><span class="p">,</span>
|
||
<span class="n">concat</span><span class="p">([</span><span class="n">shape</span><span class="p">(</span><span class="n">last_token_ids</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span> <span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">shape</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span> <span class="mi">2</span><span class="p">)]))</span>
|
||
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">last_token_ids</span> <span class="o">-</span> <span class="mi">1</span>
|
||
<span class="n">hidden_states</span> <span class="o">=</span> <span class="n">gather</span><span class="p">(</span>
|
||
<span class="n">hidden_states</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">indices</span><span class="o">=</span><span class="n">last_token_ids</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span>
|
||
<span class="n">concat</span><span class="p">([</span><span class="n">shape</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span>
|
||
<span class="n">shape</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span> <span class="mi">2</span><span class="p">)]))</span>
|
||
<span class="k">elif</span> <span class="n">ndim</span> <span class="o">==</span> <span class="mi">2</span><span class="p">:</span> <span class="c1"># speculative decoding needs last few token's logits</span>
|
||
<span class="c1"># last_token_ids is of shape [batch_size, num_last_tokens]</span>
|
||
<span class="c1"># So [batch_size, seqlen, hidden_size] -> [batch_size, num_last_tokens, hidden_size]</span>
|
||
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">last_token_ids</span><span class="o">.</span><span class="n">view</span><span class="p">(</span>
|
||
<span class="n">concat</span><span class="p">([</span><span class="n">shape</span><span class="p">(</span><span class="n">last_token_ids</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span>
|
||
<span class="n">shape</span><span class="p">(</span><span class="n">last_token_ids</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="mi">1</span><span class="p">]))</span>
|
||
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">expand</span><span class="p">(</span>
|
||
<span class="n">last_token_ids</span><span class="p">,</span>
|
||
<span class="n">concat</span><span class="p">([</span>
|
||
<span class="n">shape</span><span class="p">(</span><span class="n">last_token_ids</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span>
|
||
<span class="n">shape</span><span class="p">(</span><span class="n">last_token_ids</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
|
||
<span class="n">shape</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
|
||
<span class="p">]))</span>
|
||
<span class="n">hidden_states</span> <span class="o">=</span> <span class="n">gather</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">indices</span><span class="o">=</span><span class="n">last_token_ids</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">hidden_states</span></div>
|
||
|
||
|
||
|
||
<span class="n">ACT2FN</span> <span class="o">=</span> <span class="p">{</span>
|
||
<span class="s1">'relu'</span><span class="p">:</span> <span class="n">relu</span><span class="p">,</span>
|
||
<span class="s1">'tanh'</span><span class="p">:</span> <span class="n">tanh</span><span class="p">,</span>
|
||
<span class="s1">'gelu'</span><span class="p">:</span> <span class="n">gelu</span><span class="p">,</span>
|
||
<span class="s1">'gelu_new'</span><span class="p">:</span> <span class="n">gelu</span><span class="p">,</span>
|
||
<span class="s1">'gelu_fast'</span><span class="p">:</span> <span class="n">gelu</span><span class="p">,</span>
|
||
<span class="s1">'gelu_pytorch_tanh'</span><span class="p">:</span> <span class="n">gelu</span><span class="p">,</span>
|
||
<span class="s1">'openai-gelu'</span><span class="p">:</span> <span class="n">gelu</span><span class="p">,</span>
|
||
<span class="s1">'geglu'</span><span class="p">:</span> <span class="n">geglu</span><span class="p">,</span>
|
||
<span class="s1">'gegelu'</span><span class="p">:</span> <span class="n">gegelu</span><span class="p">,</span>
|
||
<span class="s1">'identity'</span><span class="p">:</span> <span class="n">identity</span><span class="p">,</span>
|
||
<span class="s1">'silu'</span><span class="p">:</span> <span class="n">silu</span><span class="p">,</span>
|
||
<span class="s1">'softplus'</span><span class="p">:</span> <span class="n">softplus</span><span class="p">,</span>
|
||
<span class="s1">'relu2'</span><span class="p">:</span> <span class="n">squared_relu</span><span class="p">,</span>
|
||
<span class="s1">'squared-relu'</span><span class="p">:</span> <span class="n">squared_relu</span><span class="p">,</span>
|
||
<span class="s1">'swiglu'</span><span class="p">:</span> <span class="n">swiglu</span><span class="p">,</span>
|
||
<span class="s1">'fast-swiglu'</span><span class="p">:</span> <span class="n">swiglu</span><span class="p">,</span>
|
||
<span class="p">}</span>
|
||
|
||
<span class="n">GATED_ACT_2_ACT</span> <span class="o">=</span> <span class="p">{</span>
|
||
<span class="s1">'swiglu'</span><span class="p">:</span> <span class="s1">'silu'</span><span class="p">,</span>
|
||
<span class="s1">'fast-swiglu'</span><span class="p">:</span> <span class="s1">'silu'</span><span class="p">,</span>
|
||
<span class="s1">'geglu'</span><span class="p">:</span> <span class="s1">'gelu'</span><span class="p">,</span>
|
||
<span class="p">}</span>
|
||
|
||
|
||
<div class="viewcode-block" id="is_gated_activation">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.is_gated_activation">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">is_gated_activation</span><span class="p">(</span><span class="n">activation</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Is a given activation function gated?</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> activation : str</span>
|
||
<span class="sd"> The name of the activation function.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> True if the function is gated, False otherwise.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">assert</span> <span class="n">activation</span> <span class="ow">in</span> <span class="n">ACT2FN</span>
|
||
<span class="k">return</span> <span class="n">activation</span> <span class="ow">in</span> <span class="n">GATED_ACT_2_ACT</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="non_gated_version">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.non_gated_version">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">non_gated_version</span><span class="p">(</span><span class="n">activation</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Given an activation function, get the non-gated version.</span>
|
||
|
||
<span class="sd"> If the activation function is non-gated, it returns the same activation</span>
|
||
<span class="sd"> function name.</span>
|
||
|
||
<span class="sd"> For example, that function returns 'silu' for 'swiglu' and 'relu' for</span>
|
||
<span class="sd"> 'relu'.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> activation : str</span>
|
||
<span class="sd"> The name of the activation function.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The name of the non-gated activation function.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="n">is_gated_activation</span><span class="p">(</span><span class="n">activation</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="n">GATED_ACT_2_ACT</span><span class="p">[</span><span class="n">activation</span><span class="p">]</span>
|
||
<span class="k">return</span> <span class="n">activation</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="lora_plugin">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.lora_plugin">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">lora_plugin</span><span class="p">(</span>
|
||
<span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">in_hidden_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">out_hidden_sizes</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="p">[</span><span class="mi">0</span><span class="p">],</span>
|
||
<span class="n">host_request_types</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">transa</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">transb</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">host_context_lengths</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="c1"># for pad-free input mode</span>
|
||
<span class="n">max_low_rank</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">lora_ranks</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">lora_weights_pointers</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">weight_index</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
|
||
<span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor (On GPU)</span>
|
||
<span class="sd"> The input tensor. Its shape is [batch_size, seq_len, dim] or [num_tokens, dim] for remove_input_padding</span>
|
||
|
||
<span class="sd"> in_hidden_size/out_hidden_size : int</span>
|
||
<span class="sd"> the lora computation workflow is</span>
|
||
<span class="sd"> [M, in_hidden_size] -> [M, low_rank] -> [M, out_hidden_size]</span>
|
||
|
||
<span class="sd"> host_request_types : Tensor = None</span>
|
||
<span class="sd"> The tensor on the host that indicates if a request is in context or</span>
|
||
<span class="sd"> generation phase. Its shape is [batch_size]. See Inflight Batching</span>
|
||
<span class="sd"> in docs/source/advanced/gpt-attention.md,</span>
|
||
|
||
<span class="sd"> transa : bool</span>
|
||
<span class="sd"> Is the first input transposed? Set to 'True' if you want the first</span>
|
||
<span class="sd"> input to be transposed, 'False' otherwise.</span>
|
||
|
||
<span class="sd"> transb : bool</span>
|
||
<span class="sd"> Is the second input transposed? Set to 'True' if you want the</span>
|
||
<span class="sd"> second input to be transposed, 'False' otherwise.</span>
|
||
|
||
<span class="sd"> host_context_lengths: cpu Tensor = None</span>
|
||
<span class="sd"> A host tensor that contains the lengths of the different inputs,</span>
|
||
|
||
<span class="sd"> max_low_rank : int</span>
|
||
<span class="sd"> Maximum low_rank, used to determine the workspace size.</span>
|
||
|
||
<span class="sd"> lora_ranks : cpu Tensor with shape [batch_size]</span>
|
||
<span class="sd"> The low_rank of each request</span>
|
||
|
||
<span class="sd"> lora_weights_pointers : cpu int64 Tensor with shape [batch_size, 2]</span>
|
||
<span class="sd"> The weights pointers of each request. Consist of in_pointer and out_pointer.</span>
|
||
|
||
<span class="sd"> weight_index : int</span>
|
||
<span class="sd"> The index of weight if the weight pointer pointing to multiple weights.</span>
|
||
|
||
<span class="sd"> Return:</span>
|
||
<span class="sd"> The tensor produced by that layer.</span>
|
||
|
||
<span class="sd"> '''</span>
|
||
<span class="k">assert</span> <span class="n">host_context_lengths</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">or</span> <span class="ow">not</span> <span class="n">default_net</span><span class="p">(</span>
|
||
<span class="p">)</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">remove_input_padding</span>
|
||
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_creator_list</span>
|
||
<span class="n">in_hidden_size_field</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"in_hidden_size"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">in_hidden_size</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">out_hidden_size_field_list</span> <span class="o">=</span> <span class="p">[</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="sa">f</span><span class="s2">"out_hidden_size_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s2">"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">o</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">o</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">out_hidden_sizes</span><span class="p">)</span>
|
||
<span class="p">]</span>
|
||
<span class="n">transa</span> <span class="o">=</span> <span class="mi">1</span> <span class="k">if</span> <span class="n">transa</span> <span class="k">else</span> <span class="mi">0</span>
|
||
<span class="n">transa</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"transa"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">transa</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">transb</span> <span class="o">=</span> <span class="mi">1</span> <span class="k">if</span> <span class="n">transb</span> <span class="k">else</span> <span class="mi">0</span>
|
||
<span class="n">transb</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"transb"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">transb</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
|
||
<span class="n">plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
|
||
<span class="s1">'Lora'</span><span class="p">,</span> <span class="s1">'1'</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
|
||
<span class="n">p_dtype</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">lora_plugin</span>
|
||
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"type_id"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">))],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">remove_input_padding</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"remove_input_padding"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">max_low_rank_field</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"max_low_rank"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">max_low_rank</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">weight_index_field</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"weight_index"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">weight_index</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">num_lora_modules</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">out_hidden_sizes</span><span class="p">)</span>
|
||
<span class="n">num_lora_modules_field</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"num_lora_modules"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">num_lora_modules</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
|
||
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span>
|
||
<span class="n">in_hidden_size_field</span><span class="p">,</span> <span class="n">transa</span><span class="p">,</span> <span class="n">transb</span><span class="p">,</span> <span class="n">num_lora_modules_field</span><span class="p">,</span> <span class="n">pf_type</span><span class="p">,</span>
|
||
<span class="n">remove_input_padding</span><span class="p">,</span> <span class="n">max_low_rank_field</span><span class="p">,</span> <span class="n">weight_index_field</span>
|
||
<span class="p">]</span> <span class="o">+</span> <span class="n">out_hidden_size_field_list</span><span class="p">)</span>
|
||
<span class="n">lora_plug</span> <span class="o">=</span> <span class="n">plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">"lora"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="nb">input</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">),</span> <span class="n">host_request_types</span>
|
||
<span class="p">]</span> <span class="o">+</span> <span class="n">lora_ranks</span> <span class="o">+</span> <span class="n">lora_weights_pointers</span>
|
||
|
||
<span class="k">if</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">host_context_lengths</span><span class="p">]</span>
|
||
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">i</span><span class="o">.</span><span class="n">trt_tensor</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">plug_inputs</span><span class="p">]</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">lora_plug</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">num_lora_modules</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="p">[</span>
|
||
<span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="n">i</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_lora_modules</span><span class="p">)</span>
|
||
<span class="p">]</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="mamba_conv1d">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.mamba_conv1d">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">mamba_conv1d</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">conv_state_or_ptr</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">conv_weight</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">conv_bias</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">host_request_types</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">last_token_ids</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">dconv</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
|
||
<span class="n">pre_stride</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">post_stride</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">host_context_lengths</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">slot_mapping</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">apply_silu</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor (On GPU)</span>
|
||
<span class="sd"> The input tensor. Its shape is [batch_size, seq_len, dim] or [num_tokens, dim] for remove_input_padding</span>
|
||
|
||
<span class="sd"> conv_state_or_ptr : Tensor (On GPU or CPU)</span>
|
||
<span class="sd"> The conv state tensor. Its shape is [batch_size, dconv - 1, dim]</span>
|
||
<span class="sd"> Or the CPU tensor of shape [1] for the pointer of paged states.</span>
|
||
|
||
<span class="sd"> conv_weight : Tensor (On GPU)</span>
|
||
<span class="sd"> The weight tensor. Its shape is [1, dconv, dim]</span>
|
||
|
||
<span class="sd"> conv_bias : Tensor (On GPU)</span>
|
||
<span class="sd"> The bias tensor. Its shape is [dim]</span>
|
||
|
||
<span class="sd"> host_request_types : Tensor (On CPU)</span>
|
||
<span class="sd"> The tensor on the host that indicates if a request is in context or</span>
|
||
<span class="sd"> generation phase. Its shape is [batch_size]. See Inflight Batching</span>
|
||
<span class="sd"> in docs/source/advanced/gpt-attention.md,</span>
|
||
|
||
<span class="sd"> last_token_ids : Tensor (On GPU)</span>
|
||
<span class="sd"> The inclusive prefix-sum of the lengths or the lengths of the</span>
|
||
<span class="sd"> sequences in the batch.</span>
|
||
|
||
<span class="sd"> dim : int</span>
|
||
<span class="sd"> The hidden dimension of conv1d</span>
|
||
|
||
<span class="sd"> dconv : int</span>
|
||
<span class="sd"> The window size of conv1d</span>
|
||
|
||
<span class="sd"> dtype: str</span>
|
||
<span class="sd"> data type</span>
|
||
|
||
<span class="sd"> pre_stride : int = 0</span>
|
||
<span class="sd"> The (pre) stride size of the input tensor.</span>
|
||
<span class="sd"> The valid values of the input tensor are input[..., pre_stride: dim-post_stride]</span>
|
||
|
||
<span class="sd"> post_stride : int = 0</span>
|
||
<span class="sd"> The (post) stride size of the input tensor.</span>
|
||
<span class="sd"> The valid values of the input tensor are input[..., pre_stride: dim-post_stride]</span>
|
||
|
||
<span class="sd"> host_context_lengths: Tensor (On CPU) (Optional)</span>
|
||
<span class="sd"> A host tensor that contains the lengths of the different inputs,</span>
|
||
|
||
<span class="sd"> slot_mapping: Tensor (On GPU) (Optional)</span>
|
||
<span class="sd"> Real page index in state. Its shape is [dim], used for paged state, each page shape is [dconv, dim]</span>
|
||
|
||
<span class="sd"> apply_silu: bool</span>
|
||
<span class="sd"> Is there a SiLU operation after the conv1d? When True apply</span>
|
||
<span class="sd"> SiLU activation function after the conv1d.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">assert</span> <span class="n">host_request_types</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="n">mamba_conv1d_plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
|
||
<span class="s1">'MambaConv1d'</span><span class="p">,</span> <span class="s1">'1'</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">mamba_conv1d_plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
|
||
<span class="n">dim</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"dim"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">dconv</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"dconv"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">dconv</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">pre_stride</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"pre_stride"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">pre_stride</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">post_stride</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"post_stride"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">post_stride</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"type_id"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">dtype</span><span class="p">))],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">remove_input_padding</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"remove_input_padding"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">paged_state</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"paged_state"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">paged_state</span><span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">apply_silu</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"apply_silu"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">apply_silu</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
|
||
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span>
|
||
<span class="n">dim</span><span class="p">,</span> <span class="n">dconv</span><span class="p">,</span> <span class="n">pre_stride</span><span class="p">,</span> <span class="n">post_stride</span><span class="p">,</span> <span class="n">pf_type</span><span class="p">,</span> <span class="n">remove_input_padding</span><span class="p">,</span>
|
||
<span class="n">paged_state</span><span class="p">,</span> <span class="n">apply_silu</span>
|
||
<span class="p">])</span>
|
||
<span class="n">mamba_conv1d_plug</span> <span class="o">=</span> <span class="n">mamba_conv1d_plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span>
|
||
<span class="s2">"mamba_conv1d"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span>
|
||
<span class="nb">input</span><span class="p">,</span> <span class="n">conv_state_or_ptr</span><span class="p">,</span> <span class="n">conv_weight</span><span class="p">,</span> <span class="n">conv_bias</span><span class="p">,</span> <span class="n">host_request_types</span><span class="p">,</span>
|
||
<span class="n">last_token_ids</span>
|
||
<span class="p">]</span>
|
||
<span class="k">if</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">host_context_lengths</span><span class="p">]</span>
|
||
<span class="k">if</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">paged_state</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">slot_mapping</span><span class="p">]</span>
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">i</span><span class="o">.</span><span class="n">trt_tensor</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">plug_inputs</span><span class="p">]</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">mamba_conv1d_plug</span><span class="p">)</span>
|
||
<span class="n">_add_plugin_info</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">mamba_conv1d_plg_creator</span><span class="p">,</span> <span class="s2">"mamba_conv1d"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="n">output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">paged_state</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">output</span><span class="p">,</span> <span class="kc">None</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">present_state</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">output</span><span class="p">,</span> <span class="n">present_state</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="selective_scan">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.selective_scan">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">selective_scan</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">state_or_ptr</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">delta</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">delta_bias</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">A</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">BC</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">D</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">host_request_types</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">last_token_ids</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">dstate</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">dt_rank</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">delta_softplus</span><span class="p">:</span> <span class="nb">bool</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
|
||
<span class="n">z</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">host_context_lengths</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">slot_mapping</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">nheads</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">ngroups</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">chunk_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">256</span><span class="p">,</span>
|
||
<span class="n">mamba_version</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s1">'Mamba1'</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor (On GPU)</span>
|
||
<span class="sd"> The input tensor. Its shape is [batch_size, seq_len, dim]</span>
|
||
|
||
<span class="sd"> state_or_ptr : Tensor (On GPU or CPU)</span>
|
||
<span class="sd"> The ssm state tensor. Its shape is [batch_size, dstate, dim]</span>
|
||
<span class="sd"> Or the CPU tensor of shape [1] for the pointer of paged states.</span>
|
||
|
||
<span class="sd"> delta : Tensor (On GPU)</span>
|
||
<span class="sd"> The delta tensor.</span>
|
||
<span class="sd"> mamba: Its shape is [batch_size, seq_len, dim] or [num_tokens, dim] for remove_input_padding</span>
|
||
<span class="sd"> mamba2: Its shape is [batch_size, seq_len, nheads] or [num_tokens, nheads] for remove_input_padding</span>
|
||
|
||
<span class="sd"> delta_bias : Tensor (On GPU)</span>
|
||
<span class="sd"> The delta bias tensor.</span>
|
||
<span class="sd"> mamba: Its shape is [dim]</span>
|
||
<span class="sd"> mamba2: Its shape is [nheads]</span>
|
||
|
||
<span class="sd"> A : Tensor (On GPU)</span>
|
||
<span class="sd"> A matrix.</span>
|
||
<span class="sd"> mamba: Its shape is [dstate, dim]</span>
|
||
<span class="sd"> mamba2: Its shape is [nheads]</span>
|
||
|
||
<span class="sd"> BC : Tensor (On GPU)</span>
|
||
<span class="sd"> B and C matrix.</span>
|
||
<span class="sd"> mamba: Its shape is [batch_size, seq_len, dstate * 2] or [num_tokens, dstate * 2] for remove_input_padding</span>
|
||
<span class="sd"> mamba2: Its shape is [batch_size, seq_len, ngroups * dstate * 2] or [num_tokens, ngroups * dstate * 2] for remove_input_padding</span>
|
||
|
||
<span class="sd"> D : Tensor (On GPU)</span>
|
||
<span class="sd"> D matrix.</span>
|
||
<span class="sd"> mamba: Its shape is [dim]</span>
|
||
<span class="sd"> mamba2: Its shape is [nheads]</span>
|
||
|
||
<span class="sd"> host_request_types : Tensor (On CPU)</span>
|
||
<span class="sd"> The tensor on the host that indicates if a request is in context or</span>
|
||
<span class="sd"> generation phase. Its shape is [batch_size]. See Inflight Batching</span>
|
||
<span class="sd"> in docs/source/advanced/gpt-attention.md</span>
|
||
|
||
<span class="sd"> last_token_ids : Tensor (On GPU)</span>
|
||
<span class="sd"> The inclusive prefix-sum of the lengths or the lengths of the</span>
|
||
<span class="sd"> sequences in the batch.</span>
|
||
|
||
<span class="sd"> dim : int</span>
|
||
<span class="sd"> The inner dimension of SSM block</span>
|
||
|
||
<span class="sd"> dstate : int</span>
|
||
<span class="sd"> The state dimension of SSM block</span>
|
||
|
||
<span class="sd"> dt_rank: int</span>
|
||
<span class="sd"> The rank dimension of dt_proj</span>
|
||
|
||
<span class="sd"> delta_softplus : bool</span>
|
||
<span class="sd"> Do we apply softplus to the delta.</span>
|
||
|
||
<span class="sd"> dtype: str</span>
|
||
<span class="sd"> data type</span>
|
||
|
||
<span class="sd"> z : Tensor (On GPU) (Optional)</span>
|
||
<span class="sd"> The z tensor. Its shape is [batch_size, seq_len, dim] or [num_tokens, dim] for remove_input_padding</span>
|
||
|
||
<span class="sd"> host_context_lengths: Tensor (On CPU) (Optional)</span>
|
||
<span class="sd"> A host tensor that contains the lengths of the different inputs,</span>
|
||
|
||
<span class="sd"> slot_mapping: Tensor (On GPU) (Optional)</span>
|
||
<span class="sd"> Real page index in state. Its shape is [dim], used for paged state, each page shape is [dstate, dim]</span>
|
||
|
||
<span class="sd"> nheads: int (Optional)</span>
|
||
<span class="sd"> The number of heads.</span>
|
||
|
||
<span class="sd"> ngroups: int (Optional)</span>
|
||
<span class="sd"> The number of groups.</span>
|
||
|
||
<span class="sd"> chunk_size: int (Optional)</span>
|
||
<span class="sd"> The chunk_size is used for the chunk_scan kernel.</span>
|
||
|
||
<span class="sd"> mamba_version: int (Optional)</span>
|
||
<span class="sd"> Mamba version, support Mamba1 as default.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">assert</span> <span class="n">host_request_types</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="n">selective_scan_plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
|
||
<span class="s1">'SelectiveScan'</span><span class="p">,</span> <span class="s1">'1'</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">selective_scan_plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
|
||
<span class="n">dim</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"dim"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">dstate</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"dstate"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">dstate</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">dt_rank</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"dt_rank"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">dt_rank</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">nheads</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"nheads"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">nheads</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">ngroups</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"ngroups"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">ngroups</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">chunk_size</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"chunk_size"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">chunk_size</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">delta_softplus</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"delta_softplus"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">delta_softplus</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"type_id"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">dtype</span><span class="p">))],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">remove_input_padding</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"remove_input_padding"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">paged_state</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"paged_state"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">paged_state</span><span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">z</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">z_enabled</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"z_enabled"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">z_enabled</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"z_enabled"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">is_mamba2</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"is_mamba2"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mi">1</span> <span class="k">if</span> <span class="n">mamba_version</span> <span class="o">==</span> <span class="s1">'Mamba2'</span> <span class="k">else</span> <span class="mi">0</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
|
||
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span>
|
||
<span class="n">dim</span><span class="p">,</span> <span class="n">dstate</span><span class="p">,</span> <span class="n">dt_rank</span><span class="p">,</span> <span class="n">nheads</span><span class="p">,</span> <span class="n">ngroups</span><span class="p">,</span> <span class="n">chunk_size</span><span class="p">,</span> <span class="n">delta_softplus</span><span class="p">,</span>
|
||
<span class="n">pf_type</span><span class="p">,</span> <span class="n">remove_input_padding</span><span class="p">,</span> <span class="n">paged_state</span><span class="p">,</span> <span class="n">z_enabled</span><span class="p">,</span> <span class="n">is_mamba2</span>
|
||
<span class="p">])</span>
|
||
<span class="n">selective_scan_plug</span> <span class="o">=</span> <span class="n">selective_scan_plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span>
|
||
<span class="s2">"selective_scan"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span>
|
||
<span class="nb">input</span><span class="p">,</span> <span class="n">state_or_ptr</span><span class="p">,</span> <span class="n">delta</span><span class="p">,</span> <span class="n">delta_bias</span><span class="p">,</span> <span class="n">A</span><span class="p">,</span> <span class="n">BC</span><span class="p">,</span> <span class="n">D</span><span class="p">,</span> <span class="n">host_request_types</span><span class="p">,</span>
|
||
<span class="n">last_token_ids</span>
|
||
<span class="p">]</span>
|
||
<span class="k">if</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">host_context_lengths</span><span class="p">]</span>
|
||
<span class="k">if</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">paged_state</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">slot_mapping</span><span class="p">]</span>
|
||
<span class="k">if</span> <span class="n">z</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">z</span><span class="p">]</span>
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">i</span><span class="o">.</span><span class="n">trt_tensor</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">plug_inputs</span><span class="p">]</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">selective_scan_plug</span><span class="p">)</span>
|
||
<span class="n">_add_plugin_info</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">selective_scan_plg_creator</span><span class="p">,</span> <span class="s2">"selective_scan"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="n">output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">paged_state</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">output</span><span class="p">,</span> <span class="kc">None</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">present_state</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">output</span><span class="p">,</span> <span class="n">present_state</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="rg_lru">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.rg_lru">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">rg_lru</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">A</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">state_or_ptr</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">host_request_types</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">last_token_ids</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
|
||
<span class="n">block_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">y</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">y_bias</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">gate</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">gate_bias</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">gate_x</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">gate_x_bias</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">gate_a</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">gate_a_bias</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">slot_mapping</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor (On GPU)</span>
|
||
<span class="sd"> The input tensor. Its shape is [batch_size, seq_len, dim]</span>
|
||
|
||
<span class="sd"> A : Tensor (On GPU)</span>
|
||
<span class="sd"> A matrix. Its shape is [dim]</span>
|
||
|
||
<span class="sd"> state_or_ptr : Tensor (On GPU or CPU)</span>
|
||
<span class="sd"> The lru state tensor. Its shape is [batch_size, dstate, dim]</span>
|
||
<span class="sd"> Or the CPU tensor of shape [1] for the pointer of paged states.</span>
|
||
|
||
<span class="sd"> host_request_types : Tensor (On CPU)</span>
|
||
<span class="sd"> The tensor on the host that indicates if a request is in context or</span>
|
||
<span class="sd"> generation phase. Its shape is [batch_size]. See Inflight Batching</span>
|
||
<span class="sd"> in docs/source/advanced/gpt-attention.md,</span>
|
||
|
||
<span class="sd"> last_token_ids : Tensor (On GPU)</span>
|
||
<span class="sd"> The inclusive prefix-sum of the lengths or the lengths of the</span>
|
||
<span class="sd"> sequences in the batch.</span>
|
||
|
||
<span class="sd"> dim : int</span>
|
||
<span class="sd"> The inner dimension of RG_LRU block</span>
|
||
|
||
<span class="sd"> block_size : int</span>
|
||
<span class="sd"> The block size of the block diagonal linear layer. It is used to</span>
|
||
<span class="sd"> support the cases that enable fused gate.</span>
|
||
|
||
<span class="sd"> dtype: str</span>
|
||
<span class="sd"> data type</span>
|
||
|
||
<span class="sd"> y : Tensor (On GPU) (Optional)</span>
|
||
<span class="sd"> The y tensor. Its shape is [batch_size, seq_len, dim]</span>
|
||
|
||
<span class="sd"> y_bias : Tensor (On GPU) (Optional)</span>
|
||
<span class="sd"> The y_bias tensor. Its shape is [dim]. If y_bias is not None, we</span>
|
||
<span class="sd"> will fuse GELU(y + y_bias) in this function.</span>
|
||
|
||
<span class="sd"> gate : Tensor (On GPU) (Optional)</span>
|
||
<span class="sd"> The gate tensor. Its shape is [batch_size, seq_len, 2 * dim].</span>
|
||
<span class="sd"> If gate is not None, we will fuse the gate_x and gate_a, otherwise</span>
|
||
<span class="sd"> use those two tensors.</span>
|
||
|
||
<span class="sd"> gate_bias : Tensor (On GPU) (Optional)</span>
|
||
<span class="sd"> The gate_bias tensor. Its shape is [2 * block_num, dim // block_num].</span>
|
||
<span class="sd"> If gate_bias is not None, we will fuse the bias add in this function.</span>
|
||
|
||
<span class="sd"> gate_x : Tensor (On GPU) (Optional)</span>
|
||
<span class="sd"> The gate_x tensor. Its shape is [batch_size, seq_len, dim]</span>
|
||
|
||
<span class="sd"> gate_x_bias : Tensor (On GPU) (Optional)</span>
|
||
<span class="sd"> The gate_x_bias tensor. Its shape is [block_num, dim // block_num].</span>
|
||
<span class="sd"> If gate_x_bias is not None, we will fuse the bias add in this function.</span>
|
||
|
||
<span class="sd"> gate_a : Tensor (On GPU) (Optional)</span>
|
||
<span class="sd"> The gate_a tensor. Its shape is [batch_size, seq_len, dim]</span>
|
||
|
||
<span class="sd"> gate_a_bias : Tensor (On GPU) (Optional)</span>
|
||
<span class="sd"> The gate_a_bias tensor. Its shape is [block_num, dim // block_num].</span>
|
||
<span class="sd"> If gate_a_bias is not None, we will fuse the bias add in this function.</span>
|
||
|
||
<span class="sd"> slot_mapping: Tensor (On GPU) (Optional)</span>
|
||
<span class="sd"> Real page index in state. Its shape is [dim], used for paged state, each page shape is [dstate, dim]</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">assert</span> <span class="n">host_request_types</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="n">lru_plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
|
||
<span class="s1">'LRU'</span><span class="p">,</span> <span class="s1">'1'</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">lru_plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="k">assert</span> <span class="p">(</span><span class="n">gate_x_bias</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">)</span> <span class="o">==</span> <span class="p">(</span><span class="n">gate_a_bias</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">)</span>
|
||
<span class="n">enable_fuse_gate</span> <span class="o">=</span> <span class="n">gate</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="n">has_gate_bias</span> <span class="o">=</span> <span class="p">(</span><span class="n">gate_bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">)</span> <span class="ow">or</span> <span class="p">(</span><span class="n">gate_x_bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">enable_fuse_gate</span><span class="p">:</span>
|
||
<span class="k">assert</span> <span class="n">gate</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="k">assert</span> <span class="n">block_size</span> <span class="o">></span> <span class="mi">0</span>
|
||
<span class="k">if</span> <span class="n">has_gate_bias</span><span class="p">:</span>
|
||
<span class="k">assert</span> <span class="n">gate_bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="k">assert</span> <span class="n">gate_x</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">gate_a</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="k">if</span> <span class="n">has_gate_bias</span><span class="p">:</span>
|
||
<span class="k">assert</span> <span class="n">gate_x_bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">gate_a_bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
|
||
<span class="n">dim</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"dim"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">block_size</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"block_size"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">block_size</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"type_id"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">dtype</span><span class="p">))],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">remove_input_padding</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"remove_input_padding"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">paged_state</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"paged_state"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">paged_state</span><span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">y</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">y_enabled</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"y_enabled"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">y_enabled</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"y_enabled"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">y_bias</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">y_bias_enabled</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"y_bias_enabled"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">y_bias_enabled</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"y_bias_enabled"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">enable_fuse_gate</span><span class="p">:</span>
|
||
<span class="n">fuse_gate_enabled</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"fuse_gate_enabled"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">fuse_gate_enabled</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"fuse_gate_enabled"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">has_gate_bias</span><span class="p">:</span>
|
||
<span class="n">gate_bias_enabled</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"gate_bias_enabled"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">gate_bias_enabled</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"gate_bias_enabled"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
|
||
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span>
|
||
<span class="n">dim</span><span class="p">,</span> <span class="n">block_size</span><span class="p">,</span> <span class="n">pf_type</span><span class="p">,</span> <span class="n">remove_input_padding</span><span class="p">,</span> <span class="n">paged_state</span><span class="p">,</span> <span class="n">y_enabled</span><span class="p">,</span>
|
||
<span class="n">y_bias_enabled</span><span class="p">,</span> <span class="n">fuse_gate_enabled</span><span class="p">,</span> <span class="n">gate_bias_enabled</span>
|
||
<span class="p">])</span>
|
||
<span class="n">lru_plug</span> <span class="o">=</span> <span class="n">lru_plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">"rg_lru"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span>
|
||
<span class="nb">input</span><span class="p">,</span>
|
||
<span class="n">A</span><span class="p">,</span>
|
||
<span class="n">state_or_ptr</span><span class="p">,</span>
|
||
<span class="n">host_request_types</span><span class="p">,</span>
|
||
<span class="n">last_token_ids</span><span class="p">,</span>
|
||
<span class="p">]</span>
|
||
<span class="k">if</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">paged_state</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">slot_mapping</span><span class="p">]</span>
|
||
<span class="k">if</span> <span class="n">y</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">y</span><span class="p">]</span>
|
||
<span class="k">if</span> <span class="n">y_bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">y_bias</span><span class="p">]</span>
|
||
<span class="k">if</span> <span class="n">enable_fuse_gate</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">gate</span><span class="p">]</span>
|
||
<span class="k">if</span> <span class="n">has_gate_bias</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">gate_bias</span><span class="p">]</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">gate_x</span><span class="p">,</span> <span class="n">gate_a</span><span class="p">]</span>
|
||
<span class="k">if</span> <span class="n">has_gate_bias</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">gate_x_bias</span><span class="p">,</span> <span class="n">gate_a_bias</span><span class="p">]</span>
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">i</span><span class="o">.</span><span class="n">trt_tensor</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">plug_inputs</span><span class="p">]</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">lru_plug</span><span class="p">)</span>
|
||
<span class="n">_add_plugin_info</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">lru_plg_creator</span><span class="p">,</span> <span class="s2">"rg_lru"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="n">output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">paged_state</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">output</span><span class="p">,</span> <span class="kc">None</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">present_state</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">output</span><span class="p">,</span> <span class="n">present_state</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="topk">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.topk">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">topk</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">k</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">],</span>
|
||
<span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">largest</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
|
||
<span class="n">prefer_plugin</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tuple</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">]:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an topk operation.</span>
|
||
|
||
<span class="sd"> As explained in the ONNX documentation,</span>
|
||
|
||
<span class="sd"> https://github.com/onnx/onnx/blob/main/docs/Operators.md#topk</span>
|
||
|
||
<span class="sd"> NOTE: One distinction from the ONNX topk op, the output is always sorted</span>
|
||
<span class="sd"> with TensorRT layer.</span>
|
||
|
||
<span class="sd"> Retrieve the top-K largest elements along a specified axis.</span>
|
||
<span class="sd"> Given an input tensor of shape [a_1, a_2, ..., a_n, r]</span>
|
||
<span class="sd"> and integer argument k, return two outputs:</span>
|
||
<span class="sd"> Value tensor of shape [a_1, a_2, ..., a_{axis-1}, k, a_{axis+1}, ... a_n] which contains the values of the top k elements along the specified axis</span>
|
||
<span class="sd"> Index tensor of shape [a_1, a_2, ..., a_{axis-1}, k, a_{axis+1}, ... a_n] which contains the indices of the top k elements (original indices from the input tensor).</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor.</span>
|
||
|
||
<span class="sd"> k : int</span>
|
||
<span class="sd"> A single positive value corresponding to the number of top elements to retrieve</span>
|
||
|
||
<span class="sd"> dim: int</span>
|
||
<span class="sd"> The dimension in which to compute the topk indices.</span>
|
||
|
||
<span class="sd"> largest: bool</span>
|
||
<span class="sd"> Controls whether to return largest or smallest elements</span>
|
||
|
||
<span class="sd"> prefer_plugin : bool</span>
|
||
<span class="sd"> Whether to use the topkLastDim plugin if dim is last dim and k is static.</span>
|
||
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensors (values, indices) produced by this topk operation.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">dim</span> <span class="o">=</span> <span class="n">dim_resolve_negative</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">())[</span><span class="mi">0</span><span class="p">]</span>
|
||
<span class="k">if</span> <span class="n">prefer_plugin</span> <span class="ow">and</span> <span class="n">dim</span> <span class="o">==</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">-</span> <span class="mi">1</span> <span class="ow">and</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">k</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">):</span>
|
||
<span class="n">last_dim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">last_dim</span> <span class="o">==</span> <span class="o">-</span><span class="mi">1</span><span class="p">:</span> <span class="c1"># dynamic?</span>
|
||
<span class="n">last_dim</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="c1"># since we might need to flatten the input to 2d tensor,</span>
|
||
<span class="c1"># we need to prepare the output shape</span>
|
||
<span class="n">out_shape</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">-</span> <span class="mi">1</span><span class="p">):</span>
|
||
<span class="n">out_shape</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">i</span><span class="p">))</span>
|
||
<span class="n">out_shape</span> <span class="o">=</span> <span class="n">concat</span><span class="p">(</span><span class="n">out_shape</span> <span class="o">+</span> <span class="p">[</span><span class="n">k</span><span class="p">])</span>
|
||
<span class="k">if</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
|
||
<span class="n">input_2d</span> <span class="o">=</span> <span class="n">unsqueeze</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span>
|
||
<span class="mi">0</span><span class="p">)</span> <span class="c1"># special handling of rank-1 dynamic tensor</span>
|
||
<span class="k">elif</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">!=</span> <span class="mi">2</span><span class="p">:</span>
|
||
<span class="n">input_2d</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">concat</span><span class="p">([</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">last_dim</span><span class="p">]),</span>
|
||
<span class="n">zero_is_placeholder</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">input_2d</span> <span class="o">=</span> <span class="nb">input</span>
|
||
<span class="n">plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
|
||
<span class="s2">"TopkLastDim"</span><span class="p">,</span> <span class="s2">"1"</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="n">is_largest</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"is_largest"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mi">1</span> <span class="k">if</span> <span class="n">largest</span> <span class="k">else</span> <span class="mi">0</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">k</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"k"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">k</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"type_id"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">input_2d</span><span class="o">.</span><span class="n">dtype</span><span class="p">)],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span><span class="n">pf_type</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">is_largest</span><span class="p">])</span>
|
||
<span class="n">topk_last_dim_plug</span> <span class="o">=</span> <span class="n">plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">"topk_last_dim"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">input_2d</span><span class="p">]</span>
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">i</span><span class="o">.</span><span class="n">trt_tensor</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">plug_inputs</span><span class="p">]</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">topk_last_dim_plug</span><span class="p">)</span>
|
||
<span class="n">_add_plugin_info</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">plg_creator</span><span class="p">,</span> <span class="s2">"topk_last_dim"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="n">values</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
<span class="n">indices</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
<span class="n">values</span> <span class="o">=</span> <span class="n">values</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">out_shape</span><span class="p">,</span> <span class="n">zero_is_placeholder</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||
<span class="n">indices</span> <span class="o">=</span> <span class="n">indices</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">out_shape</span><span class="p">,</span> <span class="n">zero_is_placeholder</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="c1"># non-plugin path</span>
|
||
<span class="n">axes</span> <span class="o">=</span> <span class="n">dim_to_trt_axes</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_topk</span><span class="p">(</span>
|
||
<span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">TopKOperation</span><span class="o">.</span><span class="n">MAX</span> <span class="k">if</span> <span class="n">largest</span> <span class="k">else</span> <span class="n">trt</span><span class="o">.</span><span class="n">TopKOperation</span><span class="o">.</span><span class="n">MIN</span><span class="p">,</span>
|
||
<span class="n">k</span><span class="o">=</span><span class="n">k</span> <span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">k</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">)</span> <span class="k">else</span> <span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">axes</span><span class="o">=</span><span class="n">axes</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">k</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="n">k</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
|
||
<span class="n">k</span> <span class="o">=</span> <span class="n">squeeze</span><span class="p">(</span><span class="n">k</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">k</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="n">values</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
<span class="n">indices</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
|
||
<span class="k">return</span> <span class="n">values</span><span class="p">,</span> <span class="n">indices</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="scatter_nd">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.scatter_nd">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">scatter_nd</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">mask</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">source</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Scatter_nd is a tensor operation that writes or updates values in a tensor based on indices.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input: Tensor</span>
|
||
<span class="sd"> The input tensor to be updated</span>
|
||
<span class="sd"> mask: Tensor</span>
|
||
<span class="sd"> A tensor of indices specifying the locations in data to be updated.</span>
|
||
<span class="sd"> source: Tensor</span>
|
||
<span class="sd"> A tensor of values to be written or scattered into data.</span>
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> New tensor with the same shape as the input tensor data,</span>
|
||
<span class="sd"> where the values from the source tensor are scattered or written into the output tensor</span>
|
||
<span class="sd"> at the locations specified by the mask tensor.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">scatter_layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_scatter</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">mask</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">source</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">mode</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ScatterMode</span><span class="o">.</span><span class="n">ND</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">scatter_layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">scatter_layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="low_latency_gemm">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.low_latency_gemm">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">low_latency_gemm</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">mat2</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">alpha</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">strict_dtype</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">low_latency_gemm_plugin</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span><span class="s2">"Low Latency GEMM is only support with plugin"</span><span class="p">)</span>
|
||
<span class="k">elif</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">low_latency_gemm_plugin</span> <span class="o">!=</span> <span class="s2">"fp8"</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span><span class="s2">"Low Latency GEMM plugin only support fp8"</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
|
||
<span class="s2">"LowLatencyGemm"</span><span class="p">,</span> <span class="s2">"1"</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="k">if</span> <span class="p">((</span><span class="nb">input</span><span class="o">.</span><span class="n">dtype</span> <span class="o">!=</span> <span class="n">trt</span><span class="o">.</span><span class="n">fp8</span><span class="p">)</span> <span class="ow">or</span> <span class="p">((</span><span class="n">mat2</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span> <span class="o">!=</span> <span class="n">trt</span><span class="o">.</span><span class="n">fp8</span><span class="p">)):</span>
|
||
<span class="k">raise</span> <span class="ne">TypeError</span><span class="p">(</span><span class="s2">"Low Latency GEMM only support fp8 input"</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="p">(</span><span class="n">alpha</span><span class="p">):</span>
|
||
<span class="k">assert</span> <span class="p">(</span><span class="nb">isinstance</span><span class="p">(</span><span class="n">alpha</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">)</span> <span class="ow">and</span> <span class="n">alpha</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">np</span><span class="o">.</span><span class="n">float32</span>
|
||
<span class="ow">and</span> <span class="n">alpha</span><span class="o">.</span><span class="n">size</span>
|
||
<span class="o">==</span> <span class="mi">1</span><span class="p">),</span> <span class="s2">"`alpha` must be passed as a float32 ndarray"</span>
|
||
<span class="n">alpha</span> <span class="o">=</span> <span class="n">alpha</span> <span class="k">if</span> <span class="n">alpha</span> <span class="k">else</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mf">1.0</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
||
<span class="n">alpha</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"alpha"</span><span class="p">,</span> <span class="n">alpha</span><span class="o">.</span><span class="n">flatten</span><span class="p">(),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">FLOAT32</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">strict_dtype</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">strict_dtype</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="p">)</span>
|
||
<span class="n">p_dtype</span> <span class="o">=</span> <span class="n">strict_dtype</span>
|
||
<span class="k">if</span> <span class="p">(</span><span class="n">p_dtype</span> <span class="ow">not</span> <span class="ow">in</span> <span class="p">[</span><span class="n">trt</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">float16</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">bfloat16</span><span class="p">]):</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
|
||
<span class="s2">"strict_dtype must be float32, float16 or bfloat16 in low latency gemm plugin"</span>
|
||
<span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span>
|
||
<span class="s2">"need to use strict dtype in low latency gemm plugin fp8"</span><span class="p">)</span>
|
||
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"type_id"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">)],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span><span class="n">alpha</span><span class="p">,</span> <span class="n">pf_type</span><span class="p">])</span>
|
||
<span class="n">low_latency_gemm_plug</span> <span class="o">=</span> <span class="n">plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span>
|
||
<span class="s2">"low_latency_gemm"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">mat2</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">]</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span>
|
||
<span class="n">low_latency_gemm_plug</span><span class="p">)</span>
|
||
<span class="n">_add_plugin_info</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">plg_creator</span><span class="p">,</span> <span class="s2">"low_latency_gemm"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="SideStreamIDType">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.SideStreamIDType">[docs]</a>
|
||
<span class="k">class</span> <span class="nc">SideStreamIDType</span><span class="p">(</span><span class="n">IntEnum</span><span class="p">):</span>
|
||
<span class="n">disable</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="n">moe</span> <span class="o">=</span> <span class="mi">1</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="low_latency_gemm_swiglu">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.low_latency_gemm_swiglu">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">low_latency_gemm_swiglu</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">weight</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">scale_d0</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span>
|
||
<span class="n">scale_d1</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span>
|
||
<span class="n">scale_output</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add a matrix multiplication, followed by SwiGLU (`x * SiLU(gate)`) operation.</span>
|
||
|
||
<span class="sd"> The second SwiGLU operation takes the preceding tensor, splits it into two halves</span>
|
||
<span class="sd"> along the last dimension, applies SiLU to the second half and multiply the results. The</span>
|
||
<span class="sd"> behaviour is undefined if the last dimension is not even.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The first tensor (often called A).</span>
|
||
|
||
<span class="sd"> weight : Tensor</span>
|
||
<span class="sd"> The second tensor (often called B).</span>
|
||
|
||
<span class="sd"> scale_d0 : float</span>
|
||
<span class="sd"> The scale for dequantizing x, used for fp8</span>
|
||
|
||
<span class="sd"> scale_d1 : float</span>
|
||
<span class="sd"> The scale for dequantizing gate, used for fp8</span>
|
||
|
||
<span class="sd"> scale_output : float</span>
|
||
<span class="sd"> The scale for quantizing output, used for fp8</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the inserted layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
|
||
<span class="s1">'LowLatencyGemmSwiglu'</span><span class="p">,</span> <span class="s1">'1'</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
|
||
<span class="n">p_dtype</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">low_latency_gemm_swiglu_plugin</span>
|
||
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"type_id"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">))],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">pf_scale_d0</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"scale_d0"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">scale_d0</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">FLOAT32</span><span class="p">)</span>
|
||
<span class="n">pf_scale_d1</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"scale_d1"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">scale_d1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">FLOAT32</span><span class="p">)</span>
|
||
<span class="n">pf_scale_output</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"scale_output"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">scale_output</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">FLOAT32</span><span class="p">)</span>
|
||
|
||
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">(</span>
|
||
<span class="p">[</span><span class="n">pf_type</span><span class="p">,</span> <span class="n">pf_scale_output</span><span class="p">,</span> <span class="n">pf_scale_d0</span><span class="p">,</span> <span class="n">pf_scale_d1</span><span class="p">])</span>
|
||
<span class="n">low_latency_gemm_swiglu_plug</span> <span class="o">=</span> <span class="n">plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span>
|
||
<span class="s2">"low_latency_gemm_swiglu"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">weight</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">]</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span>
|
||
<span class="n">low_latency_gemm_swiglu_plug</span><span class="p">)</span>
|
||
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="cuda_stream_sync">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.cuda_stream_sync">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">cuda_stream_sync</span><span class="p">(</span><span class="n">input_list</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">Tensor</span><span class="p">],</span>
|
||
<span class="n">side_stream_id</span><span class="p">:</span> <span class="n">SideStreamIDType</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Wait for the side stream on the main stream.</span>
|
||
<span class="sd"> output = input_list[0]</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input_list : List[Tensor] (On GPU)</span>
|
||
<span class="sd"> The list of input tensors.</span>
|
||
<span class="sd"> side_stream_id : int (On CPU)</span>
|
||
<span class="sd"> The side stream ID.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
|
||
<span class="s2">"CudaStream"</span><span class="p">,</span> <span class="s2">"1"</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
|
||
<span class="n">p_side_stream_id</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"side_stream_id"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">side_stream_id</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">p_num_inputs</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"num_inputs"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">input_list</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"type_id"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">input_list</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">dtype</span><span class="p">)],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span><span class="n">p_side_stream_id</span><span class="p">,</span> <span class="n">p_num_inputs</span><span class="p">,</span> <span class="n">pf_type</span><span class="p">])</span>
|
||
<span class="n">plug</span> <span class="o">=</span> <span class="n">plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">"cuda_stream"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span> <span class="k">for</span> <span class="nb">input</span> <span class="ow">in</span> <span class="n">input_list</span><span class="p">]</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">plug</span><span class="p">)</span>
|
||
<span class="n">_add_plugin_info</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">plg_creator</span><span class="p">,</span> <span class="s2">"cuda_stream"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="n">output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">output</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="cp_split_plugin">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.cp_split_plugin">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">cp_split_plugin</span><span class="p">(</span>
|
||
<span class="n">input_ids</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">host_request_types</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">host_context_lengths</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="c1"># for pad-free input mode</span>
|
||
<span class="n">cp_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">cp_rank</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
|
||
<span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to perform splitting for context parallelism.</span>
|
||
|
||
<span class="sd"> This operation split the input_ids into cp_size chunks, and return the cp_rank-th</span>
|
||
<span class="sd"> chunk.</span>
|
||
<span class="sd"> When the seqlen % cp_size != 0, the chunk sizes of each rank would be</span>
|
||
<span class="sd"> [seqlen // cp_size, seqlen // cp_size, ..., seqlen - (seqlen // cp_size) * cp_size]</span>
|
||
|
||
<span class="sd"> It inserts a IPluginV3Layer.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor contains the indices to split.</span>
|
||
|
||
<span class="sd"> host_request_types: Tensor = None (On CPU)</span>
|
||
<span class="sd"> The tensor on the host that indicates if a request is in context or</span>
|
||
<span class="sd"> generation phase. Its shape is [batch_size]. See Inflight Batching</span>
|
||
<span class="sd"> in docs/gpt_attention.md,</span>
|
||
|
||
<span class="sd"> host_context_lengths: Tensor = None (On CPU)</span>
|
||
<span class="sd"> A host tensor that contains the lengths of the different inputs</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The output split tensor.</span>
|
||
<span class="sd"> The length of the output split tensor.</span>
|
||
<span class="sd"> The index for rebuilding the sequence</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_creator</span><span class="p">(</span>
|
||
<span class="s1">'CpSplit'</span><span class="p">,</span> <span class="s1">'1'</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
|
||
<span class="n">cp_size</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"cp_size"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">cp_size</span><span class="p">)],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">cp_rank</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"cp_rank"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">cp_rank</span><span class="p">)],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
|
||
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span><span class="n">cp_size</span><span class="p">,</span> <span class="n">cp_rank</span><span class="p">])</span>
|
||
<span class="n">cp_split_plug</span> <span class="o">=</span> <span class="n">plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">"cp_split"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">,</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">TensorRTPhase</span><span class="o">.</span><span class="n">BUILD</span><span class="p">)</span>
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span>
|
||
<span class="n">input_ids</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">host_request_types</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">host_context_lengths</span><span class="o">.</span><span class="n">trt_tensor</span>
|
||
<span class="p">]</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v3</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="p">[],</span> <span class="n">cp_split_plug</span><span class="p">)</span>
|
||
<span class="n">_add_plugin_info</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">plg_creator</span><span class="p">,</span> <span class="s2">"cp_split"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
|
||
<span class="n">layer</span><span class="p">),</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">2</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
</pre></div>
|
||
|
||
</div>
|
||
</div>
|
||
<footer>
|
||
|
||
<hr/>
|
||
|
||
<div role="contentinfo">
|
||
<jinja2.runtime.BlockReference object at 0x7f59fcc7af90>
|
||
|
||
<div class="footer">
|
||
<p>
|
||
Copyright © 2024 NVIDIA Corporation
|
||
</p>
|
||
<p>
|
||
<a class="Link" href="https://www.nvidia.com/en-us/about-nvidia/privacy-policy/" target="_blank" rel="noopener"
|
||
data-cms-ai="0">Privacy Policy</a> |
|
||
<a class="Link" href="https://www.nvidia.com/en-us/about-nvidia/privacy-center/" target="_blank" rel="noopener"
|
||
data-cms-ai="0">Manage My Privacy</a> |
|
||
<a class="Link" href="https://www.nvidia.com/en-us/preferences/start/" target="_blank" rel="noopener"
|
||
data-cms-ai="0">Do Not Sell or Share My Data</a> |
|
||
<a class="Link" href="https://www.nvidia.com/en-us/about-nvidia/terms-of-service/" target="_blank"
|
||
rel="noopener" data-cms-ai="0">Terms of Service</a> |
|
||
<a class="Link" href="https://www.nvidia.com/en-us/about-nvidia/accessibility/" target="_blank" rel="noopener"
|
||
data-cms-ai="0">Accessibility</a> |
|
||
<a class="Link" href="https://www.nvidia.com/en-us/about-nvidia/company-policies/" target="_blank"
|
||
rel="noopener" data-cms-ai="0">Corporate Policies</a> |
|
||
<a class="Link" href="https://www.nvidia.com/en-us/product-security/" target="_blank" rel="noopener"
|
||
data-cms-ai="0">Product Security</a> |
|
||
<a class="Link" href="https://www.nvidia.com/en-us/contact/" target="_blank" rel="noopener"
|
||
data-cms-ai="0">Contact</a>
|
||
</p>
|
||
</div>
|
||
|
||
|
||
</div>
|
||
|
||
|
||
|
||
</footer>
|
||
</div>
|
||
</div>
|
||
</section>
|
||
</div>
|
||
<script>
|
||
jQuery(function () {
|
||
SphinxRtdTheme.Navigation.enable(true);
|
||
});
|
||
</script>
|
||
|
||
</body>
|
||
</html> |