mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
666 lines
79 KiB
HTML
666 lines
79 KiB
HTML
<!DOCTYPE html>
|
|
<html class="writer-html5" lang="en" data-content_root="../../../../">
|
|
<head>
|
|
<meta charset="utf-8" />
|
|
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
|
<title>tensorrt_llm.models.llama.model — tensorrt_llm documentation</title>
|
|
<link rel="stylesheet" type="text/css" href="../../../../_static/pygments.css?v=80d5e7a1" />
|
|
<link rel="stylesheet" type="text/css" href="../../../../_static/css/theme.css?v=19f00094" />
|
|
|
|
|
|
<!--[if lt IE 9]>
|
|
<script src="../../../../_static/js/html5shiv.min.js"></script>
|
|
<![endif]-->
|
|
|
|
<script src="../../../../_static/jquery.js?v=5d32c60e"></script>
|
|
<script src="../../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
|
<script src="../../../../_static/documentation_options.js?v=5929fcd5"></script>
|
|
<script src="../../../../_static/doctools.js?v=888ff710"></script>
|
|
<script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
|
<script src="../../../../_static/js/theme.js"></script>
|
|
<link rel="index" title="Index" href="../../../../genindex.html" />
|
|
<link rel="search" title="Search" href="../../../../search.html" />
|
|
</head>
|
|
|
|
<body class="wy-body-for-nav">
|
|
<div class="wy-grid-for-nav">
|
|
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
|
|
<div class="wy-side-scroll">
|
|
<div class="wy-side-nav-search" >
|
|
|
|
|
|
|
|
<a href="../../../../index.html" class="icon icon-home">
|
|
tensorrt_llm
|
|
</a>
|
|
<div role="search">
|
|
<form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
|
|
<input type="text" name="q" placeholder="Search docs" aria-label="Search docs" />
|
|
<input type="hidden" name="check_keywords" value="yes" />
|
|
<input type="hidden" name="area" value="default" />
|
|
</form>
|
|
</div>
|
|
</div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
|
|
<p class="caption" role="heading"><span class="caption-text">Contents:</span></p>
|
|
<ul>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../../architecture.html">TensorRT-LLM Architecture</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../../gpt_runtime.html">C++ GPT Runtime</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../../batch_manager.html">The Batch Manager in TensorRT-LLM</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../../gpt_attention.html">Multi-head, Multi-query and Group-query Attention</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../../precision.html">Numerical Precision</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../../installation.html">TensorRT-LLM Installation</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../../performance.html">Performance of TensorRT-LLM</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../../2023-05-19-how-to-debug.html">How to debug</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../../2023-05-17-how-to-add-a-new-model.html">How to add a new model</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../../graph-rewriting.html">Graph Rewriting Module</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../../memory.html">Memory Usage of TensorRT-LLM</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../../new_workflow.html">New Workflow</a></li>
|
|
</ul>
|
|
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
|
|
<ul>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../../python-api/tensorrt_llm.layers.html">Layers</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../../python-api/tensorrt_llm.functional.html">Functionals</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../../python-api/tensorrt_llm.models.html">Models</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../../python-api/tensorrt_llm.plugin.html">Plugin</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../../python-api/tensorrt_llm.quantization.html">Quantization</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../../python-api/tensorrt_llm.runtime.html">Runtime</a></li>
|
|
</ul>
|
|
<p class="caption" role="heading"><span class="caption-text">C++ API</span></p>
|
|
<ul>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../../_cpp_gen/runtime.html">Runtime</a></li>
|
|
</ul>
|
|
<p class="caption" role="heading"><span class="caption-text">Blogs</span></p>
|
|
<ul>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../../blogs/H100vsA100.html">H100 has 4.6x A100 Performance in TensorRT-LLM, achieving 10,000 tok/s at 100ms to first token</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../../blogs/H200launch.html">H200 achieves nearly 12,000 tokens/sec on Llama2-13B with TensorRT-LLM</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../../blogs/Falcon180B-H200.html">Falcon-180B on a single H200 GPU with INT4 AWQ, and 6.7x faster Llama-70B over A100</a></li>
|
|
</ul>
|
|
|
|
</div>
|
|
</div>
|
|
</nav>
|
|
|
|
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
|
|
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
|
|
<a href="../../../../index.html">tensorrt_llm</a>
|
|
</nav>
|
|
|
|
<div class="wy-nav-content">
|
|
<div class="rst-content">
|
|
<div role="navigation" aria-label="Page navigation">
|
|
<ul class="wy-breadcrumbs">
|
|
<li><a href="../../../../index.html" class="icon icon-home" aria-label="Home"></a></li>
|
|
<li class="breadcrumb-item"><a href="../../../index.html">Module code</a></li>
|
|
<li class="breadcrumb-item active">tensorrt_llm.models.llama.model</li>
|
|
<li class="wy-breadcrumbs-aside">
|
|
</li>
|
|
</ul>
|
|
<hr/>
|
|
</div>
|
|
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
|
|
<div itemprop="articleBody">
|
|
|
|
<h1>Source code for tensorrt_llm.models.llama.model</h1><div class="highlight"><pre>
|
|
<span></span><span class="c1"># SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.</span>
|
|
<span class="c1"># SPDX-License-Identifier: Apache-2.0</span>
|
|
<span class="c1">#</span>
|
|
<span class="c1"># Licensed under the Apache License, Version 2.0 (the "License");</span>
|
|
<span class="c1"># you may not use this file except in compliance with the License.</span>
|
|
<span class="c1"># You may obtain a copy of the License at</span>
|
|
<span class="c1">#</span>
|
|
<span class="c1"># http://www.apache.org/licenses/LICENSE-2.0</span>
|
|
<span class="c1">#</span>
|
|
<span class="c1"># Unless required by applicable law or agreed to in writing, software</span>
|
|
<span class="c1"># distributed under the License is distributed on an "AS IS" BASIS,</span>
|
|
<span class="c1"># WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.</span>
|
|
<span class="c1"># See the License for the specific language governing permissions and</span>
|
|
<span class="c1"># limitations under the License.</span>
|
|
<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">List</span><span class="p">,</span> <span class="n">Optional</span>
|
|
|
|
<span class="kn">import</span> <span class="nn">tensorrt</span> <span class="k">as</span> <span class="nn">trt</span>
|
|
|
|
<span class="kn">from</span> <span class="nn">..._common</span> <span class="kn">import</span> <span class="n">default_net</span>
|
|
<span class="kn">from</span> <span class="nn">..._utils</span> <span class="kn">import</span> <span class="n">pad_vocab_size</span><span class="p">,</span> <span class="n">str_dtype_to_trt</span>
|
|
<span class="kn">from</span> <span class="nn">...functional</span> <span class="kn">import</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">gather_last_token_logits</span><span class="p">,</span> <span class="n">recv</span><span class="p">,</span> <span class="n">send</span>
|
|
<span class="kn">from</span> <span class="nn">...layers</span> <span class="kn">import</span> <span class="p">(</span><span class="n">MOE</span><span class="p">,</span> <span class="n">Attention</span><span class="p">,</span> <span class="n">AttentionMaskType</span><span class="p">,</span> <span class="n">AttentionParams</span><span class="p">,</span>
|
|
<span class="n">ColumnLinear</span><span class="p">,</span> <span class="n">Embedding</span><span class="p">,</span> <span class="n">FusedGatedMLP</span><span class="p">,</span> <span class="n">GatedMLP</span><span class="p">,</span>
|
|
<span class="n">KeyValueCacheParams</span><span class="p">,</span> <span class="n">LoraParams</span><span class="p">,</span> <span class="n">MoeConfig</span><span class="p">,</span>
|
|
<span class="n">PositionEmbeddingType</span><span class="p">,</span> <span class="n">PromptTuningEmbedding</span><span class="p">,</span> <span class="n">RmsNorm</span><span class="p">)</span>
|
|
<span class="kn">from</span> <span class="nn">...mapping</span> <span class="kn">import</span> <span class="n">Mapping</span>
|
|
<span class="kn">from</span> <span class="nn">...module</span> <span class="kn">import</span> <span class="n">Module</span><span class="p">,</span> <span class="n">ModuleList</span>
|
|
<span class="kn">from</span> <span class="nn">...quantization</span> <span class="kn">import</span> <span class="n">QuantMode</span>
|
|
<span class="kn">from</span> <span class="nn">..generation_mixin</span> <span class="kn">import</span> <span class="n">GenerationMixin</span>
|
|
|
|
|
|
<span class="k">class</span> <span class="nc">LLaMADecoderLayer</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span>
|
|
|
|
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
|
|
<span class="n">layer_id</span><span class="p">,</span>
|
|
<span class="n">hidden_size</span><span class="p">,</span>
|
|
<span class="n">num_attention_heads</span><span class="p">,</span>
|
|
<span class="n">num_kv_heads</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">max_position_embeddings</span><span class="o">=</span><span class="mi">2048</span><span class="p">,</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">attention_mask_type</span><span class="o">=</span><span class="n">AttentionMaskType</span><span class="o">.</span><span class="n">causal</span><span class="p">,</span>
|
|
<span class="n">hidden_act</span><span class="o">=</span><span class="s1">'silu'</span><span class="p">,</span>
|
|
<span class="n">position_embedding_type</span><span class="o">=</span><span class="n">PositionEmbeddingType</span><span class="o">.</span><span class="n">rope_gpt_neox</span><span class="p">,</span>
|
|
<span class="n">rotary_base</span><span class="o">=</span><span class="mf">10000.0</span><span class="p">,</span>
|
|
<span class="n">rotary_scaling</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">mlp_hidden_size</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">tp_group</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">tp_size</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
|
|
<span class="n">tp_rank</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
|
|
<span class="n">quant_mode</span><span class="o">=</span><span class="n">QuantMode</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
|
|
<span class="n">rms_norm_eps</span><span class="o">=</span><span class="mf">1e-06</span><span class="p">,</span>
|
|
<span class="n">attn_bias</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
|
<span class="n">mlp_bias</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
|
<span class="n">use_fused_mlp</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
|
<span class="n">moe_config</span><span class="p">:</span> <span class="n">MoeConfig</span> <span class="o">=</span> <span class="n">MoeConfig</span><span class="p">()):</span>
|
|
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">_layer_id</span> <span class="o">=</span> <span class="n">layer_id</span> <span class="c1"># useful for debugging</span>
|
|
<span class="c1"># used for quantizing model</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span> <span class="o">=</span> <span class="n">hidden_size</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">num_attention_heads</span> <span class="o">=</span> <span class="n">num_attention_heads</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">num_kv_heads</span> <span class="o">=</span> <span class="n">num_kv_heads</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">max_position_embeddings</span> <span class="o">=</span> <span class="n">max_position_embeddings</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">dtype</span> <span class="o">=</span> <span class="n">dtype</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">hidden_act</span> <span class="o">=</span> <span class="n">hidden_act</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">tp_group</span> <span class="o">=</span> <span class="n">tp_group</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">tp_size</span> <span class="o">=</span> <span class="n">tp_size</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">mlp_hidden_size</span> <span class="o">=</span> <span class="n">mlp_hidden_size</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">attention_mask_type</span> <span class="o">=</span> <span class="n">attention_mask_type</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">position_embedding_type</span> <span class="o">=</span> <span class="n">position_embedding_type</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">input_layernorm</span> <span class="o">=</span> <span class="n">RmsNorm</span><span class="p">(</span><span class="n">normalized_shape</span><span class="o">=</span><span class="n">hidden_size</span><span class="p">,</span>
|
|
<span class="n">eps</span><span class="o">=</span><span class="n">rms_norm_eps</span><span class="p">,</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
|
|
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">attention</span> <span class="o">=</span> <span class="n">Attention</span><span class="p">(</span>
|
|
<span class="n">hidden_size</span><span class="p">,</span>
|
|
<span class="n">num_attention_heads</span><span class="p">,</span>
|
|
<span class="n">num_kv_heads</span><span class="p">,</span>
|
|
<span class="n">max_position_embeddings</span><span class="p">,</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span>
|
|
<span class="n">attention_mask_type</span><span class="o">=</span><span class="n">AttentionMaskType</span><span class="o">.</span><span class="n">causal</span><span class="p">,</span>
|
|
<span class="n">bias</span><span class="o">=</span><span class="n">attn_bias</span><span class="p">,</span>
|
|
<span class="n">position_embedding_type</span><span class="o">=</span><span class="n">position_embedding_type</span><span class="p">,</span>
|
|
<span class="n">rotary_embedding_base</span><span class="o">=</span><span class="n">rotary_base</span><span class="p">,</span>
|
|
<span class="n">rotary_embedding_scaling</span><span class="o">=</span><span class="n">rotary_scaling</span><span class="p">,</span>
|
|
<span class="n">tp_group</span><span class="o">=</span><span class="n">tp_group</span><span class="p">,</span>
|
|
<span class="n">tp_size</span><span class="o">=</span><span class="n">tp_size</span><span class="p">,</span>
|
|
<span class="n">quant_mode</span><span class="o">=</span><span class="n">quant_mode</span><span class="p">,</span>
|
|
<span class="n">instance_id</span><span class="o">=</span><span class="mi">2</span> <span class="o">*</span> <span class="n">layer_id</span><span class="p">,</span>
|
|
<span class="p">)</span>
|
|
<span class="k">if</span> <span class="ow">not</span> <span class="n">mlp_hidden_size</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">mlp_hidden_size</span> <span class="o">=</span> <span class="n">hidden_size</span> <span class="o">*</span> <span class="mi">4</span>
|
|
|
|
<span class="n">ClsMLP</span> <span class="o">=</span> <span class="n">GatedMLP</span>
|
|
<span class="n">mlp_kwargs</span> <span class="o">=</span> <span class="p">{}</span>
|
|
<span class="k">if</span> <span class="n">moe_config</span><span class="o">.</span><span class="n">has_moe</span><span class="p">():</span>
|
|
<span class="n">ClsMLP</span> <span class="o">=</span> <span class="n">MOE</span>
|
|
<span class="n">mlp_kwargs</span> <span class="o">=</span> <span class="p">{</span>
|
|
<span class="s2">"moe_config"</span><span class="p">:</span> <span class="n">moe_config</span><span class="p">,</span>
|
|
<span class="s2">"tp_rank"</span><span class="p">:</span> <span class="n">tp_rank</span><span class="p">,</span>
|
|
<span class="p">}</span>
|
|
<span class="k">elif</span> <span class="n">use_fused_mlp</span><span class="p">:</span>
|
|
<span class="n">ClsMLP</span> <span class="o">=</span> <span class="n">FusedGatedMLP</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">mlp</span> <span class="o">=</span> <span class="n">ClsMLP</span><span class="p">(</span><span class="n">hidden_size</span><span class="o">=</span><span class="n">hidden_size</span><span class="p">,</span>
|
|
<span class="n">ffn_hidden_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">mlp_hidden_size</span><span class="p">,</span>
|
|
<span class="n">hidden_act</span><span class="o">=</span><span class="n">hidden_act</span><span class="p">,</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span>
|
|
<span class="n">bias</span><span class="o">=</span><span class="n">mlp_bias</span><span class="p">,</span>
|
|
<span class="n">tp_group</span><span class="o">=</span><span class="n">tp_group</span><span class="p">,</span>
|
|
<span class="n">tp_size</span><span class="o">=</span><span class="n">tp_size</span><span class="p">,</span>
|
|
<span class="n">quant_mode</span><span class="o">=</span><span class="n">quant_mode</span><span class="p">,</span>
|
|
<span class="n">instance_id</span><span class="o">=</span><span class="mi">2</span> <span class="o">*</span> <span class="n">layer_id</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span>
|
|
<span class="o">**</span><span class="n">mlp_kwargs</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">post_layernorm</span> <span class="o">=</span> <span class="n">RmsNorm</span><span class="p">(</span><span class="n">normalized_shape</span><span class="o">=</span><span class="n">hidden_size</span><span class="p">,</span>
|
|
<span class="n">eps</span><span class="o">=</span><span class="n">rms_norm_eps</span><span class="p">,</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
|
|
|
|
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
|
|
<span class="n">hidden_states</span><span class="p">,</span>
|
|
<span class="n">attention_mask</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">use_cache</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
|
<span class="n">kv_cache_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">attention_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">all_reduce_workspace</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">lora_layer_params</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
|
<span class="n">residual</span> <span class="o">=</span> <span class="n">hidden_states</span>
|
|
<span class="n">hidden_states</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">input_layernorm</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_layer_id</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">register_network_output</span><span class="p">(</span><span class="sa">f</span><span class="s2">"norm0"</span><span class="p">,</span> <span class="n">hidden_states</span><span class="p">)</span>
|
|
|
|
<span class="n">attention_output</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">attention</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span>
|
|
<span class="n">attention_mask</span><span class="o">=</span><span class="n">attention_mask</span><span class="p">,</span>
|
|
<span class="n">use_cache</span><span class="o">=</span><span class="n">use_cache</span><span class="p">,</span>
|
|
<span class="n">kv_cache_params</span><span class="o">=</span><span class="n">kv_cache_params</span><span class="p">,</span>
|
|
<span class="n">attention_params</span><span class="o">=</span><span class="n">attention_params</span><span class="p">,</span>
|
|
<span class="n">workspace</span><span class="o">=</span><span class="n">all_reduce_workspace</span><span class="p">,</span>
|
|
<span class="n">lora_layer_params</span><span class="o">=</span><span class="n">lora_layer_params</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="n">use_cache</span><span class="p">:</span>
|
|
<span class="n">attention_output</span><span class="p">,</span> <span class="n">presents</span> <span class="o">=</span> <span class="n">attention_output</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_layer_id</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">register_network_output</span><span class="p">(</span><span class="sa">f</span><span class="s2">"attn"</span><span class="p">,</span> <span class="n">attention_output</span><span class="p">)</span>
|
|
|
|
<span class="n">hidden_states</span> <span class="o">=</span> <span class="n">residual</span> <span class="o">+</span> <span class="n">attention_output</span>
|
|
|
|
<span class="n">residual</span> <span class="o">=</span> <span class="n">hidden_states</span>
|
|
<span class="n">hidden_states</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">post_layernorm</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_layer_id</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">register_network_output</span><span class="p">(</span><span class="sa">f</span><span class="s2">"norm1"</span><span class="p">,</span> <span class="n">hidden_states</span><span class="p">)</span>
|
|
|
|
<span class="n">hidden_states</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">mlp</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span>
|
|
<span class="n">all_reduce_workspace</span><span class="p">,</span>
|
|
<span class="n">lora_layer_params</span><span class="o">=</span><span class="n">lora_layer_params</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_layer_id</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">register_network_output</span><span class="p">(</span><span class="sa">f</span><span class="s2">"mlp"</span><span class="p">,</span> <span class="n">hidden_states</span><span class="p">)</span>
|
|
|
|
<span class="n">hidden_states</span> <span class="o">=</span> <span class="n">residual</span> <span class="o">+</span> <span class="n">hidden_states</span>
|
|
<span class="k">if</span> <span class="n">use_cache</span><span class="p">:</span>
|
|
<span class="k">return</span> <span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span> <span class="n">presents</span><span class="p">)</span>
|
|
<span class="k">return</span> <span class="n">hidden_states</span>
|
|
|
|
|
|
<div class="viewcode-block" id="LLaMAModel">
|
|
<a class="viewcode-back" href="../../../../python-api/tensorrt_llm.models.html#tensorrt_llm.models.LLaMAModel">[docs]</a>
|
|
<span class="k">class</span> <span class="nc">LLaMAModel</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span>
|
|
|
|
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
|
|
<span class="n">num_layers</span><span class="p">,</span>
|
|
<span class="n">num_heads</span><span class="p">,</span>
|
|
<span class="n">num_kv_heads</span><span class="p">,</span>
|
|
<span class="n">hidden_size</span><span class="p">,</span>
|
|
<span class="n">vocab_size</span><span class="p">,</span>
|
|
<span class="n">hidden_act</span><span class="p">,</span>
|
|
<span class="n">max_position_embeddings</span><span class="p">,</span>
|
|
<span class="n">dtype</span><span class="p">,</span>
|
|
<span class="n">mlp_hidden_size</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">position_embedding_type</span><span class="o">=</span><span class="n">PositionEmbeddingType</span><span class="o">.</span><span class="n">rope_gpt_neox</span><span class="p">,</span>
|
|
<span class="n">rotary_base</span><span class="o">=</span><span class="mf">10000.0</span><span class="p">,</span>
|
|
<span class="n">rotary_scaling</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">mapping</span><span class="o">=</span><span class="n">Mapping</span><span class="p">(),</span>
|
|
<span class="n">quant_mode</span><span class="o">=</span><span class="n">QuantMode</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
|
|
<span class="n">use_parallel_embedding</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
|
<span class="n">embedding_sharding_dim</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
|
|
<span class="n">rms_norm_eps</span><span class="o">=</span><span class="mf">1e-06</span><span class="p">,</span>
|
|
<span class="n">use_fused_mlp</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
|
<span class="n">attn_bias</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
|
<span class="n">mlp_bias</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
|
<span class="n">moe_config</span><span class="p">:</span> <span class="n">MoeConfig</span> <span class="o">=</span> <span class="n">MoeConfig</span><span class="p">(),</span>
|
|
<span class="n">use_prompt_tuning</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">):</span>
|
|
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">mapping</span> <span class="o">=</span> <span class="n">mapping</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">use_prompt_tuning</span> <span class="o">=</span> <span class="n">use_prompt_tuning</span>
|
|
|
|
<span class="n">EmbeddingCls</span> <span class="o">=</span> <span class="n">PromptTuningEmbedding</span> <span class="k">if</span> <span class="n">use_prompt_tuning</span> <span class="k">else</span> <span class="n">Embedding</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_first_pp_rank</span><span class="p">():</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">vocab_embedding</span> <span class="o">=</span> <span class="n">EmbeddingCls</span><span class="p">(</span>
|
|
<span class="n">num_embeddings</span><span class="o">=</span><span class="n">vocab_size</span><span class="p">,</span>
|
|
<span class="n">embedding_dim</span><span class="o">=</span><span class="n">hidden_size</span><span class="p">,</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span>
|
|
<span class="n">tp_size</span><span class="o">=</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span> <span class="k">if</span> <span class="n">use_parallel_embedding</span> <span class="k">else</span> <span class="mi">1</span><span class="p">,</span>
|
|
<span class="n">tp_group</span><span class="o">=</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_group</span> <span class="k">if</span> <span class="n">use_parallel_embedding</span> <span class="k">else</span> <span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">sharding_dim</span><span class="o">=</span><span class="n">embedding_sharding_dim</span><span class="p">,</span>
|
|
<span class="n">tp_rank</span><span class="o">=</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_rank</span><span class="p">,</span>
|
|
<span class="n">instance_id</span><span class="o">=</span><span class="mi">2</span> <span class="o">*</span>
|
|
<span class="n">num_layers</span><span class="p">,</span> <span class="c1"># ids in [0, 2 * (num_layers - 1) + 1] already used</span>
|
|
<span class="p">)</span>
|
|
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">layers</span> <span class="o">=</span> <span class="n">ModuleList</span><span class="p">([</span>
|
|
<span class="n">LLaMADecoderLayer</span><span class="p">(</span>
|
|
<span class="n">layer_id</span><span class="o">=</span><span class="n">i</span><span class="p">,</span>
|
|
<span class="n">hidden_size</span><span class="o">=</span><span class="n">hidden_size</span><span class="p">,</span>
|
|
<span class="n">num_attention_heads</span><span class="o">=</span><span class="n">num_heads</span><span class="p">,</span>
|
|
<span class="n">num_kv_heads</span><span class="o">=</span><span class="n">num_kv_heads</span><span class="p">,</span>
|
|
<span class="n">max_position_embeddings</span><span class="o">=</span><span class="n">max_position_embeddings</span><span class="p">,</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span>
|
|
<span class="n">hidden_act</span><span class="o">=</span><span class="n">hidden_act</span><span class="p">,</span>
|
|
<span class="n">mlp_hidden_size</span><span class="o">=</span><span class="n">mlp_hidden_size</span><span class="p">,</span>
|
|
<span class="n">position_embedding_type</span><span class="o">=</span><span class="n">position_embedding_type</span><span class="p">,</span>
|
|
<span class="n">rotary_base</span><span class="o">=</span><span class="n">rotary_base</span><span class="p">,</span>
|
|
<span class="n">rotary_scaling</span><span class="o">=</span><span class="n">rotary_scaling</span><span class="p">,</span>
|
|
<span class="n">tp_group</span><span class="o">=</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_group</span><span class="p">,</span>
|
|
<span class="n">tp_size</span><span class="o">=</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span><span class="p">,</span>
|
|
<span class="n">tp_rank</span><span class="o">=</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_rank</span><span class="p">,</span>
|
|
<span class="n">quant_mode</span><span class="o">=</span><span class="n">quant_mode</span><span class="p">,</span>
|
|
<span class="n">rms_norm_eps</span><span class="o">=</span><span class="n">rms_norm_eps</span><span class="p">,</span>
|
|
<span class="n">attn_bias</span><span class="o">=</span><span class="n">attn_bias</span><span class="p">,</span>
|
|
<span class="n">mlp_bias</span><span class="o">=</span><span class="n">mlp_bias</span><span class="p">,</span>
|
|
<span class="n">use_fused_mlp</span><span class="o">=</span><span class="n">use_fused_mlp</span><span class="p">,</span>
|
|
<span class="n">moe_config</span><span class="o">=</span><span class="n">moe_config</span><span class="p">,</span>
|
|
<span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_layers</span><span class="p">(</span><span class="n">num_layers</span><span class="p">)</span>
|
|
<span class="p">])</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">ln_f</span> <span class="o">=</span> <span class="n">RmsNorm</span><span class="p">(</span><span class="n">normalized_shape</span><span class="o">=</span><span class="n">hidden_size</span><span class="p">,</span>
|
|
<span class="n">eps</span><span class="o">=</span><span class="n">rms_norm_eps</span><span class="p">,</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
|
|
|
|
<div class="viewcode-block" id="LLaMAModel.forward">
|
|
<a class="viewcode-back" href="../../../../python-api/tensorrt_llm.models.html#tensorrt_llm.models.LLaMAModel.forward">[docs]</a>
|
|
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
|
|
<span class="n">input_ids</span><span class="p">,</span>
|
|
<span class="n">position_ids</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">use_cache</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
|
<span class="n">attention_mask</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">kv_cache_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">attention_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">hidden_states</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">all_reduce_workspace</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">prompt_embedding_table</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">prompt_tasks</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">prompt_vocab_size</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">lora_params</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
|
|
|
<span class="n">kv_cache_params</span><span class="o">.</span><span class="n">fill_none_tensor_list</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">))</span>
|
|
|
|
<span class="k">if</span> <span class="n">use_cache</span><span class="p">:</span>
|
|
<span class="n">presents</span> <span class="o">=</span> <span class="p">[]</span>
|
|
|
|
<span class="n">ptuning_args</span> <span class="o">=</span> <span class="p">[]</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_prompt_tuning</span><span class="p">:</span>
|
|
<span class="n">ptuning_args</span> <span class="o">=</span> <span class="p">[</span>
|
|
<span class="n">prompt_embedding_table</span><span class="p">,</span> <span class="n">prompt_tasks</span><span class="p">,</span> <span class="n">prompt_vocab_size</span>
|
|
<span class="p">]</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_first_pp_rank</span><span class="p">():</span>
|
|
<span class="n">hidden_states</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">vocab_embedding</span><span class="p">(</span><span class="n">input_ids</span><span class="p">,</span> <span class="o">*</span><span class="n">ptuning_args</span><span class="p">,</span>
|
|
<span class="n">all_reduce_workspace</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="n">hidden_states</span> <span class="o">=</span> <span class="n">recv</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">prev_pp_rank</span><span class="p">())</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">register_network_output</span><span class="p">(</span><span class="sa">f</span><span class="s2">"embd"</span><span class="p">,</span> <span class="n">hidden_states</span><span class="p">)</span>
|
|
|
|
<span class="k">for</span> <span class="n">layer_idx</span><span class="p">,</span> <span class="p">(</span>
|
|
<span class="n">layer</span><span class="p">,</span> <span class="n">past</span><span class="p">,</span> <span class="n">pointer</span><span class="p">,</span> <span class="n">host_pointer</span><span class="p">,</span>
|
|
<span class="n">max_attention_window_size</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span>
|
|
<span class="nb">zip</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">,</span> <span class="n">kv_cache_params</span><span class="o">.</span><span class="n">past_key_value</span><span class="p">,</span>
|
|
<span class="n">kv_cache_params</span><span class="o">.</span><span class="n">kv_cache_block_pointers</span><span class="p">,</span>
|
|
<span class="n">kv_cache_params</span><span class="o">.</span><span class="n">host_kv_cache_block_pointers</span><span class="p">,</span>
|
|
<span class="n">kv_cache_params</span><span class="o">.</span><span class="n">host_max_attention_window_sizes</span><span class="p">)):</span>
|
|
<span class="n">lora_layer_params</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="k">if</span> <span class="n">lora_params</span><span class="o">.</span><span class="n">lora_ranks</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
|
<span class="n">lora_layer_params</span> <span class="o">=</span> <span class="n">lora_params</span><span class="o">.</span><span class="n">get_layer_params</span><span class="p">(</span><span class="n">layer_idx</span><span class="p">)</span>
|
|
|
|
<span class="n">hidden_states</span> <span class="o">=</span> <span class="n">layer</span><span class="p">(</span>
|
|
<span class="n">hidden_states</span><span class="p">,</span>
|
|
<span class="n">use_cache</span><span class="o">=</span><span class="n">use_cache</span><span class="p">,</span>
|
|
<span class="n">attention_mask</span><span class="o">=</span><span class="n">attention_mask</span><span class="p">,</span>
|
|
<span class="n">kv_cache_params</span><span class="o">=</span><span class="n">KeyValueCacheParams</span><span class="p">(</span>
|
|
<span class="n">past_key_value</span><span class="o">=</span><span class="p">[</span><span class="n">past</span><span class="p">],</span>
|
|
<span class="n">host_past_key_value_lengths</span><span class="o">=</span><span class="n">kv_cache_params</span><span class="o">.</span>
|
|
<span class="n">host_past_key_value_lengths</span><span class="p">,</span>
|
|
<span class="n">host_max_attention_window_sizes</span><span class="o">=</span><span class="n">max_attention_window_size</span><span class="p">,</span>
|
|
<span class="n">kv_cache_block_pointers</span><span class="o">=</span><span class="p">[</span><span class="n">pointer</span><span class="p">],</span>
|
|
<span class="n">host_kv_cache_block_pointers</span><span class="o">=</span><span class="p">[</span><span class="n">host_pointer</span><span class="p">],</span>
|
|
<span class="n">cache_indirection</span><span class="o">=</span><span class="n">kv_cache_params</span><span class="o">.</span><span class="n">cache_indirection</span><span class="p">),</span>
|
|
<span class="n">attention_params</span><span class="o">=</span><span class="n">attention_params</span><span class="p">,</span>
|
|
<span class="n">all_reduce_workspace</span><span class="o">=</span><span class="n">all_reduce_workspace</span><span class="p">,</span>
|
|
<span class="n">lora_layer_params</span><span class="o">=</span><span class="n">lora_layer_params</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="n">use_cache</span><span class="p">:</span>
|
|
<span class="n">presents</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
|
|
<span class="n">hidden_states</span> <span class="o">=</span> <span class="n">hidden_states</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
|
|
<span class="n">hidden_states</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ln_f</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="n">hidden_states</span> <span class="o">=</span> <span class="n">send</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">next_pp_rank</span><span class="p">())</span>
|
|
|
|
<span class="k">if</span> <span class="n">use_cache</span><span class="p">:</span>
|
|
<span class="k">return</span> <span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">(</span><span class="n">presents</span><span class="p">))</span>
|
|
<span class="k">return</span> <span class="n">hidden_states</span></div>
|
|
</div>
|
|
|
|
|
|
|
|
<div class="viewcode-block" id="LLaMAForCausalLM">
|
|
<a class="viewcode-back" href="../../../../python-api/tensorrt_llm.models.html#tensorrt_llm.models.LLaMAForCausalLM">[docs]</a>
|
|
<span class="k">class</span> <span class="nc">LLaMAForCausalLM</span><span class="p">(</span><span class="n">LLaMAModel</span><span class="p">,</span> <span class="n">GenerationMixin</span><span class="p">):</span>
|
|
|
|
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
|
|
<span class="n">num_layers</span><span class="p">,</span>
|
|
<span class="n">num_heads</span><span class="p">,</span>
|
|
<span class="n">num_kv_heads</span><span class="p">,</span>
|
|
<span class="n">hidden_size</span><span class="p">,</span>
|
|
<span class="n">vocab_size</span><span class="p">,</span>
|
|
<span class="n">hidden_act</span><span class="p">,</span>
|
|
<span class="n">max_position_embeddings</span><span class="p">,</span>
|
|
<span class="n">dtype</span><span class="p">,</span>
|
|
<span class="n">logits_dtype</span><span class="o">=</span><span class="s2">"float32"</span><span class="p">,</span>
|
|
<span class="n">mlp_hidden_size</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">position_embedding_type</span><span class="o">=</span><span class="n">PositionEmbeddingType</span><span class="o">.</span><span class="n">rope_gpt_neox</span><span class="p">,</span>
|
|
<span class="n">rotary_base</span><span class="o">=</span><span class="mf">10000.0</span><span class="p">,</span>
|
|
<span class="n">rotary_scaling</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">mapping</span><span class="o">=</span><span class="n">Mapping</span><span class="p">(),</span>
|
|
<span class="n">quant_mode</span><span class="o">=</span><span class="n">QuantMode</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
|
|
<span class="n">use_parallel_embedding</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
|
<span class="n">embedding_sharding_dim</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
|
|
<span class="n">rms_norm_eps</span><span class="o">=</span><span class="mf">1e-06</span><span class="p">,</span>
|
|
<span class="n">use_fused_mlp</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
|
<span class="n">attn_bias</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
|
<span class="n">mlp_bias</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
|
<span class="n">moe_config</span><span class="o">=</span><span class="n">MoeConfig</span><span class="p">(),</span>
|
|
<span class="n">use_prompt_tuning</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">):</span>
|
|
|
|
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dtype</span><span class="p">,</span> <span class="nb">str</span><span class="p">):</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">dtype</span> <span class="o">=</span> <span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dtype</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">dtype</span> <span class="o">=</span> <span class="n">dtype</span>
|
|
|
|
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">logits_dtype</span><span class="p">,</span> <span class="nb">str</span><span class="p">):</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">logits_dtype</span> <span class="o">=</span> <span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">logits_dtype</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">logits_dtype</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">logits_dtype</span> <span class="o">=</span> <span class="n">logits_dtype</span>
|
|
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span> <span class="o">=</span> <span class="n">num_layers</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">num_heads</span> <span class="o">=</span> <span class="n">num_heads</span>
|
|
<span class="k">if</span> <span class="n">num_kv_heads</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">or</span> <span class="n">num_kv_heads</span> <span class="o"><=</span> <span class="mi">0</span><span class="p">:</span>
|
|
<span class="n">num_kv_heads</span> <span class="o">=</span> <span class="n">num_heads</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">num_kv_heads</span> <span class="o">=</span> <span class="n">num_kv_heads</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span> <span class="o">=</span> <span class="n">hidden_size</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">vocab_size</span> <span class="o">=</span> <span class="n">vocab_size</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">tp_size</span> <span class="o">=</span> <span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span>
|
|
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">kv_dtype</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dtype</span>
|
|
<span class="k">if</span> <span class="n">quant_mode</span><span class="o">.</span><span class="n">has_int8_kv_cache</span><span class="p">():</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">kv_dtype</span> <span class="o">=</span> <span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="s1">'int8'</span><span class="p">)</span>
|
|
<span class="k">elif</span> <span class="n">quant_mode</span><span class="o">.</span><span class="n">has_fp8_kv_cache</span><span class="p">():</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">kv_dtype</span> <span class="o">=</span> <span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="s1">'fp8'</span><span class="p">)</span>
|
|
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">quant_mode</span> <span class="o">=</span> <span class="n">quant_mode</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">use_parallel_embedding</span> <span class="o">=</span> <span class="n">use_parallel_embedding</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">embedding_sharding_dim</span> <span class="o">=</span> <span class="n">embedding_sharding_dim</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">moe_config</span> <span class="o">=</span> <span class="n">moe_config</span>
|
|
|
|
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">num_layers</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">,</span> <span class="n">num_kv_heads</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">,</span>
|
|
<span class="n">vocab_size</span><span class="p">,</span> <span class="n">hidden_act</span><span class="p">,</span> <span class="n">max_position_embeddings</span><span class="p">,</span> <span class="n">dtype</span><span class="p">,</span>
|
|
<span class="n">mlp_hidden_size</span><span class="p">,</span> <span class="n">position_embedding_type</span><span class="p">,</span> <span class="n">rotary_base</span><span class="p">,</span>
|
|
<span class="n">rotary_scaling</span><span class="p">,</span> <span class="n">mapping</span><span class="p">,</span> <span class="n">quant_mode</span><span class="p">,</span>
|
|
<span class="n">use_parallel_embedding</span><span class="p">,</span> <span class="n">embedding_sharding_dim</span><span class="p">,</span>
|
|
<span class="n">rms_norm_eps</span><span class="p">,</span> <span class="n">use_fused_mlp</span><span class="p">,</span> <span class="n">attn_bias</span><span class="p">,</span> <span class="n">mlp_bias</span><span class="p">,</span>
|
|
<span class="n">moe_config</span><span class="p">,</span> <span class="n">use_prompt_tuning</span><span class="p">)</span>
|
|
|
|
<span class="n">vocab_size_padded</span> <span class="o">=</span> <span class="n">pad_vocab_size</span><span class="p">(</span><span class="n">vocab_size</span><span class="p">,</span> <span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">lm_head</span> <span class="o">=</span> <span class="n">ColumnLinear</span><span class="p">(</span><span class="n">hidden_size</span><span class="p">,</span>
|
|
<span class="n">vocab_size_padded</span><span class="p">,</span>
|
|
<span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span>
|
|
<span class="n">tp_group</span><span class="o">=</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_group</span><span class="p">,</span>
|
|
<span class="n">tp_size</span><span class="o">=</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span><span class="p">,</span>
|
|
<span class="n">gather_output</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
|
|
|
<div class="viewcode-block" id="LLaMAForCausalLM.forward">
|
|
<a class="viewcode-back" href="../../../../python-api/tensorrt_llm.models.html#tensorrt_llm.models.LLaMAForCausalLM.forward">[docs]</a>
|
|
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
|
|
<span class="n">input_ids</span><span class="p">,</span>
|
|
<span class="n">position_ids</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">use_cache</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
|
<span class="n">last_token_ids</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">attention_mask</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">kv_cache_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">attention_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">hidden_states</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">all_reduce_workspace</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">prompt_embedding_table</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">prompt_tasks</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">prompt_vocab_size</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">lora_params</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
|
<span class="n">hidden_states</span> <span class="o">=</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">input_ids</span><span class="p">,</span> <span class="n">position_ids</span><span class="p">,</span> <span class="n">use_cache</span><span class="p">,</span>
|
|
<span class="n">attention_mask</span><span class="p">,</span> <span class="n">kv_cache_params</span><span class="p">,</span>
|
|
<span class="n">attention_params</span><span class="p">,</span> <span class="n">hidden_states</span><span class="p">,</span>
|
|
<span class="n">all_reduce_workspace</span><span class="p">,</span>
|
|
<span class="n">prompt_embedding_table</span><span class="p">,</span> <span class="n">prompt_tasks</span><span class="p">,</span>
|
|
<span class="n">prompt_vocab_size</span><span class="p">,</span> <span class="n">lora_params</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="n">use_cache</span><span class="p">:</span>
|
|
<span class="n">hidden_states</span><span class="p">,</span> <span class="n">presents</span> <span class="o">=</span> <span class="n">hidden_states</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
|
|
<span class="n">hidden_states</span> <span class="o">=</span> <span class="n">gather_last_token_logits</span><span class="p">(</span>
|
|
<span class="n">hidden_states</span><span class="p">,</span> <span class="n">last_token_ids</span><span class="p">,</span>
|
|
<span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">)</span>
|
|
|
|
<span class="c1"># [batch_size, hidden_size] -> [batch_size, vocab_size]</span>
|
|
<span class="n">lm_logits</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">lm_head</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">)</span>
|
|
<span class="n">lm_logits</span><span class="o">.</span><span class="n">mark_output</span><span class="p">(</span><span class="s1">'logits'</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">logits_dtype</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="n">hidden_states</span><span class="o">.</span><span class="n">mark_output</span><span class="p">(</span><span class="s1">'hidden_states_output'</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="n">use_cache</span> <span class="ow">and</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">paged_kv_cache</span> <span class="o">==</span> <span class="kc">False</span><span class="p">:</span>
|
|
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">present</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_layers</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span><span class="p">),</span>
|
|
<span class="n">presents</span><span class="p">):</span>
|
|
<span class="n">present</span><span class="o">.</span><span class="n">mark_output</span><span class="p">(</span><span class="sa">f</span><span class="s1">'present_key_value_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">'</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">kv_dtype</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
|
|
<span class="k">return</span> <span class="p">(</span><span class="n">lm_logits</span><span class="p">,</span> <span class="n">presents</span><span class="p">)</span>
|
|
<span class="k">return</span> <span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span> <span class="n">presents</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
|
|
<span class="k">return</span> <span class="n">lm_logits</span>
|
|
<span class="k">return</span> <span class="n">hidden_states</span></div>
|
|
|
|
|
|
<div class="viewcode-block" id="LLaMAForCausalLM.prepare_inputs">
|
|
<a class="viewcode-back" href="../../../../python-api/tensorrt_llm.models.html#tensorrt_llm.models.LLaMAForCausalLM.prepare_inputs">[docs]</a>
|
|
<span class="k">def</span> <span class="nf">prepare_inputs</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
|
|
<span class="n">max_batch_size</span><span class="p">,</span>
|
|
<span class="n">max_input_len</span><span class="p">,</span>
|
|
<span class="n">max_new_tokens</span><span class="p">,</span>
|
|
<span class="n">use_cache</span><span class="p">,</span>
|
|
<span class="n">max_beam_width</span><span class="p">,</span>
|
|
<span class="n">max_num_tokens</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">prompt_embedding_table_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
|
|
<span class="n">gather_all_token_logits</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
|
|
<span class="n">lora_target_modules</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
|
|
<span class="w"> </span><span class="sd">'''@brief: Prepare inputs Tensors for the model, the given sizes are used to determine the</span>
|
|
<span class="sd"> ranges of the dimensions of when using TRT dynamic shapes.</span>
|
|
|
|
<span class="sd"> @return: a list contains values which can be fed into the self.forward()</span>
|
|
<span class="sd"> '''</span>
|
|
|
|
<span class="c1"># Prepare inputs</span>
|
|
<span class="n">head_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span> <span class="o">//</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_heads</span>
|
|
<span class="n">remove_input_padding</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">remove_input_padding</span>
|
|
<span class="n">use_gpt_attention_plugin</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">(</span>
|
|
<span class="p">)</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">gpt_attention_plugin</span>
|
|
<span class="n">use_gemm_plugin</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">gemm_plugin</span>
|
|
<span class="n">paged_kv_cache</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">paged_kv_cache</span>
|
|
<span class="n">tokens_per_block</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">tokens_per_block</span>
|
|
<span class="n">use_custom_all_reduce</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">(</span>
|
|
<span class="p">)</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_custom_all_reduce</span>
|
|
<span class="n">use_lora_plugin</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">lora_plugin</span>
|
|
|
|
<span class="n">model_inputs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">prepare_basic_inputs</span><span class="p">(</span>
|
|
<span class="n">max_batch_size</span><span class="p">,</span>
|
|
<span class="n">max_beam_width</span><span class="p">,</span>
|
|
<span class="n">max_input_len</span><span class="p">,</span>
|
|
<span class="n">max_new_tokens</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">num_kv_heads</span><span class="p">,</span>
|
|
<span class="n">head_size</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">kv_dtype</span><span class="p">,</span>
|
|
<span class="n">remove_input_padding</span><span class="o">=</span><span class="n">remove_input_padding</span><span class="p">,</span>
|
|
<span class="n">use_gpt_attention_plugin</span><span class="o">=</span><span class="n">use_gpt_attention_plugin</span><span class="p">,</span>
|
|
<span class="n">use_gemm_plugin</span><span class="o">=</span><span class="n">use_gemm_plugin</span><span class="p">,</span>
|
|
<span class="n">use_custom_all_reduce</span><span class="o">=</span><span class="n">use_custom_all_reduce</span><span class="p">,</span>
|
|
<span class="n">paged_kv_cache</span><span class="o">=</span><span class="n">paged_kv_cache</span><span class="p">,</span>
|
|
<span class="n">tokens_per_block</span><span class="o">=</span><span class="n">tokens_per_block</span><span class="p">,</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span>
|
|
<span class="n">num_heads</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">num_heads</span><span class="p">,</span>
|
|
<span class="n">mapping</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="p">,</span>
|
|
<span class="n">max_num_tokens</span><span class="o">=</span><span class="n">max_num_tokens</span><span class="p">,</span>
|
|
<span class="n">prompt_embedding_table_size</span><span class="o">=</span><span class="n">prompt_embedding_table_size</span><span class="p">,</span>
|
|
<span class="n">gather_all_token_logits</span><span class="o">=</span><span class="n">gather_all_token_logits</span><span class="p">,</span>
|
|
<span class="n">use_lora_plugin</span><span class="o">=</span><span class="n">use_lora_plugin</span><span class="p">,</span>
|
|
<span class="n">lora_target_modules</span><span class="o">=</span><span class="n">lora_target_modules</span><span class="p">)</span>
|
|
|
|
<span class="k">return</span> <span class="p">(</span>
|
|
<span class="n">model_inputs</span><span class="p">[</span><span class="s1">'input_ids'</span><span class="p">],</span>
|
|
<span class="n">model_inputs</span><span class="p">[</span><span class="s1">'position_ids'</span><span class="p">],</span>
|
|
<span class="kc">True</span><span class="p">,</span>
|
|
<span class="n">model_inputs</span><span class="p">[</span><span class="s1">'last_token_ids'</span><span class="p">],</span>
|
|
<span class="n">model_inputs</span><span class="p">[</span><span class="s1">'attention_mask'</span><span class="p">],</span>
|
|
<span class="n">KeyValueCacheParams</span><span class="p">(</span>
|
|
<span class="n">past_key_value</span><span class="o">=</span><span class="n">model_inputs</span><span class="p">[</span><span class="s1">'past_key_value'</span><span class="p">],</span>
|
|
<span class="n">host_past_key_value_lengths</span><span class="o">=</span><span class="n">model_inputs</span><span class="p">[</span>
|
|
<span class="s1">'host_past_key_value_lengths'</span><span class="p">],</span>
|
|
<span class="n">host_max_attention_window_sizes</span><span class="o">=</span><span class="n">model_inputs</span><span class="p">[</span>
|
|
<span class="s1">'host_max_attention_window_sizes'</span><span class="p">],</span>
|
|
<span class="n">kv_cache_block_pointers</span><span class="o">=</span><span class="n">model_inputs</span><span class="p">[</span>
|
|
<span class="s1">'kv_cache_block_pointers_list'</span><span class="p">],</span>
|
|
<span class="n">host_kv_cache_block_pointers</span><span class="o">=</span><span class="n">model_inputs</span><span class="p">[</span>
|
|
<span class="s1">'host_kv_cache_block_pointers_list'</span><span class="p">],</span>
|
|
<span class="n">cache_indirection</span><span class="o">=</span><span class="n">model_inputs</span><span class="p">[</span><span class="s1">'cache_indirection'</span><span class="p">],</span>
|
|
<span class="p">),</span>
|
|
<span class="n">AttentionParams</span><span class="p">(</span>
|
|
<span class="n">sequence_length</span><span class="o">=</span><span class="n">model_inputs</span><span class="p">[</span><span class="s1">'sequence_length'</span><span class="p">],</span>
|
|
<span class="n">context_lengths</span><span class="o">=</span><span class="n">model_inputs</span><span class="p">[</span><span class="s1">'context_lengths'</span><span class="p">],</span>
|
|
<span class="n">host_context_lengths</span><span class="o">=</span><span class="n">model_inputs</span><span class="p">[</span><span class="s1">'host_context_lengths'</span><span class="p">],</span>
|
|
<span class="n">max_context_length</span><span class="o">=</span><span class="n">max_input_len</span><span class="p">,</span>
|
|
<span class="n">host_request_types</span><span class="o">=</span><span class="n">model_inputs</span><span class="p">[</span><span class="s1">'host_request_types'</span><span class="p">]),</span>
|
|
<span class="n">model_inputs</span><span class="p">[</span><span class="s1">'hidden_states_input'</span><span class="p">],</span>
|
|
<span class="n">model_inputs</span><span class="p">[</span><span class="s1">'all_reduce_workspace'</span><span class="p">],</span>
|
|
<span class="n">model_inputs</span><span class="p">[</span><span class="s1">'prompt_embedding_table'</span><span class="p">],</span>
|
|
<span class="n">model_inputs</span><span class="p">[</span><span class="s1">'tasks'</span><span class="p">],</span>
|
|
<span class="n">model_inputs</span><span class="p">[</span><span class="s1">'prompt_vocab_size'</span><span class="p">],</span>
|
|
<span class="n">LoraParams</span><span class="p">(</span>
|
|
<span class="n">model_inputs</span><span class="p">[</span><span class="s1">'lora_ranks'</span><span class="p">],</span>
|
|
<span class="n">model_inputs</span><span class="p">[</span><span class="s1">'lora_weights_pointers'</span><span class="p">],</span>
|
|
<span class="n">host_context_lengths</span><span class="o">=</span><span class="n">model_inputs</span><span class="p">[</span><span class="s1">'host_context_lengths'</span><span class="p">],</span>
|
|
<span class="n">max_context_length</span><span class="o">=</span><span class="n">max_input_len</span><span class="p">,</span>
|
|
<span class="n">host_request_types</span><span class="o">=</span><span class="n">model_inputs</span><span class="p">[</span><span class="s1">'host_request_types'</span><span class="p">]),</span>
|
|
<span class="p">)</span></div>
|
|
</div>
|
|
|
|
</pre></div>
|
|
|
|
</div>
|
|
</div>
|
|
<footer>
|
|
|
|
<hr/>
|
|
|
|
<div role="contentinfo">
|
|
<p>© Copyright 2023, NVidia.</p>
|
|
</div>
|
|
|
|
Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
|
|
<a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
|
|
provided by <a href="https://readthedocs.org">Read the Docs</a>.
|
|
|
|
|
|
</footer>
|
|
</div>
|
|
</div>
|
|
</section>
|
|
</div>
|
|
<script>
|
|
jQuery(function () {
|
|
SphinxRtdTheme.Navigation.enable(true);
|
|
});
|
|
</script>
|
|
|
|
</body>
|
|
</html> |