TensorRT-LLMs/_modules/tensorrt_llm/layers/linear.html
2023-12-04 18:59:41 +08:00

329 lines
37 KiB
HTML

<!DOCTYPE html>
<html class="writer-html5" lang="en" data-content_root="../../../">
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>tensorrt_llm.layers.linear &mdash; tensorrt_llm documentation</title>
<link rel="stylesheet" type="text/css" href="../../../_static/pygments.css?v=80d5e7a1" />
<link rel="stylesheet" type="text/css" href="../../../_static/css/theme.css?v=19f00094" />
<!--[if lt IE 9]>
<script src="../../../_static/js/html5shiv.min.js"></script>
<![endif]-->
<script src="../../../_static/jquery.js?v=5d32c60e"></script>
<script src="../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
<script src="../../../_static/documentation_options.js?v=5929fcd5"></script>
<script src="../../../_static/doctools.js?v=888ff710"></script>
<script src="../../../_static/sphinx_highlight.js?v=dc90522c"></script>
<script src="../../../_static/js/theme.js"></script>
<link rel="index" title="Index" href="../../../genindex.html" />
<link rel="search" title="Search" href="../../../search.html" />
</head>
<body class="wy-body-for-nav">
<div class="wy-grid-for-nav">
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
<div class="wy-side-scroll">
<div class="wy-side-nav-search" >
<a href="../../../index.html" class="icon icon-home">
tensorrt_llm
</a>
<div role="search">
<form id="rtd-search-form" class="wy-form" action="../../../search.html" method="get">
<input type="text" name="q" placeholder="Search docs" aria-label="Search docs" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
</div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
<p class="caption" role="heading"><span class="caption-text">Contents:</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../architecture.html">TensorRT-LLM Architecture</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../gpt_runtime.html">C++ GPT Runtime</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../batch_manager.html">The Batch Manager in TensorRT-LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../gpt_attention.html">Multi-head, Multi-query and Group-query Attention</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../precision.html">Numerical Precision</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../installation.html">Build TensorRT-LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../performance.html">Performance of TensorRT-LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../2023-05-19-how-to-debug.html">How to debug</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../2023-05-17-how-to-add-a-new-model.html">How to add a new model</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../graph-rewriting.html">Graph Rewriting Module</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../memory.html">Memory Usage of TensorRT-LLM</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../python-api/tensorrt_llm.layers.html">Layers</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../python-api/tensorrt_llm.functional.html">Functionals</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../python-api/tensorrt_llm.models.html">Models</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../python-api/tensorrt_llm.plugin.html">Plugin</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../python-api/tensorrt_llm.quantization.html">Quantization</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../python-api/tensorrt_llm.runtime.html">Runtime</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">C++ API</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../_cpp_gen/runtime.html">Runtime</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Blogs</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/H100vsA100.html">H100 has 4.6x A100 Performance in TensorRT-LLM, achieving 10,000 tok/s at 100ms to first token</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/H200launch.html">H200 achieves nearly 12,000 tokens/sec on Llama2-13B with TensorRT-LLM</a></li>
</ul>
</div>
</div>
</nav>
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
<a href="../../../index.html">tensorrt_llm</a>
</nav>
<div class="wy-nav-content">
<div class="rst-content">
<div role="navigation" aria-label="Page navigation">
<ul class="wy-breadcrumbs">
<li><a href="../../../index.html" class="icon icon-home" aria-label="Home"></a></li>
<li class="breadcrumb-item"><a href="../../index.html">Module code</a></li>
<li class="breadcrumb-item active">tensorrt_llm.layers.linear</li>
<li class="wy-breadcrumbs-aside">
</li>
</ul>
<hr/>
</div>
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
<div itemprop="articleBody">
<h1>Source code for tensorrt_llm.layers.linear</h1><div class="highlight"><pre>
<span></span><span class="c1"># SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION &amp; AFFILIATES. All rights reserved.</span>
<span class="c1"># SPDX-License-Identifier: Apache-2.0</span>
<span class="c1">#</span>
<span class="c1"># Licensed under the Apache License, Version 2.0 (the &quot;License&quot;);</span>
<span class="c1"># you may not use this file except in compliance with the License.</span>
<span class="c1"># You may obtain a copy of the License at</span>
<span class="c1">#</span>
<span class="c1"># http://www.apache.org/licenses/LICENSE-2.0</span>
<span class="c1">#</span>
<span class="c1"># Unless required by applicable law or agreed to in writing, software</span>
<span class="c1"># distributed under the License is distributed on an &quot;AS IS&quot; BASIS,</span>
<span class="c1"># WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.</span>
<span class="c1"># See the License for the specific language governing permissions and</span>
<span class="c1"># limitations under the License.</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="kn">import</span> <span class="nn">tensorrt</span> <span class="k">as</span> <span class="nn">trt</span>
<span class="kn">from</span> <span class="nn">.._common</span> <span class="kn">import</span> <span class="n">default_net</span><span class="p">,</span> <span class="n">default_trtnet</span>
<span class="kn">from</span> <span class="nn">.._utils</span> <span class="kn">import</span> <span class="n">int32_array</span><span class="p">,</span> <span class="n">str_dtype_to_trt</span>
<span class="kn">from</span> <span class="nn">..functional</span> <span class="kn">import</span> <span class="p">(</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">_create_tensor</span><span class="p">,</span> <span class="n">allgather</span><span class="p">,</span> <span class="n">allreduce</span><span class="p">,</span> <span class="n">cast</span><span class="p">,</span>
<span class="n">concat</span><span class="p">,</span> <span class="n">constant</span><span class="p">,</span> <span class="n">matmul</span><span class="p">,</span> <span class="n">shape</span><span class="p">,</span> <span class="nb">slice</span><span class="p">)</span>
<span class="kn">from</span> <span class="nn">..module</span> <span class="kn">import</span> <span class="n">Module</span>
<span class="kn">from</span> <span class="nn">..parameter</span> <span class="kn">import</span> <span class="n">Parameter</span>
<span class="kn">from</span> <span class="nn">..plugin</span> <span class="kn">import</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span>
<span class="k">def</span> <span class="nf">_gemm_plugin</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">mat2</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">transa</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">transb</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">use_fp8</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="n">plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
<span class="s1">&#39;Gemm&#39;</span><span class="p">,</span> <span class="s1">&#39;1&#39;</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="n">transa</span> <span class="o">=</span> <span class="mi">1</span> <span class="k">if</span> <span class="n">transa</span> <span class="k">else</span> <span class="mi">0</span>
<span class="n">transa</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;transa&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">transa</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">transb</span> <span class="o">=</span> <span class="mi">1</span> <span class="k">if</span> <span class="n">transb</span> <span class="k">else</span> <span class="mi">0</span>
<span class="n">transb</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;transb&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">transb</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">use_fp8</span> <span class="o">=</span> <span class="mi">1</span> <span class="k">if</span> <span class="n">use_fp8</span> <span class="k">else</span> <span class="mi">0</span>
<span class="n">use_fp8</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;use_fp8&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">use_fp8</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">p_dtype</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">gemm_plugin</span>
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;type_id&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">))],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span><span class="n">transa</span><span class="p">,</span> <span class="n">transb</span><span class="p">,</span> <span class="n">pf_type</span><span class="p">,</span> <span class="n">use_fp8</span><span class="p">])</span>
<span class="n">gemm_plug</span> <span class="o">=</span> <span class="n">plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">&quot;gemm&quot;</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">mat2</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">]</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">gemm_plug</span><span class="p">)</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
<div class="viewcode-block" id="Linear">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.layers.html#tensorrt_llm.layers.linear.Linear">[docs]</a>
<span class="k">class</span> <span class="nc">Linear</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">in_features</span><span class="p">,</span>
<span class="n">out_features</span><span class="p">,</span>
<span class="n">bias</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">tp_group</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">tp_size</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
<span class="n">gather_output</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="n">share_weight</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">in_features</span> <span class="o">=</span> <span class="n">in_features</span>
<span class="bp">self</span><span class="o">.</span><span class="n">out_features</span> <span class="o">=</span> <span class="n">out_features</span> <span class="o">//</span> <span class="n">tp_size</span>
<span class="bp">self</span><span class="o">.</span><span class="n">dtype</span> <span class="o">=</span> <span class="n">dtype</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">share_weight</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">weight</span> <span class="o">=</span> <span class="n">Parameter</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">out_features</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">in_features</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">weight</span> <span class="o">=</span> <span class="n">share_weight</span>
<span class="bp">self</span><span class="o">.</span><span class="n">tp_size</span> <span class="o">=</span> <span class="n">tp_size</span>
<span class="bp">self</span><span class="o">.</span><span class="n">tp_group</span> <span class="o">=</span> <span class="n">tp_group</span>
<span class="bp">self</span><span class="o">.</span><span class="n">gather_output</span> <span class="o">=</span> <span class="n">gather_output</span>
<span class="k">if</span> <span class="n">bias</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">bias</span> <span class="o">=</span> <span class="n">Parameter</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">out_features</span><span class="p">,</span> <span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">register_parameter</span><span class="p">(</span><span class="s1">&#39;bias&#39;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
<div class="viewcode-block" id="Linear.multiply_gather">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.layers.html#tensorrt_llm.layers.linear.Linear.multiply_gather">[docs]</a>
<span class="k">def</span> <span class="nf">multiply_gather</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">gemm_plugin</span><span class="p">,</span> <span class="n">use_fp8</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
<span class="k">if</span> <span class="n">gemm_plugin</span><span class="p">:</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">_gemm_plugin</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">transb</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">use_fp8</span><span class="o">=</span><span class="n">use_fp8</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">matmul</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">transb</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">if</span> <span class="n">x</span><span class="o">.</span><span class="n">dtype</span> <span class="o">!=</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="o">.</span><span class="n">value</span><span class="o">.</span><span class="n">dtype</span><span class="p">:</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="o">.</span><span class="n">value</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="o">.</span><span class="n">value</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_output</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">tp_size</span> <span class="o">&gt;</span> <span class="mi">1</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">tp_group</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="c1"># 1. [dim0, local_dim] -&gt; [dim0 * tp_size, local_dim]</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">allgather</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">tp_group</span><span class="p">)</span>
<span class="c1"># 2. [dim0 * tp_size, local_dim] -&gt; [dim0, local_dim * tp_size]</span>
<span class="c1"># 2.1 split</span>
<span class="n">split_size</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">tp_size</span>
<span class="n">ndim</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
<span class="n">starts</span> <span class="o">=</span> <span class="p">[</span><span class="n">constant</span><span class="p">(</span><span class="n">int32_array</span><span class="p">([</span><span class="mi">0</span><span class="p">]))</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)]</span>
<span class="n">sizes</span> <span class="o">=</span> <span class="p">[</span><span class="n">shape</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="n">d</span><span class="p">)</span> <span class="k">for</span> <span class="n">d</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)]</span>
<span class="n">sizes</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">split_size</span>
<span class="n">sections</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">tp_size</span><span class="p">):</span>
<span class="n">starts</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">split_size</span> <span class="o">*</span> <span class="n">i</span>
<span class="n">sections</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="nb">slice</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">concat</span><span class="p">(</span><span class="n">starts</span><span class="p">),</span> <span class="n">concat</span><span class="p">(</span><span class="n">sizes</span><span class="p">)))</span>
<span class="c1"># 2.2 concat</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">concat</span><span class="p">(</span><span class="n">sections</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="k">return</span> <span class="n">x</span></div>
<div class="viewcode-block" id="Linear.forward">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.layers.html#tensorrt_llm.layers.linear.Linear.forward">[docs]</a>
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">multiply_gather</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">value</span><span class="p">,</span>
<span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">gemm_plugin</span><span class="p">)</span></div>
</div>
<span class="n">ColumnLinear</span> <span class="o">=</span> <span class="n">Linear</span>
<div class="viewcode-block" id="RowLinear">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.layers.html#tensorrt_llm.layers.linear.RowLinear">[docs]</a>
<span class="k">class</span> <span class="nc">RowLinear</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">in_features</span><span class="p">,</span>
<span class="n">out_features</span><span class="p">,</span>
<span class="n">bias</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">tp_group</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">tp_size</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
<span class="n">instance_id</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">in_features</span> <span class="o">=</span> <span class="n">in_features</span> <span class="o">//</span> <span class="n">tp_size</span>
<span class="bp">self</span><span class="o">.</span><span class="n">out_features</span> <span class="o">=</span> <span class="n">out_features</span>
<span class="bp">self</span><span class="o">.</span><span class="n">dtype</span> <span class="o">=</span> <span class="n">dtype</span>
<span class="bp">self</span><span class="o">.</span><span class="n">weight</span> <span class="o">=</span> <span class="n">Parameter</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">out_features</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">in_features</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
<span class="k">if</span> <span class="n">bias</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">bias</span> <span class="o">=</span> <span class="n">Parameter</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">out_features</span><span class="p">,</span> <span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">register_parameter</span><span class="p">(</span><span class="s1">&#39;bias&#39;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">tp_group</span> <span class="o">=</span> <span class="n">tp_group</span>
<span class="bp">self</span><span class="o">.</span><span class="n">tp_size</span> <span class="o">=</span> <span class="n">tp_size</span>
<span class="bp">self</span><span class="o">.</span><span class="n">instance_id</span> <span class="o">=</span> <span class="n">instance_id</span>
<div class="viewcode-block" id="RowLinear.multiply_reduce">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.layers.html#tensorrt_llm.layers.linear.RowLinear.multiply_reduce">[docs]</a>
<span class="k">def</span> <span class="nf">multiply_reduce</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">x</span><span class="p">,</span>
<span class="n">weight</span><span class="p">,</span>
<span class="n">gemm_plugin</span><span class="p">,</span>
<span class="n">use_fp8</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">workspace</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="k">if</span> <span class="n">gemm_plugin</span><span class="p">:</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">_gemm_plugin</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">transb</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">use_fp8</span><span class="o">=</span><span class="n">use_fp8</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">matmul</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">transb</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">tp_size</span> <span class="o">&gt;</span> <span class="mi">1</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">tp_group</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">allreduce</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">tp_group</span><span class="p">,</span> <span class="n">workspace</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">instance_id</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">if</span> <span class="n">x</span><span class="o">.</span><span class="n">dtype</span> <span class="o">!=</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="o">.</span><span class="n">value</span><span class="o">.</span><span class="n">dtype</span><span class="p">:</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="o">.</span><span class="n">value</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="o">.</span><span class="n">value</span>
<span class="k">return</span> <span class="n">x</span></div>
<div class="viewcode-block" id="RowLinear.forward">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.layers.html#tensorrt_llm.layers.linear.RowLinear.forward">[docs]</a>
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">workspace</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">multiply_reduce</span><span class="p">(</span><span class="n">x</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">value</span><span class="p">,</span>
<span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">gemm_plugin</span><span class="p">,</span>
<span class="n">workspace</span><span class="o">=</span><span class="n">workspace</span><span class="p">)</span></div>
</div>
</pre></div>
</div>
</div>
<footer>
<hr/>
<div role="contentinfo">
<p>&#169; Copyright 2023, NVidia.</p>
</div>
Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
<a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
provided by <a href="https://readthedocs.org">Read the Docs</a>.
</footer>
</div>
</div>
</section>
</div>
<script>
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
</script>
</body>
</html>