TensorRT-LLMs/_modules/tensorrt_llm/layers/embedding.html
2024-12-25 13:44:02 +08:00

415 lines
38 KiB
HTML

<!DOCTYPE html>
<html class="writer-html5" lang="en" data-content_root="../../../">
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>tensorrt_llm.layers.embedding &mdash; tensorrt_llm documentation</title>
<link rel="stylesheet" type="text/css" href="../../../_static/pygments.css?v=80d5e7a1" />
<link rel="stylesheet" type="text/css" href="../../../_static/css/theme.css?v=e59714d7" />
<link rel="stylesheet" type="text/css" href="../../../_static/copybutton.css?v=76b2166b" />
<script src="../../../_static/jquery.js?v=5d32c60e"></script>
<script src="../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
<script src="../../../_static/documentation_options.js?v=5929fcd5"></script>
<script src="../../../_static/doctools.js?v=888ff710"></script>
<script src="../../../_static/sphinx_highlight.js?v=dc90522c"></script>
<script src="../../../_static/clipboard.min.js?v=a7894cd8"></script>
<script src="../../../_static/copybutton.js?v=65e89d2a"></script>
<script src="../../../_static/js/theme.js"></script>
<link rel="index" title="Index" href="../../../genindex.html" />
<link rel="search" title="Search" href="../../../search.html" />
</head>
<body class="wy-body-for-nav">
<div class="wy-grid-for-nav">
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
<div class="wy-side-scroll">
<div class="wy-side-nav-search" >
<a href="../../../index.html" class="icon icon-home">
tensorrt_llm
</a>
<div role="search">
<form id="rtd-search-form" class="wy-form" action="../../../search.html" method="get">
<input type="text" name="q" placeholder="Search docs" aria-label="Search docs" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
</div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../overview.html">Overview</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../quick-start-guide.html">Quick Start Guide</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../key-features.html">Key Features</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../release-notes.html">Release Notes</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Installation</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../installation/linux.html">Installing on Linux</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../installation/build-from-source-linux.html">Building from Source Code on Linux</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../installation/windows.html">Installing on Windows</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../installation/build-from-source-windows.html">Building from Source Code on Windows</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../installation/grace-hopper.html">Installing on Grace Hopper</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">LLM API</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../llm-api/index.html">API Introduction</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../llm-api/reference.html">API Reference</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">LLM API Examples</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../llm-api-examples/index.html">LLM Examples Introduction</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../llm-api-examples/customization.html">Common Customizations</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../llm-api-examples/llm_api_examples.html">Examples</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Model Definition API</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../python-api/tensorrt_llm.layers.html">Layers</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../python-api/tensorrt_llm.functional.html">Functionals</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../python-api/tensorrt_llm.models.html">Models</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../python-api/tensorrt_llm.plugin.html">Plugin</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../python-api/tensorrt_llm.quantization.html">Quantization</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../python-api/tensorrt_llm.runtime.html">Runtime</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">C++ API</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../_cpp_gen/executor.html">Executor</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../_cpp_gen/runtime.html">Runtime</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Command-Line Reference</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../commands/trtllm-build.html">trtllm-build</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../commands/trtllm-serve.html">trtllm-serve</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Architecture</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../architecture/overview.html">TensorRT-LLM Architecture</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../architecture/core-concepts.html">Model Definition</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../architecture/core-concepts.html#compilation">Compilation</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../architecture/core-concepts.html#runtime">Runtime</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../architecture/core-concepts.html#multi-gpu-and-multi-node-support">Multi-GPU and Multi-Node Support</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../architecture/checkpoint.html">TensorRT-LLM Checkpoint</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../architecture/workflow.html">TensorRT-LLM Build Workflow</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../architecture/add-model.html">Adding a Model</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Advanced</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/gpt-attention.html">Multi-Head, Multi-Query, and Group-Query Attention</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/gpt-runtime.html">C++ GPT Runtime</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/executor.html">Executor API</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/graph-rewriting.html">Graph Rewriting Module</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/inference-request.html">Inference Request</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/inference-request.html#responses">Responses</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/lora.html">Run gpt-2b + LoRA using GptManager / cpp runtime</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/speculative-decoding.html">Speculative Sampling</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../performance/perf-overview.html">Overview</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../performance/perf-benchmarking.html">Benchmarking</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../performance/perf-best-practices.html">Best Practices</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../performance/perf-analysis.html">Performance Analysis</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Reference</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../reference/troubleshooting.html">Troubleshooting</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../reference/support-matrix.html">Support Matrix</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../reference/precision.html">Numerical Precision</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../reference/memory.html">Memory Usage of TensorRT-LLM</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Blogs</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/H100vsA100.html">H100 has 4.6x A100 Performance in TensorRT-LLM, achieving 10,000 tok/s at 100ms to first token</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/H200launch.html">H200 achieves nearly 12,000 tokens/sec on Llama2-13B with TensorRT-LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/Falcon180B-H200.html">Falcon-180B on a single H200 GPU with INT4 AWQ, and 6.7x faster Llama-70B over A100</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/quantization-in-TRT-LLM.html">Speed up inference with SOTA quantization techniques in TRT-LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/XQA-kernel.html">New XQA-kernel provides 2.4x more Llama-70B throughput within the same latency budget</a></li>
</ul>
</div>
</div>
</nav>
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
<a href="../../../index.html">tensorrt_llm</a>
</nav>
<div class="wy-nav-content">
<div class="rst-content">
<div role="navigation" aria-label="Page navigation">
<ul class="wy-breadcrumbs">
<li><a href="../../../index.html" class="icon icon-home" aria-label="Home"></a></li>
<li class="breadcrumb-item"><a href="../../index.html">Module code</a></li>
<li class="breadcrumb-item active">tensorrt_llm.layers.embedding</li>
<li class="wy-breadcrumbs-aside">
</li>
</ul>
<hr/>
</div>
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
<div itemprop="articleBody">
<h1>Source code for tensorrt_llm.layers.embedding</h1><div class="highlight"><pre>
<span></span><span class="c1"># SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION &amp; AFFILIATES. All rights reserved.</span>
<span class="c1"># SPDX-License-Identifier: Apache-2.0</span>
<span class="c1">#</span>
<span class="c1"># Licensed under the Apache License, Version 2.0 (the &quot;License&quot;);</span>
<span class="c1"># you may not use this file except in compliance with the License.</span>
<span class="c1"># You may obtain a copy of the License at</span>
<span class="c1">#</span>
<span class="c1"># http://www.apache.org/licenses/LICENSE-2.0</span>
<span class="c1">#</span>
<span class="c1"># Unless required by applicable law or agreed to in writing, software</span>
<span class="c1"># distributed under the License is distributed on an &quot;AS IS&quot; BASIS,</span>
<span class="c1"># WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.</span>
<span class="c1"># See the License for the specific language governing permissions and</span>
<span class="c1"># limitations under the License.</span>
<span class="kn">import</span> <span class="nn">math</span>
<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Optional</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">from</span> <span class="nn">.._utils</span> <span class="kn">import</span> <span class="n">set_obj_attrs</span><span class="p">,</span> <span class="n">str_dtype_to_torch</span><span class="p">,</span> <span class="n">trt_dtype_to_np</span>
<span class="kn">from</span> <span class="nn">..functional</span> <span class="kn">import</span> <span class="n">constant</span><span class="p">,</span> <span class="n">embedding</span><span class="p">,</span> <span class="n">unsqueeze</span><span class="p">,</span> <span class="n">where</span>
<span class="kn">from</span> <span class="nn">..mapping</span> <span class="kn">import</span> <span class="n">Mapping</span>
<span class="kn">from</span> <span class="nn">..module</span> <span class="kn">import</span> <span class="n">Module</span>
<span class="kn">from</span> <span class="nn">..parameter</span> <span class="kn">import</span> <span class="n">Parameter</span>
<div class="viewcode-block" id="Embedding">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.layers.html#tensorrt_llm.layers.embedding.Embedding">[docs]</a>
<span class="k">class</span> <span class="nc">Embedding</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> The embedding layer takes input indices (x) and the embedding lookup table (weight) as input.</span>
<span class="sd"> And output the corresponding embeddings according to input indices.</span>
<span class="sd"> The size of weight is [num_embeddings, embedding_dim]</span>
<span class="sd"> Four parameters (tp_size, tp_group, sharding_dim, tp_rank) are involved in tensor parallelism.</span>
<span class="sd"> Only when &quot;tp_size &gt; 1 and tp_group is not None&quot;, tensor parallelism is enabled.</span>
<span class="sd"> When &quot;sharding_dim == 0&quot;, the weight is shared in the vocabulary dimension.</span>
<span class="sd"> tp_rank must be set when sharding_dim == 0.</span>
<span class="sd"> When &quot;sharding_dim == 1&quot;, the weight is shard in the hidden dimension.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">num_embeddings</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">embedding_dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">dtype</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">tp_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
<span class="n">tp_group</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">list</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">sharding_dim</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
<span class="n">tp_rank</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="c1"># num_embeddings records the total vocab size no matter using TP or not</span>
<span class="bp">self</span><span class="o">.</span><span class="n">num_embeddings</span> <span class="o">=</span> <span class="n">num_embeddings</span>
<span class="bp">self</span><span class="o">.</span><span class="n">embedding_dim</span> <span class="o">=</span> <span class="n">embedding_dim</span>
<span class="bp">self</span><span class="o">.</span><span class="n">tp_size</span> <span class="o">=</span> <span class="n">tp_size</span>
<span class="bp">self</span><span class="o">.</span><span class="n">tp_group</span> <span class="o">=</span> <span class="n">tp_group</span>
<span class="bp">self</span><span class="o">.</span><span class="n">sharding_dim</span> <span class="o">=</span> <span class="n">sharding_dim</span>
<span class="bp">self</span><span class="o">.</span><span class="n">tp_rank</span> <span class="o">=</span> <span class="n">tp_rank</span>
<span class="bp">self</span><span class="o">.</span><span class="n">dtype</span> <span class="o">=</span> <span class="n">dtype</span>
<span class="bp">self</span><span class="o">.</span><span class="n">tp_dim</span> <span class="o">=</span> <span class="n">sharding_dim</span>
<span class="k">if</span> <span class="n">sharding_dim</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
<span class="n">shape</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_embeddings</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">embedding_dim</span> <span class="o">//</span> <span class="bp">self</span><span class="o">.</span><span class="n">tp_size</span><span class="p">)</span>
<span class="k">elif</span> <span class="n">sharding_dim</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">math</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_embeddings</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">tp_size</span><span class="p">),</span>
<span class="bp">self</span><span class="o">.</span><span class="n">embedding_dim</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">weight</span> <span class="o">=</span> <span class="n">Parameter</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="n">shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">weight_padding_size</span> <span class="o">=</span> <span class="p">((</span><span class="mi">8</span> <span class="o">-</span> <span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">%</span> <span class="mi">8</span><span class="p">)</span> <span class="o">%</span> <span class="mi">8</span><span class="p">,</span> <span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
<span class="n">set_obj_attrs</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="p">,</span> <span class="p">{</span>
<span class="s2">&quot;weight_loader&quot;</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight_loader</span><span class="p">,</span>
<span class="p">})</span>
<div class="viewcode-block" id="Embedding.forward">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.layers.html#tensorrt_llm.layers.embedding.Embedding.forward">[docs]</a>
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
<span class="c1"># The embedding weight is padded to the multiple of 8.</span>
<span class="c1"># The reason is that when lm_head and vocab_embedding are using the same embedding weight,</span>
<span class="c1"># previously weights can&#39;t be depulicated in the engine because gemm will pad the weight to the multiple of 8.</span>
<span class="c1"># If we also pad the embedding weight to the multiple of 8, the weights can be successfully deduplicated.</span>
<span class="c1"># This will not affect the input and output of the gather op and perf impact is negligible.</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight_padding_size</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">padding_values</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">weight_padding_size</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">trt_dtype_to_np</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">value</span><span class="o">.</span><span class="n">dtype</span><span class="p">))</span>
<span class="n">padding</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">padding_values</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">padding</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">return</span> <span class="n">embedding</span><span class="p">(</span><span class="n">x</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">value</span><span class="p">,</span>
<span class="n">tp_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">tp_size</span><span class="p">,</span>
<span class="n">tp_group</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">tp_group</span><span class="p">,</span>
<span class="n">sharding_dim</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">sharding_dim</span><span class="p">,</span>
<span class="n">tp_rank</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">tp_rank</span><span class="p">,</span>
<span class="n">padding</span><span class="o">=</span><span class="n">padding</span><span class="p">)</span></div>
<div class="viewcode-block" id="Embedding.weight_loader">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.layers.html#tensorrt_llm.layers.embedding.Embedding.weight_loader">[docs]</a>
<span class="k">def</span> <span class="nf">weight_loader</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">mapping</span><span class="p">:</span> <span class="n">Mapping</span><span class="p">,</span> <span class="n">param</span><span class="p">:</span> <span class="n">Parameter</span><span class="p">,</span>
<span class="n">loaded_weight</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="c1"># use_parallel_embedding</span>
<span class="n">tp_rank</span> <span class="o">=</span> <span class="n">mapping</span><span class="o">.</span><span class="n">tp_rank</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">tp_size</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
<span class="n">sharding_dim</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sharding_dim</span>
<span class="n">shard_size</span> <span class="o">=</span> <span class="n">param</span><span class="o">.</span><span class="n">_shape</span><span class="p">[</span><span class="n">sharding_dim</span><span class="p">]</span>
<span class="n">start_idx</span> <span class="o">=</span> <span class="n">tp_rank</span> <span class="o">*</span> <span class="n">shard_size</span>
<span class="n">loaded_weight</span> <span class="o">=</span> <span class="n">loaded_weight</span><span class="o">.</span><span class="n">narrow</span><span class="p">(</span><span class="n">sharding_dim</span><span class="p">,</span> <span class="n">start_idx</span><span class="p">,</span>
<span class="n">shard_size</span><span class="p">)</span>
<span class="n">param</span><span class="o">.</span><span class="n">value</span> <span class="o">=</span> <span class="n">loaded_weight</span></div>
<div class="viewcode-block" id="Embedding.postprocess">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.layers.html#tensorrt_llm.layers.embedding.Embedding.postprocess">[docs]</a>
<span class="k">def</span> <span class="nf">postprocess</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">tllm_key</span><span class="p">,</span> <span class="n">weights</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="k">if</span> <span class="n">weights</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">return</span> <span class="p">{}</span>
<span class="n">weights</span> <span class="o">=</span> <span class="n">weights</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">str_dtype_to_torch</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dtype</span><span class="p">))</span>
<span class="k">return</span> <span class="p">{</span><span class="n">tllm_key</span><span class="p">:</span> <span class="n">weights</span><span class="p">}</span></div>
</div>
<div class="viewcode-block" id="PromptTuningEmbedding">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.layers.html#tensorrt_llm.layers.embedding.PromptTuningEmbedding">[docs]</a>
<span class="k">class</span> <span class="nc">PromptTuningEmbedding</span><span class="p">(</span><span class="n">Embedding</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> PromptTuningEmbedding handles fine-tuned prompts with virtual tokens. At runtime,</span>
<span class="sd"> a supplementary embedding dictionary is passed. Tokens whose ids are &gt;= vocab_size are embedded</span>
<span class="sd"> with that additional dictionary.</span>
<span class="sd"> The prompt tuning dictionary holds multiple tasks, and each sequence is assigned a given task.</span>
<span class="sd"> Prompt-tuned tokens from a given sequence use the adequate task dictionary, as defined by the `tasks` input.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">num_embeddings</span><span class="p">,</span>
<span class="n">embedding_dim</span><span class="p">,</span>
<span class="n">vocab_size</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">tp_size</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
<span class="n">tp_group</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">sharding_dim</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
<span class="n">tp_rank</span><span class="o">=</span><span class="mi">0</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">num_embeddings</span><span class="p">,</span> <span class="n">embedding_dim</span><span class="p">,</span> <span class="n">dtype</span><span class="p">,</span> <span class="n">tp_size</span><span class="p">,</span>
<span class="n">tp_group</span><span class="p">,</span> <span class="n">sharding_dim</span><span class="p">,</span> <span class="n">tp_rank</span><span class="p">)</span>
<span class="k">if</span> <span class="n">vocab_size</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">vocab_size</span> <span class="o">=</span> <span class="n">num_embeddings</span>
<span class="bp">self</span><span class="o">.</span><span class="n">vocab_size</span> <span class="o">=</span> <span class="n">vocab_size</span>
<div class="viewcode-block" id="PromptTuningEmbedding.forward">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.layers.html#tensorrt_llm.layers.embedding.PromptTuningEmbedding.forward">[docs]</a>
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">tokens</span><span class="p">,</span> <span class="n">prompt_embedding_table</span><span class="p">,</span> <span class="n">tasks</span><span class="p">,</span> <span class="n">task_vocab_size</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Pass all tokens through both normal and prompt embedding tables.</span>
<span class="sd"> Tokens are masked so that &quot;normal&quot; embedding only see &quot;normal&quot; tokens. Same logic for &quot;prompt&quot; embedding.</span>
<span class="sd"> After those two embedding, combine results based on whether the token was &quot;normal&quot; or &quot;prompt-tuned&quot;.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> tokens : Tensor</span>
<span class="sd"> the ids to embed, size [batch_size, seq_len]</span>
<span class="sd"> prompt_embedding_table : Tensor</span>
<span class="sd"> the additional embedding table for prompt-tuned tokens, size [num_tasks * num_tokens_per_task, hidden_size]</span>
<span class="sd"> tasks: Tensor</span>
<span class="sd"> the task required by each token, size [batch_size, seq_len]</span>
<span class="sd"> task_vocab_size: Tensor</span>
<span class="sd"> the number of tokens used for each task, should be equal to prompt_embedding_table&#39;s num_tokens_per_task, size [1]</span>
<span class="sd"> Returns:</span>
<span class="sd"> Tokens&#39; embedding</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="c1"># do not use &quot;&gt;=&quot; because internally the layer works with floating points</span>
<span class="n">prompt_tokens_mask</span> <span class="o">=</span> <span class="n">tokens</span> <span class="o">&gt;</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">vocab_size</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span>
<span class="c1"># clip tokens in the [0, vocab_size) range</span>
<span class="n">normal_tokens</span> <span class="o">=</span> <span class="n">where</span><span class="p">(</span><span class="n">prompt_tokens_mask</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">vocab_size</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span> <span class="n">tokens</span><span class="p">)</span>
<span class="n">normal_embeddings</span> <span class="o">=</span> <span class="n">embedding</span><span class="p">(</span><span class="n">normal_tokens</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">value</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">tp_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">tp_group</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">sharding_dim</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">tp_rank</span><span class="p">)</span>
<span class="c1"># put virtual tokens in the [0, max_prompt_vocab_size) range</span>
<span class="n">prompt_tokens</span> <span class="o">=</span> <span class="n">where</span><span class="p">(</span><span class="n">prompt_tokens_mask</span><span class="p">,</span> <span class="n">tokens</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">vocab_size</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
<span class="c1"># add offsets to match the concatenated embedding tables</span>
<span class="n">tasks</span> <span class="o">=</span> <span class="n">tasks</span> <span class="o">*</span> <span class="n">task_vocab_size</span>
<span class="c1"># tasks: [batch_size, seq_len]</span>
<span class="c1"># prompt_tokens: [batch_size, seq_len]</span>
<span class="n">prompt_tokens</span> <span class="o">=</span> <span class="n">prompt_tokens</span> <span class="o">+</span> <span class="n">tasks</span>
<span class="n">prompt_embeddings</span> <span class="o">=</span> <span class="n">embedding</span><span class="p">(</span><span class="n">prompt_tokens</span><span class="p">,</span> <span class="n">prompt_embedding_table</span><span class="p">)</span>
<span class="c1"># prompt_tokens_mask: [batch_size, seq_len] -&gt; [batch_size, seq_len, 1]</span>
<span class="c1"># combine the correct sources of embedding: normal/prompt</span>
<span class="k">return</span> <span class="n">where</span><span class="p">(</span><span class="n">unsqueeze</span><span class="p">(</span><span class="n">prompt_tokens_mask</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">),</span> <span class="n">prompt_embeddings</span><span class="p">,</span>
<span class="n">normal_embeddings</span><span class="p">)</span></div>
</div>
</pre></div>
</div>
</div>
<footer>
<hr/>
<div role="contentinfo">
<jinja2.runtime.BlockReference object at 0x7fed980d89b0>
<div class="footer">
<p>
Copyright © 2024 NVIDIA Corporation
</p>
<p>
<a class="Link" href="https://www.nvidia.com/en-us/about-nvidia/privacy-policy/" target="_blank" rel="noopener"
data-cms-ai="0">Privacy Policy</a> |
<a class="Link" href="https://www.nvidia.com/en-us/about-nvidia/privacy-center/" target="_blank" rel="noopener"
data-cms-ai="0">Manage My Privacy</a> |
<a class="Link" href="https://www.nvidia.com/en-us/preferences/start/" target="_blank" rel="noopener"
data-cms-ai="0">Do Not Sell or Share My Data</a> |
<a class="Link" href="https://www.nvidia.com/en-us/about-nvidia/terms-of-service/" target="_blank"
rel="noopener" data-cms-ai="0">Terms of Service</a> |
<a class="Link" href="https://www.nvidia.com/en-us/about-nvidia/accessibility/" target="_blank" rel="noopener"
data-cms-ai="0">Accessibility</a> |
<a class="Link" href="https://www.nvidia.com/en-us/about-nvidia/company-policies/" target="_blank"
rel="noopener" data-cms-ai="0">Corporate Policies</a> |
<a class="Link" href="https://www.nvidia.com/en-us/product-security/" target="_blank" rel="noopener"
data-cms-ai="0">Product Security</a> |
<a class="Link" href="https://www.nvidia.com/en-us/contact/" target="_blank" rel="noopener"
data-cms-ai="0">Contact</a>
</p>
</div>
</div>
</footer>
</div>
</div>
</section>
</div>
<script>
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
</script>
</body>
</html>