TensorRT-LLMs/_modules/tensorrt_llm/models/gptneox/model.html
2023-10-25 19:55:42 +08:00

520 lines
60 KiB
HTML

<!DOCTYPE html>
<html class="writer-html5" lang="en" >
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>tensorrt_llm.models.gptneox.model &mdash; tensorrt_llm documentation</title>
<link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
<!--[if lt IE 9]>
<script src="../../../../_static/js/html5shiv.min.js"></script>
<![endif]-->
<script src="../../../../_static/jquery.js?v=5d32c60e"></script>
<script src="../../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
<script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js?v=b3ba4146"></script>
<script src="../../../../_static/doctools.js?v=888ff710"></script>
<script src="../../../../_static/sphinx_highlight.js?v=4825356b"></script>
<script src="../../../../_static/js/theme.js"></script>
<link rel="index" title="Index" href="../../../../genindex.html" />
<link rel="search" title="Search" href="../../../../search.html" />
</head>
<body class="wy-body-for-nav">
<div class="wy-grid-for-nav">
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
<div class="wy-side-scroll">
<div class="wy-side-nav-search" >
<a href="../../../../index.html" class="icon icon-home">
tensorrt_llm
</a>
<div role="search">
<form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
<input type="text" name="q" placeholder="Search docs" aria-label="Search docs" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
</div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
<p class="caption" role="heading"><span class="caption-text">Contents:</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../../architecture.html">TensorRT-LLM Architecture</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../../gpt_runtime.html">C++ GPT Runtime</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../../batch_manager.html">The Batch Manager in TensorRT-LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../../gpt_attention.html">Multi-head, Multi-query and Group-query Attention</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../../precision.html">Numerical Precision</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../../performance.html">Performance of TensorRT-LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../../installation.html">Build From Sources</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../../2023-05-19-how-to-debug.html">How to debug</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../../2023-05-17-how-to-add-a-new-model.html">How to add a new model</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../../graph-rewriting.html">Graph Rewriting Module</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../../python-api/tensorrt_llm.layers.html">Layers</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../../python-api/tensorrt_llm.functional.html">Functionals</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../../python-api/tensorrt_llm.models.html">Models</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../../python-api/tensorrt_llm.plugin.html">Plugin</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../../python-api/tensorrt_llm.quantization.html">Quantization</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../../python-api/tensorrt_llm.runtime.html">Runtime</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">C++ API</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../../_cpp_gen/runtime.html">Runtime</a></li>
</ul>
</div>
</div>
</nav>
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
<a href="../../../../index.html">tensorrt_llm</a>
</nav>
<div class="wy-nav-content">
<div class="rst-content">
<div role="navigation" aria-label="Page navigation">
<ul class="wy-breadcrumbs">
<li><a href="../../../../index.html" class="icon icon-home" aria-label="Home"></a></li>
<li class="breadcrumb-item"><a href="../../../index.html">Module code</a></li>
<li class="breadcrumb-item active">tensorrt_llm.models.gptneox.model</li>
<li class="wy-breadcrumbs-aside">
</li>
</ul>
<hr/>
</div>
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
<div itemprop="articleBody">
<h1>Source code for tensorrt_llm.models.gptneox.model</h1><div class="highlight"><pre>
<span></span><span class="c1"># SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION &amp; AFFILIATES. All rights reserved.</span>
<span class="c1"># SPDX-License-Identifier: Apache-2.0</span>
<span class="c1">#</span>
<span class="c1"># Licensed under the Apache License, Version 2.0 (the &quot;License&quot;);</span>
<span class="c1"># you may not use this file except in compliance with the License.</span>
<span class="c1"># You may obtain a copy of the License at</span>
<span class="c1">#</span>
<span class="c1"># http://www.apache.org/licenses/LICENSE-2.0</span>
<span class="c1">#</span>
<span class="c1"># Unless required by applicable law or agreed to in writing, software</span>
<span class="c1"># distributed under the License is distributed on an &quot;AS IS&quot; BASIS,</span>
<span class="c1"># WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.</span>
<span class="c1"># See the License for the specific language governing permissions and</span>
<span class="c1"># limitations under the License.</span>
<span class="kn">import</span> <span class="nn">tensorrt</span> <span class="k">as</span> <span class="nn">trt</span>
<span class="kn">from</span> <span class="nn">..._common</span> <span class="kn">import</span> <span class="n">default_net</span>
<span class="kn">from</span> <span class="nn">..._utils</span> <span class="kn">import</span> <span class="n">pad_vocab_size</span><span class="p">,</span> <span class="n">str_dtype_to_trt</span>
<span class="kn">from</span> <span class="nn">...functional</span> <span class="kn">import</span> <span class="p">(</span><span class="n">PositionEmbeddingType</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">gather_last_token_logits</span><span class="p">,</span> <span class="n">gpt_attention</span><span class="p">)</span>
<span class="kn">from</span> <span class="nn">...layers</span> <span class="kn">import</span> <span class="p">(</span><span class="n">MLP</span><span class="p">,</span> <span class="n">AttentionMaskType</span><span class="p">,</span> <span class="n">AttentionParams</span><span class="p">,</span> <span class="n">ColumnLinear</span><span class="p">,</span>
<span class="n">Embedding</span><span class="p">,</span> <span class="n">KeyValueCacheParams</span><span class="p">,</span> <span class="n">LayerNorm</span><span class="p">,</span> <span class="n">RowLinear</span><span class="p">)</span>
<span class="kn">from</span> <span class="nn">...mapping</span> <span class="kn">import</span> <span class="n">Mapping</span>
<span class="kn">from</span> <span class="nn">...module</span> <span class="kn">import</span> <span class="n">Module</span><span class="p">,</span> <span class="n">ModuleList</span>
<span class="kn">from</span> <span class="nn">...parameter</span> <span class="kn">import</span> <span class="n">Parameter</span>
<span class="kn">from</span> <span class="nn">...quantization</span> <span class="kn">import</span> <span class="n">QuantMode</span>
<span class="kn">from</span> <span class="nn">..generation_mixin</span> <span class="kn">import</span> <span class="n">GenerationMixin</span>
<span class="k">class</span> <span class="nc">GPTNeoXAttention</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">hidden_size</span><span class="p">,</span>
<span class="n">num_attention_heads</span><span class="p">,</span>
<span class="n">rotary_dim</span><span class="p">,</span>
<span class="n">max_position_embeddings</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">multi_block_mode</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">position_embedding_type</span><span class="o">=</span><span class="n">PositionEmbeddingType</span><span class="o">.</span><span class="n">rope_gpt_neox</span><span class="p">,</span>
<span class="n">quant_mode</span><span class="o">=</span><span class="n">QuantMode</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
<span class="n">tp_group</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">tp_size</span><span class="o">=</span><span class="mi">1</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">attention_head_size</span> <span class="o">=</span> <span class="n">hidden_size</span> <span class="o">//</span> <span class="n">num_attention_heads</span>
<span class="bp">self</span><span class="o">.</span><span class="n">num_attention_heads</span> <span class="o">=</span> <span class="n">num_attention_heads</span> <span class="o">//</span> <span class="n">tp_size</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_position_embeddings</span> <span class="o">=</span> <span class="n">max_position_embeddings</span>
<span class="bp">self</span><span class="o">.</span><span class="n">rotary_dim</span> <span class="o">=</span> <span class="n">rotary_dim</span>
<span class="bp">self</span><span class="o">.</span><span class="n">position_embedding_type</span> <span class="o">=</span> <span class="n">position_embedding_type</span>
<span class="bp">self</span><span class="o">.</span><span class="n">multi_block_mode</span> <span class="o">=</span> <span class="n">multi_block_mode</span>
<span class="bp">self</span><span class="o">.</span><span class="n">multi_query_mode</span> <span class="o">=</span> <span class="kc">False</span>
<span class="bp">self</span><span class="o">.</span><span class="n">quant_mode</span> <span class="o">=</span> <span class="n">quant_mode</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">quant_mode</span><span class="o">.</span><span class="n">has_int8_kv_cache</span><span class="p">():</span>
<span class="bp">self</span><span class="o">.</span><span class="n">kv_quantization_scale</span> <span class="o">=</span> <span class="n">Parameter</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="s1">&#39;float32&#39;</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">kv_dequantization_scale</span> <span class="o">=</span> <span class="n">Parameter</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="s1">&#39;float32&#39;</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">register_parameter</span><span class="p">(</span><span class="s1">&#39;kv_quantization_scale&#39;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">register_parameter</span><span class="p">(</span><span class="s1">&#39;kv_dequantization_scale&#39;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">qkv</span> <span class="o">=</span> <span class="n">ColumnLinear</span><span class="p">(</span><span class="n">in_features</span><span class="o">=</span><span class="n">hidden_size</span><span class="p">,</span>
<span class="n">out_features</span><span class="o">=</span><span class="n">hidden_size</span> <span class="o">*</span> <span class="mi">3</span><span class="p">,</span>
<span class="n">bias</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="n">tp_group</span><span class="o">=</span><span class="n">tp_group</span><span class="p">,</span>
<span class="n">tp_size</span><span class="o">=</span><span class="n">tp_size</span><span class="p">,</span>
<span class="n">gather_output</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">dense</span> <span class="o">=</span> <span class="n">RowLinear</span><span class="p">(</span><span class="n">in_features</span><span class="o">=</span><span class="n">hidden_size</span><span class="p">,</span>
<span class="n">out_features</span><span class="o">=</span><span class="n">hidden_size</span><span class="p">,</span>
<span class="n">bias</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span>
<span class="n">tp_group</span><span class="o">=</span><span class="n">tp_group</span><span class="p">,</span>
<span class="n">tp_size</span><span class="o">=</span><span class="n">tp_size</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">hidden_states</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">attention_mask</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">use_cache</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">kv_cache_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">attention_params</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">gpt_attention_plugin</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
<span class="s1">&#39;GPT-NeoX RoPE is only supported with GPTAttention plugin&#39;</span><span class="p">)</span>
<span class="n">qkv</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">qkv</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">attention_params</span><span class="o">.</span><span class="n">is_valid</span><span class="p">(</span>
<span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">gpt_attention_plugin</span><span class="p">,</span>
<span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">kv_cache_params</span><span class="o">.</span><span class="n">is_valid</span><span class="p">(</span>
<span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">gpt_attention_plugin</span><span class="p">)</span>
<span class="n">context</span><span class="p">,</span> <span class="n">past_key_value</span> <span class="o">=</span> <span class="n">gpt_attention</span><span class="p">(</span>
<span class="n">tensor</span><span class="o">=</span><span class="n">qkv</span><span class="p">,</span>
<span class="n">past_key_value</span><span class="o">=</span><span class="n">kv_cache_params</span><span class="o">.</span><span class="n">get_first_past_key_value</span><span class="p">(),</span>
<span class="n">sequence_length</span><span class="o">=</span><span class="n">attention_params</span><span class="o">.</span><span class="n">sequence_length</span><span class="p">,</span>
<span class="n">host_past_key_value_lengths</span><span class="o">=</span><span class="n">kv_cache_params</span><span class="o">.</span>
<span class="n">host_past_key_value_lengths</span><span class="p">,</span>
<span class="n">context_lengths</span><span class="o">=</span><span class="n">attention_params</span><span class="o">.</span><span class="n">context_lengths</span><span class="p">,</span>
<span class="n">cache_indirection</span><span class="o">=</span><span class="n">kv_cache_params</span><span class="o">.</span><span class="n">cache_indirection</span><span class="p">,</span>
<span class="n">host_request_types</span><span class="o">=</span><span class="n">attention_params</span><span class="o">.</span><span class="n">host_request_types</span><span class="p">,</span>
<span class="n">num_heads</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">num_attention_heads</span><span class="p">,</span>
<span class="n">num_kv_heads</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">num_attention_heads</span><span class="p">,</span>
<span class="n">hidden_size_per_head</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">attention_head_size</span><span class="p">,</span>
<span class="n">q_scaling</span><span class="o">=</span><span class="mf">1.0</span><span class="p">,</span>
<span class="n">rotary_embedding_dim</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">rotary_dim</span><span class="p">,</span>
<span class="n">position_embedding_type</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">position_embedding_type</span><span class="p">,</span>
<span class="n">multi_block_mode</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">multi_block_mode</span><span class="p">,</span>
<span class="n">kv_orig_quant_scale</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">kv_quantization_scale</span><span class="p">,</span>
<span class="n">kv_quant_orig_scale</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">kv_dequantization_scale</span><span class="p">,</span>
<span class="n">kv_cache_quant_mode</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">quant_mode</span><span class="p">,</span>
<span class="n">max_context_length</span><span class="o">=</span><span class="n">attention_params</span><span class="o">.</span><span class="n">max_context_length</span><span class="p">,</span>
<span class="n">host_context_lengths</span><span class="o">=</span><span class="n">attention_params</span><span class="o">.</span><span class="n">host_context_lengths</span><span class="p">)</span>
<span class="n">context</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dense</span><span class="p">(</span><span class="n">context</span><span class="p">)</span>
<span class="k">if</span> <span class="n">use_cache</span><span class="p">:</span>
<span class="k">return</span> <span class="p">(</span><span class="n">context</span><span class="p">,</span> <span class="n">past_key_value</span><span class="p">)</span>
<span class="k">return</span> <span class="n">context</span>
<span class="k">class</span> <span class="nc">GPTNeoXDecoderLayer</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">hidden_size</span><span class="p">,</span>
<span class="n">num_attention_heads</span><span class="p">,</span>
<span class="n">max_position_embeddings</span><span class="p">,</span>
<span class="n">num_layers</span><span class="p">,</span>
<span class="n">rotary_dim</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">apply_query_key_layer_scaling</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">attention_mask_type</span><span class="o">=</span><span class="n">AttentionMaskType</span><span class="o">.</span><span class="n">causal</span><span class="p">,</span>
<span class="n">hidden_act</span><span class="o">=</span><span class="s1">&#39;relu&#39;</span><span class="p">,</span>
<span class="n">position_embedding_type</span><span class="o">=</span><span class="n">PositionEmbeddingType</span><span class="o">.</span><span class="n">rope_gpt_neox</span><span class="p">,</span>
<span class="n">tp_group</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">tp_size</span><span class="o">=</span><span class="mi">1</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">input_layernorm</span> <span class="o">=</span> <span class="n">LayerNorm</span><span class="p">(</span><span class="n">normalized_shape</span><span class="o">=</span><span class="n">hidden_size</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">post_attention_layernorm</span> <span class="o">=</span> <span class="n">LayerNorm</span><span class="p">(</span><span class="n">normalized_shape</span><span class="o">=</span><span class="n">hidden_size</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">attention</span> <span class="o">=</span> <span class="n">GPTNeoXAttention</span><span class="p">(</span>
<span class="n">hidden_size</span><span class="o">=</span><span class="n">hidden_size</span><span class="p">,</span>
<span class="n">num_attention_heads</span><span class="o">=</span><span class="n">num_attention_heads</span><span class="p">,</span>
<span class="n">rotary_dim</span><span class="o">=</span><span class="n">rotary_dim</span><span class="p">,</span>
<span class="n">max_position_embeddings</span><span class="o">=</span><span class="n">max_position_embeddings</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span>
<span class="n">position_embedding_type</span><span class="o">=</span><span class="n">position_embedding_type</span><span class="p">,</span>
<span class="n">tp_group</span><span class="o">=</span><span class="n">tp_group</span><span class="p">,</span>
<span class="n">tp_size</span><span class="o">=</span><span class="n">tp_size</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">mlp</span> <span class="o">=</span> <span class="n">MLP</span><span class="p">(</span><span class="n">hidden_size</span><span class="o">=</span><span class="n">hidden_size</span><span class="p">,</span>
<span class="n">ffn_hidden_size</span><span class="o">=</span><span class="n">hidden_size</span> <span class="o">*</span> <span class="mi">4</span><span class="p">,</span>
<span class="n">hidden_act</span><span class="o">=</span><span class="n">hidden_act</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span>
<span class="n">tp_group</span><span class="o">=</span><span class="n">tp_group</span><span class="p">,</span>
<span class="n">tp_size</span><span class="o">=</span><span class="n">tp_size</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">hidden_states</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">attention_mask</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">use_cache</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">kv_cache_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">attention_params</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">default_net</span><span class="p">(</span>
<span class="p">)</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">layernorm_plugin</span> <span class="ow">and</span> <span class="n">trt</span><span class="o">.</span><span class="n">__version__</span><span class="p">[:</span><span class="mi">3</span><span class="p">]</span> <span class="o">==</span> <span class="s1">&#39;8.6&#39;</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">AssertionError</span><span class="p">(</span>
<span class="s2">&quot;You need to enable the LayerNorm plugin for GPT-NeoX with TensorRT 8.6. Please set plugin_config.layernorm_plugin&quot;</span>
<span class="p">)</span>
<span class="n">residual</span> <span class="o">=</span> <span class="n">hidden_states</span>
<span class="n">input_layernorm_output</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">input_layernorm</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">)</span>
<span class="n">post_attention_layernorm_output</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">post_attention_layernorm</span><span class="p">(</span>
<span class="n">hidden_states</span><span class="p">)</span>
<span class="n">attention_output</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">attention</span><span class="p">(</span><span class="n">input_layernorm_output</span><span class="p">,</span>
<span class="n">attention_mask</span><span class="o">=</span><span class="n">attention_mask</span><span class="p">,</span>
<span class="n">use_cache</span><span class="o">=</span><span class="n">use_cache</span><span class="p">,</span>
<span class="n">kv_cache_params</span><span class="o">=</span><span class="n">kv_cache_params</span><span class="p">,</span>
<span class="n">attention_params</span><span class="o">=</span><span class="n">attention_params</span><span class="p">)</span>
<span class="k">if</span> <span class="n">use_cache</span><span class="p">:</span>
<span class="n">attention_output</span><span class="p">,</span> <span class="n">presents</span> <span class="o">=</span> <span class="n">attention_output</span>
<span class="n">feed_forward_hidden_states</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">mlp</span><span class="p">(</span><span class="n">post_attention_layernorm_output</span><span class="p">)</span>
<span class="n">hidden_states</span> <span class="o">=</span> <span class="n">attention_output</span> <span class="o">+</span> <span class="n">feed_forward_hidden_states</span> <span class="o">+</span> <span class="n">residual</span>
<span class="k">if</span> <span class="n">use_cache</span><span class="p">:</span>
<span class="k">return</span> <span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span> <span class="n">presents</span><span class="p">)</span>
<span class="k">return</span> <span class="n">hidden_states</span>
<div class="viewcode-block" id="GPTNeoXModel"><a class="viewcode-back" href="../../../../python-api/tensorrt_llm.models.html#tensorrt_llm.models.GPTNeoXModel">[docs]</a><span class="k">class</span> <span class="nc">GPTNeoXModel</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">num_layers</span><span class="p">,</span>
<span class="n">num_heads</span><span class="p">,</span>
<span class="n">hidden_size</span><span class="p">,</span>
<span class="n">vocab_size</span><span class="p">,</span>
<span class="n">hidden_act</span><span class="p">,</span>
<span class="n">max_position_embeddings</span><span class="p">,</span>
<span class="n">rotary_dim</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">position_embedding_type</span><span class="o">=</span><span class="n">PositionEmbeddingType</span><span class="o">.</span><span class="n">rope_gpt_neox</span><span class="p">,</span>
<span class="n">mapping</span><span class="o">=</span><span class="n">Mapping</span><span class="p">(),</span>
<span class="n">apply_query_key_layer_scaling</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">use_parallel_embedding</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">embedding_sharding_dim</span><span class="o">=</span><span class="mi">0</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">embedding</span> <span class="o">=</span> <span class="n">Embedding</span><span class="p">(</span>
<span class="n">num_embeddings</span><span class="o">=</span><span class="n">vocab_size</span><span class="p">,</span>
<span class="n">embedding_dim</span><span class="o">=</span><span class="n">hidden_size</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span>
<span class="n">tp_size</span><span class="o">=</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span> <span class="k">if</span> <span class="n">use_parallel_embedding</span> <span class="k">else</span> <span class="mi">1</span><span class="p">,</span>
<span class="n">tp_group</span><span class="o">=</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_group</span> <span class="k">if</span> <span class="n">use_parallel_embedding</span> <span class="k">else</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">sharding_dim</span><span class="o">=</span><span class="n">embedding_sharding_dim</span><span class="p">,</span>
<span class="n">tp_rank</span><span class="o">=</span><span class="n">mapping</span><span class="o">.</span><span class="n">rank</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">layers</span> <span class="o">=</span> <span class="n">ModuleList</span><span class="p">([</span>
<span class="n">GPTNeoXDecoderLayer</span><span class="p">(</span>
<span class="n">hidden_size</span><span class="o">=</span><span class="n">hidden_size</span><span class="p">,</span>
<span class="n">num_attention_heads</span><span class="o">=</span><span class="n">num_heads</span><span class="p">,</span>
<span class="n">max_position_embeddings</span><span class="o">=</span><span class="n">max_position_embeddings</span><span class="p">,</span>
<span class="n">num_layers</span><span class="o">=</span><span class="n">num_layers</span><span class="p">,</span>
<span class="n">rotary_dim</span><span class="o">=</span><span class="n">rotary_dim</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span>
<span class="n">apply_query_key_layer_scaling</span><span class="o">=</span><span class="n">apply_query_key_layer_scaling</span><span class="p">,</span>
<span class="n">attention_mask_type</span><span class="o">=</span><span class="n">AttentionMaskType</span><span class="o">.</span><span class="n">causal</span><span class="p">,</span>
<span class="n">hidden_act</span><span class="o">=</span><span class="n">hidden_act</span><span class="p">,</span>
<span class="n">position_embedding_type</span><span class="o">=</span><span class="n">position_embedding_type</span><span class="p">,</span>
<span class="n">tp_group</span><span class="o">=</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_group</span><span class="p">,</span>
<span class="n">tp_size</span><span class="o">=</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span><span class="p">)</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_layers</span><span class="p">)</span>
<span class="p">])</span>
<span class="bp">self</span><span class="o">.</span><span class="n">ln_f</span> <span class="o">=</span> <span class="n">LayerNorm</span><span class="p">(</span><span class="n">normalized_shape</span><span class="o">=</span><span class="n">hidden_size</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
<div class="viewcode-block" id="GPTNeoXModel.forward"><a class="viewcode-back" href="../../../../python-api/tensorrt_llm.models.html#tensorrt_llm.models.GPTNeoXModel.forward">[docs]</a> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">input_ids</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">position_ids</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">use_cache</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">kv_cache_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">attention_params</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="n">hidden_states</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">embedding</span><span class="p">(</span><span class="n">input_ids</span><span class="p">)</span>
<span class="k">if</span> <span class="n">kv_cache_params</span><span class="o">.</span><span class="n">past_key_value</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">kv_cache_params</span><span class="o">.</span><span class="n">past_key_value</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">([</span><span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">))</span>
<span class="k">if</span> <span class="n">use_cache</span><span class="p">:</span>
<span class="n">presents</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">layer</span><span class="p">,</span> <span class="n">past</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">,</span> <span class="n">kv_cache_params</span><span class="o">.</span><span class="n">past_key_value</span><span class="p">):</span>
<span class="n">hidden_states</span> <span class="o">=</span> <span class="n">layer</span><span class="p">(</span>
<span class="n">hidden_states</span><span class="p">,</span>
<span class="n">use_cache</span><span class="o">=</span><span class="n">use_cache</span><span class="p">,</span>
<span class="n">kv_cache_params</span><span class="o">=</span><span class="n">KeyValueCacheParams</span><span class="p">(</span>
<span class="n">past_key_value</span><span class="o">=</span><span class="p">[</span><span class="n">past</span><span class="p">],</span>
<span class="n">host_past_key_value_lengths</span><span class="o">=</span><span class="n">kv_cache_params</span><span class="o">.</span>
<span class="n">host_past_key_value_lengths</span><span class="p">,</span>
<span class="n">cache_indirection</span><span class="o">=</span><span class="n">kv_cache_params</span><span class="o">.</span><span class="n">cache_indirection</span><span class="p">),</span>
<span class="n">attention_params</span><span class="o">=</span><span class="n">attention_params</span><span class="p">)</span>
<span class="k">if</span> <span class="n">use_cache</span><span class="p">:</span>
<span class="n">presents</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
<span class="n">hidden_states</span> <span class="o">=</span> <span class="n">hidden_states</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">hidden_states</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ln_f</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">)</span>
<span class="k">if</span> <span class="n">use_cache</span><span class="p">:</span>
<span class="k">return</span> <span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">(</span><span class="n">presents</span><span class="p">))</span>
<span class="k">return</span> <span class="n">hidden_states</span></div></div>
<div class="viewcode-block" id="GPTNeoXForCausalLM"><a class="viewcode-back" href="../../../../python-api/tensorrt_llm.models.html#tensorrt_llm.models.GPTNeoXForCausalLM">[docs]</a><span class="k">class</span> <span class="nc">GPTNeoXForCausalLM</span><span class="p">(</span><span class="n">GPTNeoXModel</span><span class="p">,</span> <span class="n">GenerationMixin</span><span class="p">):</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">num_layers</span><span class="p">,</span>
<span class="n">num_heads</span><span class="p">,</span>
<span class="n">hidden_size</span><span class="p">,</span>
<span class="n">vocab_size</span><span class="p">,</span>
<span class="n">hidden_act</span><span class="p">,</span>
<span class="n">max_position_embeddings</span><span class="p">,</span>
<span class="n">rotary_dim</span><span class="p">,</span>
<span class="n">dtype</span><span class="p">,</span>
<span class="n">position_embedding_type</span><span class="o">=</span><span class="n">PositionEmbeddingType</span><span class="o">.</span><span class="n">rope_gpt_neox</span><span class="p">,</span>
<span class="n">mapping</span><span class="o">=</span><span class="n">Mapping</span><span class="p">(),</span>
<span class="n">apply_query_key_layer_scaling</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">use_parallel_embedding</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">embedding_sharding_dim</span><span class="o">=</span><span class="mi">0</span><span class="p">):</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dtype</span><span class="p">,</span> <span class="nb">str</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_kv_dtype</span> <span class="o">=</span> <span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dtype</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_kv_dtype</span> <span class="o">=</span> <span class="n">dtype</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_num_layers</span> <span class="o">=</span> <span class="n">num_layers</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_num_heads</span> <span class="o">=</span> <span class="n">num_heads</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_hidden_size</span> <span class="o">=</span> <span class="n">hidden_size</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_vocab_size</span> <span class="o">=</span> <span class="n">vocab_size</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_tp_size</span> <span class="o">=</span> <span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_use_parallel_embedding</span> <span class="o">=</span> <span class="n">use_parallel_embedding</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_embedding_sharding_dim</span> <span class="o">=</span> <span class="n">embedding_sharding_dim</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span>
<span class="n">num_layers</span><span class="o">=</span><span class="n">num_layers</span><span class="p">,</span>
<span class="n">num_heads</span><span class="o">=</span><span class="n">num_heads</span><span class="p">,</span>
<span class="n">hidden_size</span><span class="o">=</span><span class="n">hidden_size</span><span class="p">,</span>
<span class="n">vocab_size</span><span class="o">=</span><span class="n">vocab_size</span><span class="p">,</span>
<span class="n">hidden_act</span><span class="o">=</span><span class="n">hidden_act</span><span class="p">,</span>
<span class="n">max_position_embeddings</span><span class="o">=</span><span class="n">max_position_embeddings</span><span class="p">,</span>
<span class="n">rotary_dim</span><span class="o">=</span><span class="n">rotary_dim</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span>
<span class="n">position_embedding_type</span><span class="o">=</span><span class="n">position_embedding_type</span><span class="p">,</span>
<span class="n">mapping</span><span class="o">=</span><span class="n">mapping</span><span class="p">,</span>
<span class="n">apply_query_key_layer_scaling</span><span class="o">=</span><span class="n">apply_query_key_layer_scaling</span><span class="p">,</span>
<span class="n">use_parallel_embedding</span><span class="o">=</span><span class="n">use_parallel_embedding</span><span class="p">,</span>
<span class="n">embedding_sharding_dim</span><span class="o">=</span><span class="n">embedding_sharding_dim</span><span class="p">)</span>
<span class="n">vocab_size_padded</span> <span class="o">=</span> <span class="n">pad_vocab_size</span><span class="p">(</span><span class="n">vocab_size</span><span class="p">,</span> <span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">lm_head</span> <span class="o">=</span> <span class="n">ColumnLinear</span><span class="p">(</span><span class="n">hidden_size</span><span class="p">,</span>
<span class="n">vocab_size_padded</span><span class="p">,</span>
<span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span>
<span class="n">tp_group</span><span class="o">=</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_group</span><span class="p">,</span>
<span class="n">tp_size</span><span class="o">=</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span><span class="p">,</span>
<span class="n">gather_output</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<div class="viewcode-block" id="GPTNeoXForCausalLM.forward"><a class="viewcode-back" href="../../../../python-api/tensorrt_llm.models.html#tensorrt_llm.models.GPTNeoXForCausalLM.forward">[docs]</a> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">input_ids</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">position_ids</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">use_cache</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">last_token_ids</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">kv_cache_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">attention_params</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="n">hidden_states</span> <span class="o">=</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">input_ids</span><span class="p">,</span> <span class="n">position_ids</span><span class="p">,</span> <span class="n">use_cache</span><span class="p">,</span>
<span class="n">kv_cache_params</span><span class="p">,</span> <span class="n">attention_params</span><span class="p">)</span>
<span class="k">if</span> <span class="n">use_cache</span><span class="p">:</span>
<span class="n">hidden_states</span><span class="p">,</span> <span class="n">presents</span> <span class="o">=</span> <span class="n">hidden_states</span>
<span class="n">hidden_states</span> <span class="o">=</span> <span class="n">gather_last_token_logits</span><span class="p">(</span>
<span class="n">hidden_states</span><span class="p">,</span> <span class="n">last_token_ids</span><span class="p">,</span>
<span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">)</span>
<span class="c1"># [batch_size, hidden_size] -&gt; [batch_size, vocab_size]</span>
<span class="n">lm_logits</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">lm_head</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">)</span>
<span class="n">lm_logits</span><span class="o">.</span><span class="n">mark_output</span><span class="p">(</span><span class="s1">&#39;logits&#39;</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_kv_dtype</span><span class="p">)</span>
<span class="k">if</span> <span class="n">use_cache</span> <span class="ow">and</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">paged_kv_cache</span> <span class="o">==</span> <span class="kc">False</span><span class="p">:</span>
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">present</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">presents</span><span class="p">):</span>
<span class="n">present</span><span class="o">.</span><span class="n">mark_output</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;present_key_value_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_kv_dtype</span><span class="p">)</span>
<span class="k">return</span> <span class="p">(</span><span class="n">lm_logits</span><span class="p">,</span> <span class="n">presents</span><span class="p">)</span>
<span class="k">return</span> <span class="n">lm_logits</span></div>
<div class="viewcode-block" id="GPTNeoXForCausalLM.prepare_inputs"><a class="viewcode-back" href="../../../../python-api/tensorrt_llm.models.html#tensorrt_llm.models.GPTNeoXForCausalLM.prepare_inputs">[docs]</a> <span class="k">def</span> <span class="nf">prepare_inputs</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">max_batch_size</span><span class="p">,</span> <span class="n">max_input_len</span><span class="p">,</span> <span class="n">max_new_tokens</span><span class="p">,</span>
<span class="n">use_cache</span><span class="p">,</span> <span class="n">max_beam_width</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;@brief: Prepare inputs Tensors for the model, the given sizes are used to determine the</span>
<span class="sd"> ranges of the dimensions of when using TRT dynamic shapes.</span>
<span class="sd"> @return: a list contains values which can be fed into the self.forward()</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="c1"># Prepare inputs</span>
<span class="n">head_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_hidden_size</span> <span class="o">//</span> <span class="bp">self</span><span class="o">.</span><span class="n">_num_heads</span>
<span class="n">num_heads</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_num_heads</span> <span class="o">//</span> <span class="bp">self</span><span class="o">.</span><span class="n">_tp_size</span>
<span class="n">remove_input_padding</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">remove_input_padding</span>
<span class="n">use_gpt_attention_plugin</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">(</span>
<span class="p">)</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">gpt_attention_plugin</span>
<span class="n">use_gemm_plugin</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">gemm_plugin</span>
<span class="n">model_inputs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">prepare_basic_inputs</span><span class="p">(</span>
<span class="n">max_batch_size</span><span class="p">,</span>
<span class="n">max_beam_width</span><span class="p">,</span>
<span class="n">max_input_len</span><span class="p">,</span>
<span class="n">max_new_tokens</span><span class="p">,</span>
<span class="n">num_heads</span><span class="p">,</span>
<span class="n">head_size</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_num_layers</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_kv_dtype</span><span class="p">,</span>
<span class="n">remove_input_padding</span><span class="p">,</span>
<span class="n">use_gpt_attention_plugin</span><span class="p">,</span>
<span class="n">use_gemm_plugin</span><span class="o">=</span><span class="n">use_gemm_plugin</span><span class="p">)</span>
<span class="k">return</span> <span class="p">(</span><span class="n">model_inputs</span><span class="p">[</span><span class="s1">&#39;input_ids&#39;</span><span class="p">],</span> <span class="n">model_inputs</span><span class="p">[</span><span class="s1">&#39;position_ids&#39;</span><span class="p">],</span> <span class="kc">True</span><span class="p">,</span>
<span class="n">model_inputs</span><span class="p">[</span><span class="s1">&#39;last_token_ids&#39;</span><span class="p">],</span>
<span class="n">KeyValueCacheParams</span><span class="p">(</span>
<span class="n">past_key_value</span><span class="o">=</span><span class="n">model_inputs</span><span class="p">[</span><span class="s1">&#39;past_key_value&#39;</span><span class="p">],</span>
<span class="n">host_past_key_value_lengths</span><span class="o">=</span><span class="n">model_inputs</span><span class="p">[</span>
<span class="s1">&#39;host_past_key_value_lengths&#39;</span><span class="p">],</span>
<span class="n">cache_indirection</span><span class="o">=</span><span class="n">model_inputs</span><span class="p">[</span><span class="s1">&#39;cache_indirection&#39;</span><span class="p">],</span>
<span class="p">),</span>
<span class="n">AttentionParams</span><span class="p">(</span>
<span class="n">sequence_length</span><span class="o">=</span><span class="n">model_inputs</span><span class="p">[</span><span class="s1">&#39;sequence_length&#39;</span><span class="p">],</span>
<span class="n">context_lengths</span><span class="o">=</span><span class="n">model_inputs</span><span class="p">[</span><span class="s1">&#39;context_lengths&#39;</span><span class="p">],</span>
<span class="n">host_context_lengths</span><span class="o">=</span><span class="n">model_inputs</span><span class="p">[</span><span class="s1">&#39;host_context_lengths&#39;</span><span class="p">],</span>
<span class="n">max_context_length</span><span class="o">=</span><span class="n">max_input_len</span><span class="p">,</span>
<span class="n">host_request_types</span><span class="o">=</span><span class="n">model_inputs</span><span class="p">[</span><span class="s1">&#39;host_request_types&#39;</span><span class="p">]))</span></div></div>
</pre></div>
</div>
</div>
<footer>
<hr/>
<div role="contentinfo">
<p>&#169; Copyright 2023, NVidia.</p>
</div>
Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
<a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
provided by <a href="https://readthedocs.org">Read the Docs</a>.
</footer>
</div>
</div>
</section>
</div>
<script>
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
</script>
</body>
</html>