TensorRT-LLMs/_modules/tensorrt_llm/layers/linear.html
2024-09-30 19:28:28 +02:00

716 lines
83 KiB
HTML

<!DOCTYPE html>
<html class="writer-html5" lang="en" data-content_root="../../../">
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>tensorrt_llm.layers.linear &mdash; tensorrt_llm documentation</title>
<link rel="stylesheet" type="text/css" href="../../../_static/pygments.css?v=80d5e7a1" />
<link rel="stylesheet" type="text/css" href="../../../_static/css/theme.css?v=19f00094" />
<!--[if lt IE 9]>
<script src="../../../_static/js/html5shiv.min.js"></script>
<![endif]-->
<script src="../../../_static/jquery.js?v=5d32c60e"></script>
<script src="../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
<script src="../../../_static/documentation_options.js?v=5929fcd5"></script>
<script src="../../../_static/doctools.js?v=888ff710"></script>
<script src="../../../_static/sphinx_highlight.js?v=dc90522c"></script>
<script src="../../../_static/js/theme.js"></script>
<link rel="index" title="Index" href="../../../genindex.html" />
<link rel="search" title="Search" href="../../../search.html" />
</head>
<body class="wy-body-for-nav">
<div class="wy-grid-for-nav">
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
<div class="wy-side-scroll">
<div class="wy-side-nav-search" >
<a href="../../../index.html" class="icon icon-home">
tensorrt_llm
</a>
<div role="search">
<form id="rtd-search-form" class="wy-form" action="../../../search.html" method="get">
<input type="text" name="q" placeholder="Search docs" aria-label="Search docs" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
</div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../overview.html">Overview</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../quick-start-guide.html">Quick Start Guide</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../key-features.html">Key Features</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../release-notes.html">Release Notes</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Installation</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../installation/linux.html">Installing on Linux</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../installation/build-from-source-linux.html">Building from Source Code on Linux</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../installation/windows.html">Installing on Windows</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../installation/build-from-source-windows.html">Building from Source Code on Windows</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">LLM API Examples</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../llm-api-examples/index.html">LLM Examples Introduction</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../llm-api-examples/customization.html">Common Customizations</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../llm-api-examples/llm_api_examples.html">Examples</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">LLM API</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../llm-api/index.html">API Reference</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Model Definition API</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../python-api/tensorrt_llm.layers.html">Layers</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../python-api/tensorrt_llm.functional.html">Functionals</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../python-api/tensorrt_llm.models.html">Models</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../python-api/tensorrt_llm.plugin.html">Plugin</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../python-api/tensorrt_llm.quantization.html">Quantization</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../python-api/tensorrt_llm.runtime.html">Runtime</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">C++ API</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../_cpp_gen/executor.html">Executor</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../_cpp_gen/runtime.html">Runtime</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Command-Line Reference</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../commands/trtllm-build.html">trtllm-build</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Architecture</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../architecture/overview.html">TensorRT-LLM Architecture</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../architecture/core-concepts.html">Model Definition</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../architecture/core-concepts.html#compilation">Compilation</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../architecture/core-concepts.html#runtime">Runtime</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../architecture/core-concepts.html#multi-gpu-and-multi-node-support">Multi-GPU and Multi-Node Support</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../architecture/checkpoint.html">TensorRT-LLM Checkpoint</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../architecture/workflow.html">TensorRT-LLM Build Workflow</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../architecture/add-model.html">Adding a Model</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Advanced</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/gpt-attention.html">Multi-Head, Multi-Query, and Group-Query Attention</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/gpt-runtime.html">C++ GPT Runtime</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/graph-rewriting.html">Graph Rewriting Module</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/batch-manager.html">The Batch Manager in TensorRT-LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/inference-request.html">Inference Request</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/inference-request.html#responses">Responses</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/lora.html">Run gpt-2b + LoRA using GptManager / cpp runtime</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../performance/perf-overview.html">Overview</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../performance/perf-best-practices.html">Best Practices for Tuning the Performance of TensorRT-LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../performance/perf-analysis.html">Performance Analysis</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Reference</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../reference/troubleshooting.html">Troubleshooting</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../reference/support-matrix.html">Support Matrix</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../reference/precision.html">Numerical Precision</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../reference/memory.html">Memory Usage of TensorRT-LLM</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Blogs</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/H100vsA100.html">H100 has 4.6x A100 Performance in TensorRT-LLM, achieving 10,000 tok/s at 100ms to first token</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/H200launch.html">H200 achieves nearly 12,000 tokens/sec on Llama2-13B with TensorRT-LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/Falcon180B-H200.html">Falcon-180B on a single H200 GPU with INT4 AWQ, and 6.7x faster Llama-70B over A100</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/quantization-in-TRT-LLM.html">Speed up inference with SOTA quantization techniques in TRT-LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/XQA-kernel.html">New XQA-kernel provides 2.4x more Llama-70B throughput within the same latency budget</a></li>
</ul>
</div>
</div>
</nav>
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
<a href="../../../index.html">tensorrt_llm</a>
</nav>
<div class="wy-nav-content">
<div class="rst-content">
<div role="navigation" aria-label="Page navigation">
<ul class="wy-breadcrumbs">
<li><a href="../../../index.html" class="icon icon-home" aria-label="Home"></a></li>
<li class="breadcrumb-item"><a href="../../index.html">Module code</a></li>
<li class="breadcrumb-item active">tensorrt_llm.layers.linear</li>
<li class="wy-breadcrumbs-aside">
</li>
</ul>
<hr/>
</div>
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
<div itemprop="articleBody">
<h1>Source code for tensorrt_llm.layers.linear</h1><div class="highlight"><pre>
<span></span><span class="c1"># SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION &amp; AFFILIATES. All rights reserved.</span>
<span class="c1"># SPDX-License-Identifier: Apache-2.0</span>
<span class="c1">#</span>
<span class="c1"># Licensed under the Apache License, Version 2.0 (the &quot;License&quot;);</span>
<span class="c1"># you may not use this file except in compliance with the License.</span>
<span class="c1"># You may obtain a copy of the License at</span>
<span class="c1">#</span>
<span class="c1"># http://www.apache.org/licenses/LICENSE-2.0</span>
<span class="c1">#</span>
<span class="c1"># Unless required by applicable law or agreed to in writing, software</span>
<span class="c1"># distributed under the License is distributed on an &quot;AS IS&quot; BASIS,</span>
<span class="c1"># WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.</span>
<span class="c1"># See the License for the specific language governing permissions and</span>
<span class="c1"># limitations under the License.</span>
<span class="kn">from</span> <span class="nn">abc</span> <span class="kn">import</span> <span class="n">ABCMeta</span><span class="p">,</span> <span class="n">abstractmethod</span>
<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Optional</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="kn">import</span> <span class="nn">tensorrt</span> <span class="k">as</span> <span class="nn">trt</span>
<span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">from</span> <span class="nn">.._common</span> <span class="kn">import</span> <span class="n">default_net</span><span class="p">,</span> <span class="n">default_trtnet</span>
<span class="kn">from</span> <span class="nn">.._utils</span> <span class="kn">import</span> <span class="n">set_obj_attrs</span><span class="p">,</span> <span class="n">str_dtype_to_torch</span><span class="p">,</span> <span class="n">str_dtype_to_trt</span>
<span class="kn">from</span> <span class="nn">..functional</span> <span class="kn">import</span> <span class="p">(</span><span class="n">AllReduceFusionOp</span><span class="p">,</span> <span class="n">AllReduceFusionParams</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">_add_plugin_info</span><span class="p">,</span> <span class="n">_create_tensor</span><span class="p">,</span> <span class="n">allgather</span><span class="p">,</span>
<span class="n">allreduce</span><span class="p">,</span> <span class="n">cast</span><span class="p">,</span> <span class="n">low_latency_gemm</span><span class="p">,</span> <span class="n">matmul</span><span class="p">)</span>
<span class="kn">from</span> <span class="nn">..mapping</span> <span class="kn">import</span> <span class="n">Mapping</span>
<span class="kn">from</span> <span class="nn">..module</span> <span class="kn">import</span> <span class="n">Module</span>
<span class="kn">from</span> <span class="nn">..parameter</span> <span class="kn">import</span> <span class="n">Parameter</span>
<span class="kn">from</span> <span class="nn">..plugin</span> <span class="kn">import</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span>
<span class="kn">from</span> <span class="nn">.lora</span> <span class="kn">import</span> <span class="n">LoraRuntimeParams</span>
<span class="k">def</span> <span class="nf">_gemm_plugin</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">mat2</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">transa</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">transb</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">pad_lda</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
<span class="n">pad_ldb</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
<span class="n">use_fp8</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">alpha</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">strict_dtype</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> output = op(mat2)op(input)</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor (On GPU)</span>
<span class="sd"> The input tensor.</span>
<span class="sd"> mat2 : Tensor (On GPU)</span>
<span class="sd"> The mat2 tensor.</span>
<span class="sd"> transa : bool</span>
<span class="sd"> Is the input tensor transposed? Set to &#39;True&#39; if you want the</span>
<span class="sd"> input tensor to be transposed, &#39;False&#39; otherwise.</span>
<span class="sd"> transb : bool</span>
<span class="sd"> Is the mat2 tensor transposed? Set to &#39;True&#39; if you want the</span>
<span class="sd"> mat2 tensor to be transposed, &#39;False&#39; otherwise.</span>
<span class="sd"> pad_lda: int</span>
<span class="sd"> Padding to the lead dimension of input tensor. It is used to</span>
<span class="sd"> support the strided GEMM that only uses the sub-tensor for</span>
<span class="sd"> computation. The GEMM plugin computation is</span>
<span class="sd"> [N, K] x [K, M+pad_lda] -&gt; [N, M] if transa,</span>
<span class="sd"> [N, K] x [K+pad_lda, M] -&gt; [N, M] if not transa.</span>
<span class="sd"> pad_ldb: int</span>
<span class="sd"> Padding to the lead dimension of mat2 tensor. It is used to</span>
<span class="sd"> support the strided GEMM that only uses the sub-tensor for</span>
<span class="sd"> computation. The GEMM plugin computation is</span>
<span class="sd"> [N, K+pad_ldb] x [K, M] -&gt; [N, M] if transb,</span>
<span class="sd"> [N+pad_ldb, K] x [K, M] -&gt; [N, M] if not transb.</span>
<span class="sd"> use_fp8: bool</span>
<span class="sd"> Do we use fp8 GEMM.</span>
<span class="sd"> alpha: float</span>
<span class="sd"> Alpha for fp8 GEMM.</span>
<span class="sd"> strict_dtype: trt.DataType</span>
<span class="sd"> Set the data type for the GEMM plugin. If it is None, the data</span>
<span class="sd"> type is the gemm_plugin type set in the plugin_config.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
<span class="s2">&quot;Gemm&quot;</span><span class="p">,</span> <span class="s2">&quot;1&quot;</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="k">if</span> <span class="n">use_fp8</span><span class="p">:</span>
<span class="k">assert</span> <span class="p">(</span>
<span class="nb">isinstance</span><span class="p">(</span><span class="n">alpha</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">)</span> <span class="ow">and</span> <span class="n">alpha</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">np</span><span class="o">.</span><span class="n">float32</span>
<span class="ow">and</span> <span class="n">alpha</span><span class="o">.</span><span class="n">size</span> <span class="o">==</span> <span class="mi">1</span>
<span class="p">),</span> <span class="s2">&quot;`alpha` must be passed as a float32 ndarray if `use_fp8` is enabled for _gemm_plugin&quot;</span>
<span class="k">assert</span> <span class="nb">input</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">fp8</span>
<span class="k">assert</span> <span class="n">mat2</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">fp8</span>
<span class="n">transa</span> <span class="o">=</span> <span class="mi">1</span> <span class="k">if</span> <span class="n">transa</span> <span class="k">else</span> <span class="mi">0</span>
<span class="n">transa</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;transa&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">transa</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">transb</span> <span class="o">=</span> <span class="mi">1</span> <span class="k">if</span> <span class="n">transb</span> <span class="k">else</span> <span class="mi">0</span>
<span class="n">transb</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;transb&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">transb</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">pad_lda</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;pad_lda&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">pad_lda</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">pad_ldb</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;pad_ldb&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">pad_ldb</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">use_fp8</span> <span class="o">=</span> <span class="mi">1</span> <span class="k">if</span> <span class="n">use_fp8</span> <span class="k">else</span> <span class="mi">0</span>
<span class="n">use_fp8</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;use_fp8&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">use_fp8</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">alpha</span> <span class="o">=</span> <span class="n">alpha</span> <span class="k">if</span> <span class="n">alpha</span> <span class="k">else</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mf">1.0</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="n">alpha</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;alpha&quot;</span><span class="p">,</span> <span class="n">alpha</span><span class="o">.</span><span class="n">flatten</span><span class="p">(),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">FLOAT32</span><span class="p">)</span>
<span class="k">if</span> <span class="n">strict_dtype</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">strict_dtype</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="p">)</span>
<span class="n">p_dtype</span> <span class="o">=</span> <span class="n">strict_dtype</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">p_dtype</span> <span class="o">=</span> <span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">gemm_plugin</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">p_dtype</span> <span class="o">!=</span> <span class="n">trt</span><span class="o">.</span><span class="n">fp8</span><span class="p">,</span> <span class="s2">&quot;need to use strict dtype in gemm plugin fp8&quot;</span>
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;type_id&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">)],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">(</span>
<span class="p">[</span><span class="n">transa</span><span class="p">,</span> <span class="n">transb</span><span class="p">,</span> <span class="n">pad_lda</span><span class="p">,</span> <span class="n">pad_ldb</span><span class="p">,</span> <span class="n">pf_type</span><span class="p">,</span> <span class="n">use_fp8</span><span class="p">,</span> <span class="n">alpha</span><span class="p">])</span>
<span class="n">gemm_plug</span> <span class="o">=</span> <span class="n">plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">&quot;gemm&quot;</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">mat2</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">]</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">gemm_plug</span><span class="p">)</span>
<span class="n">_add_plugin_info</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">plg_creator</span><span class="p">,</span> <span class="s2">&quot;gemm&quot;</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
<div class="viewcode-block" id="LinearBase">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.layers.html#tensorrt_llm.layers.linear.LinearBase">[docs]</a>
<span class="k">class</span> <span class="nc">LinearBase</span><span class="p">(</span><span class="n">Module</span><span class="p">,</span> <span class="n">metaclass</span><span class="o">=</span><span class="n">ABCMeta</span><span class="p">):</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span>
<span class="n">local_in_features</span><span class="p">,</span>
<span class="n">local_out_features</span><span class="p">,</span>
<span class="n">bias</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">tp_group</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">tp_size</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
<span class="n">share_weight</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">strict_dtype</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">pad_lda</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
<span class="n">prefer_managed_weight</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">in_features</span> <span class="o">=</span> <span class="n">local_in_features</span>
<span class="bp">self</span><span class="o">.</span><span class="n">out_features</span> <span class="o">=</span> <span class="n">local_out_features</span>
<span class="bp">self</span><span class="o">.</span><span class="n">dtype</span> <span class="o">=</span> <span class="n">dtype</span>
<span class="bp">self</span><span class="o">.</span><span class="n">pad_lda</span> <span class="o">=</span> <span class="n">pad_lda</span>
<span class="bp">self</span><span class="o">.</span><span class="n">prefer_managed_weight</span> <span class="o">=</span> <span class="n">prefer_managed_weight</span>
<span class="bp">self</span><span class="o">.</span><span class="n">share_weight</span> <span class="o">=</span> <span class="n">share_weight</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">share_weight</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">weight</span> <span class="o">=</span> <span class="n">Parameter</span><span class="p">(</span>
<span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">out_features</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">in_features</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span>
<span class="n">prefer_managed</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">prefer_managed_weight</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">set_obj_attrs</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="p">,</span>
<span class="p">{</span>
<span class="s2">&quot;weight_loader&quot;</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight_loader</span><span class="p">,</span>
<span class="p">},</span>
<span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">weight</span> <span class="o">=</span> <span class="n">share_weight</span>
<span class="bp">self</span><span class="o">.</span><span class="n">tp_size</span> <span class="o">=</span> <span class="n">tp_size</span>
<span class="bp">self</span><span class="o">.</span><span class="n">tp_group</span> <span class="o">=</span> <span class="n">tp_group</span>
<span class="bp">self</span><span class="o">.</span><span class="n">strict_dtype</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dtype</span> <span class="k">if</span> <span class="n">strict_dtype</span> <span class="k">else</span> <span class="kc">None</span>
<span class="k">if</span> <span class="n">bias</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">bias</span> <span class="o">=</span> <span class="n">Parameter</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">out_features</span><span class="p">,</span> <span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">register_parameter</span><span class="p">(</span><span class="s2">&quot;bias&quot;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
<span class="c1"># see optimize_model&#39;s add_lora for LoRA initialization</span>
<span class="bp">self</span><span class="o">.</span><span class="n">lora</span> <span class="o">=</span> <span class="kc">None</span>
<div class="viewcode-block" id="LinearBase.weight_loader">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.layers.html#tensorrt_llm.layers.linear.LinearBase.weight_loader">[docs]</a>
<span class="k">def</span> <span class="nf">weight_loader</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">mapping</span><span class="p">:</span> <span class="n">Mapping</span><span class="p">,</span> <span class="n">param</span><span class="p">:</span> <span class="n">Parameter</span><span class="p">,</span>
<span class="n">loaded_weight</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">tp_rank</span> <span class="o">=</span> <span class="n">mapping</span><span class="o">.</span><span class="n">tp_rank</span>
<span class="n">shard_size</span> <span class="o">=</span> <span class="n">param</span><span class="o">.</span><span class="n">_shape</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">tp_split_dim</span><span class="p">()]</span>
<span class="n">start_idx</span> <span class="o">=</span> <span class="n">tp_rank</span> <span class="o">*</span> <span class="n">shard_size</span>
<span class="n">loaded_weight</span> <span class="o">=</span> <span class="n">loaded_weight</span><span class="o">.</span><span class="n">narrow</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">tp_split_dim</span><span class="p">(),</span> <span class="n">start_idx</span><span class="p">,</span>
<span class="n">shard_size</span><span class="p">)</span>
<span class="n">param</span><span class="o">.</span><span class="n">value</span> <span class="o">=</span> <span class="n">loaded_weight</span></div>
<div class="viewcode-block" id="LinearBase.tp_split_dim">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.layers.html#tensorrt_llm.layers.linear.LinearBase.tp_split_dim">[docs]</a>
<span class="nd">@classmethod</span>
<span class="nd">@abstractmethod</span>
<span class="k">def</span> <span class="nf">tp_split_dim</span><span class="p">(</span><span class="bp">cls</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">int</span><span class="p">:</span>
<span class="k">pass</span></div>
<div class="viewcode-block" id="LinearBase.weight_is_kn">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.layers.html#tensorrt_llm.layers.linear.LinearBase.weight_is_kn">[docs]</a>
<span class="k">def</span> <span class="nf">weight_is_kn</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="c1"># WAR for bug 4641821</span>
<span class="k">return</span> <span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">manage_weights</span>
<span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">prefer_managed_weight</span>
<span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">HALF</span><span class="p">)</span></div>
<div class="viewcode-block" id="LinearBase.get_weight">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.layers.html#tensorrt_llm.layers.linear.LinearBase.get_weight">[docs]</a>
<span class="k">def</span> <span class="nf">get_weight</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="k">if</span> <span class="n">default_net</span><span class="p">(</span>
<span class="p">)</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">manage_weights</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">prefer_managed_weight</span><span class="p">:</span>
<span class="n">use_gemm_plugin</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">(</span>
<span class="p">)</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">gemm_plugin</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="n">use_low_latency_gemm_plugin</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">(</span>
<span class="p">)</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">low_latency_gemm_plugin</span> <span class="o">==</span> <span class="s1">&#39;fp8&#39;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">get_managed_tensor</span><span class="p">(</span>
<span class="n">network</span><span class="o">=</span><span class="n">default_net</span><span class="p">(),</span>
<span class="n">need_transpose</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">weight_is_kn</span><span class="p">()</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">use_gemm_plugin</span>
<span class="ow">and</span> <span class="ow">not</span> <span class="n">use_low_latency_gemm_plugin</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">get_constant_tensor</span><span class="p">(</span><span class="n">network</span><span class="o">=</span><span class="n">default_net</span><span class="p">())</span></div>
<div class="viewcode-block" id="LinearBase.multiply_and_lora">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.layers.html#tensorrt_llm.layers.linear.LinearBase.multiply_and_lora">[docs]</a>
<span class="k">def</span> <span class="nf">multiply_and_lora</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span>
<span class="n">x</span><span class="p">,</span>
<span class="n">weight</span><span class="p">,</span>
<span class="n">gemm_plugin</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">low_latency_gemm_plugin</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">use_fp8</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">alpha</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">lora_runtime_params</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">LoraRuntimeParams</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">lora_hidden_state</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="p">):</span>
<span class="n">hidden_state</span> <span class="o">=</span> <span class="n">x</span>
<span class="k">if</span> <span class="n">low_latency_gemm_plugin</span><span class="p">:</span>
<span class="n">strict_dtype</span> <span class="o">=</span> <span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="nb">str</span><span class="p">)</span> <span class="k">else</span> <span class="bp">self</span><span class="o">.</span><span class="n">dtype</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">low_latency_gemm</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">alpha</span><span class="p">,</span> <span class="n">strict_dtype</span><span class="p">)</span>
<span class="k">elif</span> <span class="n">gemm_plugin</span><span class="p">:</span>
<span class="k">if</span> <span class="n">gemm_plugin</span> <span class="o">==</span> <span class="s1">&#39;fp8&#39;</span><span class="p">:</span>
<span class="n">strict_dtype</span> <span class="o">=</span> <span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="nb">str</span><span class="p">)</span> <span class="k">else</span> <span class="bp">self</span><span class="o">.</span><span class="n">dtype</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">strict_dtype</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">strict_dtype</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">_gemm_plugin</span><span class="p">(</span><span class="n">x</span><span class="p">,</span>
<span class="n">weight</span><span class="p">,</span>
<span class="n">transb</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="n">pad_lda</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">pad_lda</span><span class="p">,</span>
<span class="n">use_fp8</span><span class="o">=</span><span class="n">use_fp8</span><span class="p">,</span>
<span class="n">alpha</span><span class="o">=</span><span class="n">alpha</span><span class="p">,</span>
<span class="n">strict_dtype</span><span class="o">=</span><span class="n">strict_dtype</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">matmul</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">transb</span><span class="o">=</span><span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight_is_kn</span><span class="p">())</span>
<span class="k">if</span> <span class="n">default_net</span><span class="p">(</span>
<span class="p">)</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">lora_plugin</span> <span class="ow">and</span> <span class="n">lora_runtime_params</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">lora</span><span class="p">(</span>
<span class="n">hidden_state</span>
<span class="k">if</span> <span class="n">lora_hidden_state</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">lora_hidden_state</span><span class="p">,</span>
<span class="n">lora_runtime_params</span><span class="o">=</span><span class="n">lora_runtime_params</span><span class="p">,</span>
<span class="p">)</span>
<span class="k">return</span> <span class="n">x</span></div>
<div class="viewcode-block" id="LinearBase.collect_and_bias">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.layers.html#tensorrt_llm.layers.linear.LinearBase.collect_and_bias">[docs]</a>
<span class="nd">@abstractmethod</span>
<span class="k">def</span> <span class="nf">collect_and_bias</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="k">pass</span></div>
<div class="viewcode-block" id="LinearBase.multiply_collect">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.layers.html#tensorrt_llm.layers.linear.LinearBase.multiply_collect">[docs]</a>
<span class="k">def</span> <span class="nf">multiply_collect</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span>
<span class="n">x</span><span class="p">,</span>
<span class="n">weight</span><span class="p">,</span>
<span class="n">gemm_plugin</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">low_latency_gemm_plugin</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">use_fp8</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">alpha</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">lora_runtime_params</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">LoraRuntimeParams</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">lora_hidden_state</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">multiply_and_lora</span><span class="p">(</span>
<span class="n">x</span><span class="p">,</span>
<span class="n">weight</span><span class="p">,</span>
<span class="n">gemm_plugin</span><span class="o">=</span><span class="n">gemm_plugin</span><span class="p">,</span>
<span class="n">low_latency_gemm_plugin</span><span class="o">=</span><span class="n">low_latency_gemm_plugin</span><span class="p">,</span>
<span class="n">use_fp8</span><span class="o">=</span><span class="n">use_fp8</span><span class="p">,</span>
<span class="n">alpha</span><span class="o">=</span><span class="n">alpha</span><span class="p">,</span>
<span class="n">lora_runtime_params</span><span class="o">=</span><span class="n">lora_runtime_params</span><span class="p">,</span>
<span class="n">lora_hidden_state</span><span class="o">=</span><span class="n">lora_hidden_state</span><span class="p">,</span>
<span class="p">)</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">collect_and_bias</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span></div>
<div class="viewcode-block" id="LinearBase.forward">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.layers.html#tensorrt_llm.layers.linear.LinearBase.forward">[docs]</a>
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">x</span><span class="p">,</span>
<span class="n">lora_runtime_params</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">LoraRuntimeParams</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">lora_hidden_state</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">multiply_collect</span><span class="p">(</span>
<span class="n">x</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">get_weight</span><span class="p">(),</span>
<span class="n">gemm_plugin</span><span class="o">=</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">gemm_plugin</span><span class="p">,</span>
<span class="n">use_fp8</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">lora_runtime_params</span><span class="o">=</span><span class="n">lora_runtime_params</span><span class="p">,</span>
<span class="n">lora_hidden_state</span><span class="o">=</span><span class="n">lora_hidden_state</span><span class="p">,</span>
<span class="o">**</span><span class="n">kwargs</span><span class="p">)</span></div>
</div>
<div class="viewcode-block" id="Linear">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.layers.html#tensorrt_llm.layers.linear.Linear">[docs]</a>
<span class="k">class</span> <span class="nc">Linear</span><span class="p">(</span><span class="n">LinearBase</span><span class="p">):</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span>
<span class="n">in_features</span><span class="p">,</span>
<span class="n">out_features</span><span class="p">,</span>
<span class="n">bias</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">tp_group</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">tp_size</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
<span class="n">gather_output</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="n">share_weight</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">strict_dtype</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">pad_lda</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
<span class="n">prefer_managed_weight</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="n">is_qkv</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span>
<span class="n">local_in_features</span><span class="o">=</span><span class="n">in_features</span><span class="p">,</span>
<span class="n">local_out_features</span><span class="o">=</span><span class="n">out_features</span> <span class="o">//</span> <span class="n">tp_size</span><span class="p">,</span>
<span class="n">bias</span><span class="o">=</span><span class="n">bias</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span>
<span class="n">tp_group</span><span class="o">=</span><span class="n">tp_group</span><span class="p">,</span>
<span class="n">tp_size</span><span class="o">=</span><span class="n">tp_size</span><span class="p">,</span>
<span class="n">share_weight</span><span class="o">=</span><span class="n">share_weight</span><span class="p">,</span>
<span class="n">strict_dtype</span><span class="o">=</span><span class="n">strict_dtype</span><span class="p">,</span>
<span class="n">pad_lda</span><span class="o">=</span><span class="n">pad_lda</span><span class="p">,</span>
<span class="n">prefer_managed_weight</span><span class="o">=</span><span class="n">prefer_managed_weight</span><span class="p">,</span>
<span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">gather_output</span> <span class="o">=</span> <span class="n">gather_output</span>
<span class="bp">self</span><span class="o">.</span><span class="n">is_qkv</span> <span class="o">=</span> <span class="n">is_qkv</span>
<span class="bp">self</span><span class="o">.</span><span class="n">tp_dim</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">if</span> <span class="n">bias</span><span class="p">:</span>
<span class="n">set_obj_attrs</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="p">,</span>
<span class="p">{</span>
<span class="s2">&quot;weight_loader&quot;</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight_loader</span><span class="p">,</span>
<span class="p">},</span>
<span class="p">)</span>
<div class="viewcode-block" id="Linear.tp_split_dim">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.layers.html#tensorrt_llm.layers.linear.Linear.tp_split_dim">[docs]</a>
<span class="nd">@classmethod</span>
<span class="k">def</span> <span class="nf">tp_split_dim</span><span class="p">(</span><span class="bp">cls</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">int</span><span class="p">:</span>
<span class="k">return</span> <span class="mi">0</span></div>
<div class="viewcode-block" id="Linear.collect_and_bias">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.layers.html#tensorrt_llm.layers.linear.Linear.collect_and_bias">[docs]</a>
<span class="k">def</span> <span class="nf">collect_and_bias</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">bias</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="o">.</span><span class="n">value</span><span class="p">,</span> <span class="n">x</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">bias</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_output</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">tp_size</span> <span class="o">&gt;</span> <span class="mi">1</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">tp_group</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="c1"># [dim0, local_dim] -&gt; [dim0 * tp_size, local_dim] --&gt; [dim0, local_dim * tp_size]</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">allgather</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">tp_group</span><span class="p">,</span> <span class="n">gather_dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="k">return</span> <span class="n">x</span></div>
<div class="viewcode-block" id="Linear.postprocess">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.layers.html#tensorrt_llm.layers.linear.Linear.postprocess">[docs]</a>
<span class="k">def</span> <span class="nf">postprocess</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">tllm_key</span><span class="p">,</span> <span class="n">weights</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="n">using_head_as_leading_dim</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;using_head_as_leading_dim&quot;</span><span class="p">,</span>
<span class="kc">False</span><span class="p">)</span>
<span class="n">config</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;config&quot;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_qkv</span><span class="p">:</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">weights</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span>
<span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">config</span><span class="p">,</span> <span class="s2">&quot;remove_duplicated_kv_heads&quot;</span><span class="p">):</span>
<span class="k">if</span> <span class="n">config</span><span class="o">.</span><span class="n">remove_duplicated_kv_heads</span><span class="p">:</span>
<span class="n">head_size</span> <span class="o">=</span> <span class="n">config</span><span class="o">.</span><span class="n">hidden_size</span> <span class="o">//</span> <span class="n">config</span><span class="o">.</span><span class="n">num_attention_heads</span> <span class="k">if</span> <span class="n">config</span><span class="o">.</span><span class="n">head_size</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">config</span><span class="o">.</span><span class="n">head_size</span>
<span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="o">=</span> <span class="n">weights</span><span class="p">[</span><span class="mi">1</span><span class="p">:]</span>
<span class="n">k</span> <span class="o">=</span> <span class="n">k</span><span class="o">.</span><span class="n">reshape</span><span class="p">([</span>
<span class="n">k</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">//</span> <span class="n">head_size</span> <span class="o">//</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">head_size</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">in_features</span>
<span class="p">])</span>
<span class="n">v</span> <span class="o">=</span> <span class="n">v</span><span class="o">.</span><span class="n">reshape</span><span class="p">([</span>
<span class="n">v</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">//</span> <span class="n">head_size</span> <span class="o">//</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">head_size</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">in_features</span>
<span class="p">])</span>
<span class="k">assert</span> <span class="p">(</span><span class="n">k</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">k</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">])</span><span class="o">.</span><span class="n">all</span><span class="p">()</span>
<span class="k">assert</span> <span class="p">(</span><span class="n">v</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">v</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">])</span><span class="o">.</span><span class="n">all</span><span class="p">()</span>
<span class="n">k</span> <span class="o">=</span> <span class="n">k</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">reshape</span><span class="p">([</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">in_features</span><span class="p">])</span>
<span class="n">v</span> <span class="o">=</span> <span class="n">v</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">reshape</span><span class="p">([</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">in_features</span><span class="p">])</span>
<span class="n">weights</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="n">k</span>
<span class="n">weights</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">v</span>
<span class="n">weights</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="n">weights</span><span class="p">)</span>
<span class="k">if</span> <span class="n">using_head_as_leading_dim</span><span class="p">:</span>
<span class="c1"># Reorder [n_head, 3, head_dim, ...] into [3, n_head, head_dim, ...]</span>
<span class="k">assert</span> <span class="n">config</span><span class="o">.</span><span class="n">num_attention_heads</span> <span class="o">==</span> <span class="n">config</span><span class="o">.</span><span class="n">num_key_value_heads</span><span class="p">,</span> <span class="s2">&quot;using_head_as_leading_dim require head_size to be multiple of 3.&quot;</span>
<span class="n">num_heads</span> <span class="o">=</span> <span class="n">config</span><span class="o">.</span><span class="n">num_attention_heads</span>
<span class="n">head_dim</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">out_features</span> <span class="o">//</span> <span class="p">(</span><span class="mi">3</span> <span class="o">*</span> <span class="n">num_heads</span><span class="p">)</span>
<span class="n">w</span> <span class="o">=</span> <span class="n">weights</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">num_heads</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">head_dim</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">w</span> <span class="o">=</span> <span class="n">w</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="k">if</span> <span class="n">w</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
<span class="n">weights</span> <span class="o">=</span> <span class="n">w</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">in_features</span><span class="p">)</span> <span class="c1"># Weight</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">weights</span> <span class="o">=</span> <span class="n">w</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="c1"># Bias</span>
<span class="n">weights</span> <span class="o">=</span> <span class="n">weights</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">str_dtype_to_torch</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dtype</span><span class="p">))</span>
<span class="k">return</span> <span class="p">{</span><span class="n">tllm_key</span><span class="p">:</span> <span class="n">weights</span><span class="p">}</span></div>
</div>
<span class="n">ColumnLinear</span> <span class="o">=</span> <span class="n">Linear</span>
<div class="viewcode-block" id="RowLinear">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.layers.html#tensorrt_llm.layers.linear.RowLinear">[docs]</a>
<span class="k">class</span> <span class="nc">RowLinear</span><span class="p">(</span><span class="n">LinearBase</span><span class="p">):</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span>
<span class="n">in_features</span><span class="p">,</span>
<span class="n">out_features</span><span class="p">,</span>
<span class="n">bias</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">tp_group</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">tp_size</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
<span class="n">strict_dtype</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">pad_lda</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
<span class="n">prefer_managed_weight</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="n">is_expert</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span>
<span class="n">local_in_features</span><span class="o">=</span><span class="n">in_features</span> <span class="o">//</span> <span class="n">tp_size</span><span class="p">,</span>
<span class="n">local_out_features</span><span class="o">=</span><span class="n">out_features</span><span class="p">,</span>
<span class="n">bias</span><span class="o">=</span><span class="n">bias</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span>
<span class="n">tp_group</span><span class="o">=</span><span class="n">tp_group</span><span class="p">,</span>
<span class="n">tp_size</span><span class="o">=</span><span class="n">tp_size</span><span class="p">,</span>
<span class="n">strict_dtype</span><span class="o">=</span><span class="n">strict_dtype</span><span class="p">,</span>
<span class="n">pad_lda</span><span class="o">=</span><span class="n">pad_lda</span><span class="p">,</span>
<span class="n">prefer_managed_weight</span><span class="o">=</span><span class="n">prefer_managed_weight</span><span class="p">,</span>
<span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">tp_dim</span> <span class="o">=</span> <span class="mi">1</span>
<span class="bp">self</span><span class="o">.</span><span class="n">tp_size</span> <span class="o">=</span> <span class="n">tp_size</span>
<span class="bp">self</span><span class="o">.</span><span class="n">is_expert</span> <span class="o">=</span> <span class="n">is_expert</span>
<div class="viewcode-block" id="RowLinear.tp_split_dim">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.layers.html#tensorrt_llm.layers.linear.RowLinear.tp_split_dim">[docs]</a>
<span class="nd">@classmethod</span>
<span class="k">def</span> <span class="nf">tp_split_dim</span><span class="p">(</span><span class="bp">cls</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">int</span><span class="p">:</span>
<span class="k">return</span> <span class="mi">1</span></div>
<div class="viewcode-block" id="RowLinear.collect_and_bias">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.layers.html#tensorrt_llm.layers.linear.RowLinear.collect_and_bias">[docs]</a>
<span class="k">def</span> <span class="nf">collect_and_bias</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="n">reduce_fusion_params</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">AllReduceFusionParams</span><span class="p">]</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span>
<span class="s2">&quot;reduce_fusion_params&quot;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">tp_size</span> <span class="o">&gt;</span> <span class="mi">1</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">tp_group</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">need_bias</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="n">fuse_bias_into_all_reduce</span> <span class="o">=</span> <span class="p">(</span>
<span class="n">need_bias</span> <span class="ow">and</span> <span class="p">(</span><span class="n">reduce_fusion_params</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">)</span>
<span class="ow">and</span> <span class="p">(</span><span class="n">reduce_fusion_params</span><span class="o">.</span><span class="n">fusion_op</span>
<span class="o">==</span> <span class="n">AllReduceFusionOp</span><span class="o">.</span><span class="n">RESIDUAL_RMS_NORM</span><span class="p">))</span>
<span class="k">if</span> <span class="n">fuse_bias_into_all_reduce</span><span class="p">:</span>
<span class="n">reduce_fusion_params</span><span class="o">.</span><span class="n">bias</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="o">.</span><span class="n">value</span>
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_expert</span><span class="p">:</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">allreduce</span><span class="p">(</span><span class="n">x</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">tp_group</span><span class="p">,</span>
<span class="n">reduce_fusion_params</span><span class="o">=</span><span class="n">reduce_fusion_params</span><span class="p">)</span>
<span class="k">if</span> <span class="n">need_bias</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">fuse_bias_into_all_reduce</span><span class="p">:</span>
<span class="n">bias</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="o">.</span><span class="n">value</span><span class="p">,</span> <span class="n">x</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">bias</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">if</span> <span class="n">need_bias</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">fuse_bias_into_all_reduce</span><span class="p">:</span>
<span class="n">bias</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="o">.</span><span class="n">value</span><span class="p">,</span> <span class="n">x</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">bias</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">tp_size</span>
<span class="k">return</span> <span class="n">x</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">bias</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="o">.</span><span class="n">value</span><span class="p">,</span> <span class="n">x</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">bias</span>
<span class="k">return</span> <span class="n">x</span></div>
</div>
</pre></div>
</div>
</div>
<footer>
<hr/>
<div role="contentinfo">
<jinja2.runtime.BlockReference object at 0x7fea0f92d9f0>
<div class="footer">
<p>
Copyright © 2024 NVIDIA Corporation
</p>
<p>
<a class="Link" href="https://www.nvidia.com/en-us/about-nvidia/privacy-policy/" target="_blank" rel="noopener"
data-cms-ai="0">Privacy Policy</a> |
<a class="Link" href="https://www.nvidia.com/en-us/about-nvidia/privacy-center/" target="_blank" rel="noopener"
data-cms-ai="0">Manage My Privacy</a> |
<a class="Link" href="https://www.nvidia.com/en-us/preferences/start/" target="_blank" rel="noopener"
data-cms-ai="0">Do Not Sell or Share My Data</a> |
<a class="Link" href="https://www.nvidia.com/en-us/about-nvidia/terms-of-service/" target="_blank"
rel="noopener" data-cms-ai="0">Terms of Service</a> |
<a class="Link" href="https://www.nvidia.com/en-us/about-nvidia/accessibility/" target="_blank" rel="noopener"
data-cms-ai="0">Accessibility</a> |
<a class="Link" href="https://www.nvidia.com/en-us/about-nvidia/company-policies/" target="_blank"
rel="noopener" data-cms-ai="0">Corporate Policies</a> |
<a class="Link" href="https://www.nvidia.com/en-us/product-security/" target="_blank" rel="noopener"
data-cms-ai="0">Product Security</a> |
<a class="Link" href="https://www.nvidia.com/en-us/contact/" target="_blank" rel="noopener"
data-cms-ai="0">Contact</a>
</p>
</div>
</div>
</footer>
</div>
</div>
</section>
</div>
<script>
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
</script>
</body>
</html>