TensorRT-LLMs/_modules/tensorrt_llm/runtime/generation.html
2023-12-04 18:59:41 +08:00

2361 lines
391 KiB
HTML

<!DOCTYPE html>
<html class="writer-html5" lang="en" data-content_root="../../../">
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>tensorrt_llm.runtime.generation &mdash; tensorrt_llm documentation</title>
<link rel="stylesheet" type="text/css" href="../../../_static/pygments.css?v=80d5e7a1" />
<link rel="stylesheet" type="text/css" href="../../../_static/css/theme.css?v=19f00094" />
<!--[if lt IE 9]>
<script src="../../../_static/js/html5shiv.min.js"></script>
<![endif]-->
<script src="../../../_static/jquery.js?v=5d32c60e"></script>
<script src="../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
<script src="../../../_static/documentation_options.js?v=5929fcd5"></script>
<script src="../../../_static/doctools.js?v=888ff710"></script>
<script src="../../../_static/sphinx_highlight.js?v=dc90522c"></script>
<script src="../../../_static/js/theme.js"></script>
<link rel="index" title="Index" href="../../../genindex.html" />
<link rel="search" title="Search" href="../../../search.html" />
</head>
<body class="wy-body-for-nav">
<div class="wy-grid-for-nav">
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
<div class="wy-side-scroll">
<div class="wy-side-nav-search" >
<a href="../../../index.html" class="icon icon-home">
tensorrt_llm
</a>
<div role="search">
<form id="rtd-search-form" class="wy-form" action="../../../search.html" method="get">
<input type="text" name="q" placeholder="Search docs" aria-label="Search docs" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
</div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
<p class="caption" role="heading"><span class="caption-text">Contents:</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../architecture.html">TensorRT-LLM Architecture</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../gpt_runtime.html">C++ GPT Runtime</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../batch_manager.html">The Batch Manager in TensorRT-LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../gpt_attention.html">Multi-head, Multi-query and Group-query Attention</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../precision.html">Numerical Precision</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../installation.html">Build TensorRT-LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../performance.html">Performance of TensorRT-LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../2023-05-19-how-to-debug.html">How to debug</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../2023-05-17-how-to-add-a-new-model.html">How to add a new model</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../graph-rewriting.html">Graph Rewriting Module</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../memory.html">Memory Usage of TensorRT-LLM</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../python-api/tensorrt_llm.layers.html">Layers</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../python-api/tensorrt_llm.functional.html">Functionals</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../python-api/tensorrt_llm.models.html">Models</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../python-api/tensorrt_llm.plugin.html">Plugin</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../python-api/tensorrt_llm.quantization.html">Quantization</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../python-api/tensorrt_llm.runtime.html">Runtime</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">C++ API</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../_cpp_gen/runtime.html">Runtime</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Blogs</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/H100vsA100.html">H100 has 4.6x A100 Performance in TensorRT-LLM, achieving 10,000 tok/s at 100ms to first token</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/H200launch.html">H200 achieves nearly 12,000 tokens/sec on Llama2-13B with TensorRT-LLM</a></li>
</ul>
</div>
</div>
</nav>
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
<a href="../../../index.html">tensorrt_llm</a>
</nav>
<div class="wy-nav-content">
<div class="rst-content">
<div role="navigation" aria-label="Page navigation">
<ul class="wy-breadcrumbs">
<li><a href="../../../index.html" class="icon icon-home" aria-label="Home"></a></li>
<li class="breadcrumb-item"><a href="../../index.html">Module code</a></li>
<li class="breadcrumb-item active">tensorrt_llm.runtime.generation</li>
<li class="wy-breadcrumbs-aside">
</li>
</ul>
<hr/>
</div>
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
<div itemprop="articleBody">
<h1>Source code for tensorrt_llm.runtime.generation</h1><div class="highlight"><pre>
<span></span><span class="c1"># SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION &amp; AFFILIATES. All rights reserved.</span>
<span class="c1"># SPDX-License-Identifier: Apache-2.0</span>
<span class="c1">#</span>
<span class="c1"># Licensed under the Apache License, Version 2.0 (the &quot;License&quot;);</span>
<span class="c1"># you may not use this file except in compliance with the License.</span>
<span class="c1"># You may obtain a copy of the License at</span>
<span class="c1">#</span>
<span class="c1"># http://www.apache.org/licenses/LICENSE-2.0</span>
<span class="c1">#</span>
<span class="c1"># Unless required by applicable law or agreed to in writing, software</span>
<span class="c1"># distributed under the License is distributed on an &quot;AS IS&quot; BASIS,</span>
<span class="c1"># WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.</span>
<span class="c1"># See the License for the specific language governing permissions and</span>
<span class="c1"># limitations under the License.</span>
<span class="kn">import</span> <span class="nn">csv</span>
<span class="kn">import</span> <span class="nn">math</span>
<span class="kn">from</span> <span class="nn">dataclasses</span> <span class="kn">import</span> <span class="n">dataclass</span><span class="p">,</span> <span class="n">field</span>
<span class="kn">from</span> <span class="nn">functools</span> <span class="kn">import</span> <span class="n">wraps</span>
<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">List</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Sequence</span><span class="p">,</span> <span class="n">Union</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="c1"># isort: off</span>
<span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">import</span> <span class="nn">tensorrt</span> <span class="k">as</span> <span class="nn">trt</span>
<span class="c1"># isort: on</span>
<span class="kn">from</span> <span class="nn">cuda</span> <span class="kn">import</span> <span class="n">cudart</span>
<span class="kn">from</span> <span class="nn">.._ipc_utils</span> <span class="kn">import</span> <span class="n">IpcMemory</span><span class="p">,</span> <span class="n">set_peer_access</span>
<span class="kn">from</span> <span class="nn">.._utils</span> <span class="kn">import</span> <span class="n">pad_vocab_size</span><span class="p">,</span> <span class="n">str_dtype_to_torch</span><span class="p">,</span> <span class="n">trt_dtype_to_torch</span>
<span class="kn">from</span> <span class="nn">..logger</span> <span class="kn">import</span> <span class="n">logger</span>
<span class="kn">from</span> <span class="nn">..mapping</span> <span class="kn">import</span> <span class="n">Mapping</span>
<span class="kn">from</span> <span class="nn">..quantization</span> <span class="kn">import</span> <span class="n">QuantMode</span>
<span class="kn">from</span> <span class="nn">.kv_cache_manager</span> <span class="kn">import</span> <span class="n">GenerationSequence</span><span class="p">,</span> <span class="n">KVCacheManager</span>
<span class="kn">from</span> <span class="nn">.lora_manager</span> <span class="kn">import</span> <span class="n">LoraManager</span>
<span class="kn">from</span> <span class="nn">.session</span> <span class="kn">import</span> <span class="n">_scoped_stream</span>
<div class="viewcode-block" id="to_word_list_format">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.to_word_list_format">[docs]</a>
<span class="k">def</span> <span class="nf">to_word_list_format</span><span class="p">(</span><span class="n">word_dict</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]],</span>
<span class="n">tokenizer</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">add_special_tokens</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> format of word_dict</span>
<span class="sd"> len(word_dict) should be same to batch_size</span>
<span class="sd"> word_dict[i] means the words for batch i</span>
<span class="sd"> len(word_dict[i]) must be 1, which means it only contains 1 string</span>
<span class="sd"> This string can contains several sentences and split by &quot;,&quot;.</span>
<span class="sd"> For example, if word_dict[2] = &quot; I am happy, I am sad&quot;, then this function will return</span>
<span class="sd"> the ids for two short sentences &quot; I am happy&quot; and &quot; I am sad&quot;.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">assert</span> <span class="n">tokenizer</span> <span class="o">!=</span> <span class="kc">None</span><span class="p">,</span> <span class="s2">&quot;need to set tokenizer&quot;</span>
<span class="n">flat_ids</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">offsets</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">word_dict_item</span> <span class="ow">in</span> <span class="n">word_dict</span><span class="p">:</span>
<span class="n">item_flat_ids</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">item_offsets</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">word_dict_item</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="nb">bytes</span><span class="p">):</span>
<span class="n">word_dict_item</span> <span class="o">=</span> <span class="p">[</span><span class="n">word_dict_item</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">decode</span><span class="p">()]</span>
<span class="n">words</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">csv</span><span class="o">.</span><span class="n">reader</span><span class="p">(</span><span class="n">word_dict_item</span><span class="p">))[</span><span class="mi">0</span><span class="p">]</span>
<span class="k">for</span> <span class="n">word</span> <span class="ow">in</span> <span class="n">words</span><span class="p">:</span>
<span class="n">ids</span> <span class="o">=</span> <span class="n">tokenizer</span><span class="o">.</span><span class="n">encode</span><span class="p">(</span><span class="n">word</span><span class="p">,</span> <span class="n">add_special_tokens</span><span class="o">=</span><span class="n">add_special_tokens</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">ids</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="k">continue</span>
<span class="n">item_flat_ids</span> <span class="o">+=</span> <span class="n">ids</span>
<span class="n">item_offsets</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">ids</span><span class="p">))</span>
<span class="n">flat_ids</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">item_flat_ids</span><span class="p">))</span>
<span class="n">offsets</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">item_offsets</span><span class="p">)))</span>
<span class="n">pad_to</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="nb">max</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">ids</span><span class="p">)</span> <span class="k">for</span> <span class="n">ids</span> <span class="ow">in</span> <span class="n">flat_ids</span><span class="p">))</span>
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="p">(</span><span class="n">ids</span><span class="p">,</span> <span class="n">offs</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="nb">zip</span><span class="p">(</span><span class="n">flat_ids</span><span class="p">,</span> <span class="n">offsets</span><span class="p">)):</span>
<span class="n">flat_ids</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">pad</span><span class="p">(</span><span class="n">ids</span><span class="p">,</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">pad_to</span> <span class="o">-</span> <span class="nb">len</span><span class="p">(</span><span class="n">ids</span><span class="p">)),</span> <span class="n">constant_values</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="n">offsets</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">pad</span><span class="p">(</span><span class="n">offs</span><span class="p">,</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">pad_to</span> <span class="o">-</span> <span class="nb">len</span><span class="p">(</span><span class="n">offs</span><span class="p">)),</span> <span class="n">constant_values</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">flat_ids</span><span class="p">,</span> <span class="n">offsets</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">&quot;int32&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">transpose</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">))</span></div>
<span class="k">def</span> <span class="nf">_prepare_input_ids</span><span class="p">(</span><span class="n">tensors</span><span class="p">:</span> <span class="n">Sequence</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]):</span>
<span class="n">tensors</span> <span class="o">=</span> <span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">flatten</span><span class="p">(</span><span class="n">t</span><span class="p">)</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="n">tensors</span><span class="p">]</span>
<span class="n">data</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">concat</span><span class="p">(</span><span class="n">tensors</span><span class="p">),</span> <span class="mi">0</span><span class="p">)</span>
<span class="n">row_lengths</span> <span class="o">=</span> <span class="p">[</span><span class="n">t</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="n">tensors</span><span class="p">]</span>
<span class="n">row_lengths</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">row_lengths</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="n">data</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="k">return</span> <span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">row_lengths</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">CUASSERT</span><span class="p">(</span><span class="n">cuda_ret</span><span class="p">):</span>
<span class="n">err</span> <span class="o">=</span> <span class="n">cuda_ret</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="k">if</span> <span class="n">err</span> <span class="o">!=</span> <span class="n">cudart</span><span class="o">.</span><span class="n">cudaError_t</span><span class="o">.</span><span class="n">cudaSuccess</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;CUDA ERROR: </span><span class="si">{</span><span class="n">err</span><span class="si">}</span><span class="s2">, error code reference: https://nvidia.github.io/cuda-python/module/cudart.html#cuda.cudart.cudaError_t&quot;</span>
<span class="p">)</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">cuda_ret</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
<span class="k">return</span> <span class="n">cuda_ret</span><span class="p">[</span><span class="mi">1</span><span class="p">:]</span>
<span class="k">return</span> <span class="kc">None</span>
<span class="k">def</span> <span class="nf">_update_cuda_graph_instance</span><span class="p">(</span><span class="n">instance</span><span class="p">,</span> <span class="n">graph</span><span class="p">):</span>
<span class="n">err</span> <span class="o">=</span> <span class="n">cudart</span><span class="o">.</span><span class="n">cudaGraphExecUpdate</span><span class="p">(</span><span class="n">instance</span><span class="p">,</span> <span class="n">graph</span><span class="p">)</span>
<span class="k">if</span> <span class="n">err</span> <span class="o">!=</span> <span class="n">cudart</span><span class="o">.</span><span class="n">cudaError_t</span><span class="o">.</span><span class="n">cudaSuccess</span><span class="p">:</span>
<span class="c1"># When updating cuda graph failed, destroy and instantiate one.</span>
<span class="n">CUASSERT</span><span class="p">(</span><span class="n">cudart</span><span class="o">.</span><span class="n">cudaGraphExecDestroy</span><span class="p">(</span><span class="n">instance</span><span class="p">))</span>
<span class="n">instance</span> <span class="o">=</span> <span class="n">CUASSERT</span><span class="p">(</span><span class="n">cudart</span><span class="o">.</span><span class="n">cudaGraphInstantiate</span><span class="p">(</span><span class="n">graph</span><span class="p">,</span> <span class="mi">0</span><span class="p">))[</span><span class="mi">0</span><span class="p">]</span>
<span class="k">return</span> <span class="n">instance</span>
<span class="k">def</span> <span class="nf">_prepare_attention_mask</span><span class="p">(</span><span class="n">input_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">pad_id</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
<span class="n">is_pad_id_in_inputs</span> <span class="o">=</span> <span class="p">(</span><span class="n">pad_id</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">)</span> <span class="ow">and</span> <span class="p">(</span><span class="n">pad_id</span> <span class="ow">in</span> <span class="n">input_ids</span><span class="p">)</span>
<span class="k">if</span> <span class="n">input_ids</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">is_pad_id_in_inputs</span><span class="p">:</span>
<span class="k">return</span> <span class="n">input_ids</span><span class="o">.</span><span class="n">ne</span><span class="p">(</span><span class="n">pad_id</span><span class="p">)</span><span class="o">.</span><span class="n">int</span><span class="p">()</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">input_ids</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="n">input_ids</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">_tile_beam_width</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">num_beams</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
<span class="n">new_shape</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="n">new_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">new_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">num_beams</span>
<span class="n">tile_size</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">new_shape</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
<span class="n">tile_size</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">insert</span><span class="p">(</span><span class="n">tile_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">num_beams</span><span class="p">)</span>
<span class="n">new_tensor</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">new_tensor</span> <span class="o">=</span> <span class="n">new_tensor</span><span class="o">.</span><span class="n">tile</span><span class="p">(</span><span class="n">tile_size</span><span class="o">.</span><span class="n">tolist</span><span class="p">())</span>
<span class="n">new_tensor</span> <span class="o">=</span> <span class="n">new_tensor</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">new_shape</span><span class="o">.</span><span class="n">tolist</span><span class="p">())</span>
<span class="k">return</span> <span class="n">new_tensor</span>
<span class="k">class</span> <span class="nc">_Runtime</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
<span class="n">runtime_rank</span><span class="p">:</span> <span class="nb">int</span>
<span class="n">runtime</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">Runtime</span>
<span class="n">engine</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">ICudaEngine</span>
<span class="n">ctx_context</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">IExecutionContext</span>
<span class="n">context_0</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">IExecutionContext</span>
<span class="n">context_1</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">IExecutionContext</span>
<span class="n">cuda_graph_instances</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">cudart</span><span class="o">.</span><span class="n">cudaGraphExec_t</span><span class="p">]</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">engine_buffer</span><span class="p">,</span> <span class="n">mapping</span><span class="p">:</span> <span class="n">Mapping</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">__prepare</span><span class="p">(</span><span class="n">mapping</span><span class="p">,</span> <span class="n">engine_buffer</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">__create_and_setup_context</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">address</span><span class="p">,</span> <span class="n">profile_idx</span><span class="p">,</span>
<span class="n">stream</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">trt</span><span class="o">.</span><span class="n">IExecutionContext</span><span class="p">:</span>
<span class="n">context</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">create_execution_context_without_device_memory</span><span class="p">()</span>
<span class="k">assert</span> <span class="n">context</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="n">context</span><span class="o">.</span><span class="n">device_memory</span> <span class="o">=</span> <span class="n">address</span>
<span class="n">context</span><span class="o">.</span><span class="n">set_optimization_profile_async</span><span class="p">(</span><span class="n">profile_idx</span><span class="p">,</span> <span class="n">stream</span><span class="p">)</span>
<span class="k">return</span> <span class="n">context</span>
<span class="k">def</span> <span class="nf">__prepare</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">mapping</span><span class="p">:</span> <span class="n">Mapping</span><span class="p">,</span> <span class="n">engine_buffer</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">runtime_rank</span> <span class="o">=</span> <span class="n">mapping</span><span class="o">.</span><span class="n">rank</span>
<span class="n">local_rank</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime_rank</span> <span class="o">%</span> <span class="n">mapping</span><span class="o">.</span><span class="n">gpus_per_node</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">set_device</span><span class="p">(</span><span class="n">local_rank</span><span class="p">)</span>
<span class="n">CUASSERT</span><span class="p">(</span><span class="n">cudart</span><span class="o">.</span><span class="n">cudaSetDevice</span><span class="p">(</span><span class="n">local_rank</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">Runtime</span><span class="p">(</span><span class="n">logger</span><span class="o">.</span><span class="n">trt_logger</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">engine</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">deserialize_cuda_engine</span><span class="p">(</span><span class="n">engine_buffer</span><span class="p">)</span>
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="c1"># The device_memory_size stores the memory required by the largest profile</span>
<span class="n">address</span> <span class="o">=</span> <span class="n">CUASSERT</span><span class="p">(</span><span class="n">cudart</span><span class="o">.</span><span class="n">cudaMalloc</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">device_memory_size</span><span class="p">))[</span><span class="mi">0</span><span class="p">]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">address</span> <span class="o">=</span> <span class="n">address</span>
<span class="c1"># cuda graph ping-pong instances</span>
<span class="bp">self</span><span class="o">.</span><span class="n">cuda_graph_instances</span> <span class="o">=</span> <span class="p">[</span><span class="kc">None</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">2</span><span class="p">)]</span>
<span class="k">with</span> <span class="n">_scoped_stream</span><span class="p">()</span> <span class="k">as</span> <span class="n">stream</span><span class="p">:</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">num_optimization_profiles</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
<span class="c1"># At step = 0, context_1 is active</span>
<span class="c1"># At step = 1, context_0 is active</span>
<span class="c1"># At step = 2, context_1 is active</span>
<span class="bp">self</span><span class="o">.</span><span class="n">context_0</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">__create_and_setup_context</span><span class="p">(</span>
<span class="n">address</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">stream</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">context_1</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">__create_and_setup_context</span><span class="p">(</span>
<span class="n">address</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">stream</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">ctx_context</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">context_1</span>
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">num_optimization_profiles</span> <span class="o">==</span> <span class="mi">2</span><span class="p">:</span>
<span class="c1"># At step = 0, ctx_context is active</span>
<span class="c1"># At step = 1, context_0 is active</span>
<span class="c1"># At step = 2, context_1 is active</span>
<span class="bp">self</span><span class="o">.</span><span class="n">ctx_context</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">__create_and_setup_context</span><span class="p">(</span>
<span class="n">address</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">stream</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">context_0</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">__create_and_setup_context</span><span class="p">(</span>
<span class="n">address</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">stream</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">context_1</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">__create_and_setup_context</span><span class="p">(</span>
<span class="n">address</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">stream</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">assert</span> <span class="kc">False</span><span class="p">,</span> <span class="s2">&quot;Maximum of up to two optimization profiles only&quot;</span>
<span class="k">def</span> <span class="nf">_set_shape</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">IExecutionContext</span><span class="p">,</span>
<span class="n">shape_dict</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]):</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">num_io_tensors</span><span class="p">):</span>
<span class="n">name</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_name</span><span class="p">(</span><span class="n">i</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_mode</span><span class="p">(</span><span class="n">name</span><span class="p">)</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">TensorIOMode</span><span class="o">.</span><span class="n">INPUT</span><span class="p">:</span>
<span class="n">ok</span> <span class="o">=</span> <span class="n">context</span><span class="o">.</span><span class="n">set_input_shape</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">shape_dict</span><span class="p">[</span><span class="n">name</span><span class="p">])</span>
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;setting input tensor </span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s2"> with shape </span><span class="si">{</span><span class="n">shape_dict</span><span class="p">[</span><span class="n">name</span><span class="p">]</span><span class="si">}</span><span class="s2">&quot;</span>
<span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">ok</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;Couldn&#39;t assign </span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s2"> with shape </span><span class="si">{</span><span class="n">shape_dict</span><span class="p">[</span><span class="n">name</span><span class="p">]</span><span class="si">}</span><span class="s2">, &quot;</span>
<span class="sa">f</span><span class="s2">&quot;engine supports [min, opt, max] = </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_profile_shape</span><span class="p">(</span><span class="n">context</span><span class="o">.</span><span class="n">active_optimization_profile</span><span class="p">,</span><span class="w"> </span><span class="n">name</span><span class="p">)</span><span class="si">}</span><span class="s2">&quot;</span>
<span class="p">)</span>
<span class="k">def</span> <span class="nf">_set_buffer</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">IExecutionContext</span><span class="p">,</span>
<span class="n">buffer_dict</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]):</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">num_io_tensors</span><span class="p">):</span>
<span class="n">name</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_name</span><span class="p">(</span><span class="n">i</span><span class="p">)</span>
<span class="k">if</span> <span class="n">name</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">buffer_dict</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_dtype</span><span class="p">(</span><span class="n">name</span><span class="p">)</span>
<span class="n">shape</span> <span class="o">=</span> <span class="n">context</span><span class="o">.</span><span class="n">get_tensor_shape</span><span class="p">(</span><span class="n">name</span><span class="p">)</span>
<span class="n">buffer_dict</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="nb">tuple</span><span class="p">(</span><span class="n">shape</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">trt_dtype_to_torch</span><span class="p">(</span><span class="n">dtype</span><span class="p">),</span>
<span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">buffer_dict</span><span class="p">[</span><span class="n">name</span><span class="p">]</span><span class="o">.</span><span class="n">is_contiguous</span><span class="p">(</span>
<span class="p">),</span> <span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s2"> is not contiguous()&quot;</span>
<span class="n">context</span><span class="o">.</span><span class="n">set_tensor_address</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">buffer_dict</span><span class="p">[</span><span class="n">name</span><span class="p">]</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">())</span>
<span class="k">def</span> <span class="nf">_run</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">context</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">IExecutionContext</span><span class="p">,</span>
<span class="n">stream</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">Stream</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
<span class="k">if</span> <span class="n">stream</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">stream</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">current_stream</span><span class="p">()</span><span class="o">.</span><span class="n">cuda_stream</span>
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">stream</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">Stream</span><span class="p">):</span>
<span class="n">stream</span> <span class="o">=</span> <span class="n">stream</span><span class="o">.</span><span class="n">cuda_stream</span>
<span class="n">ok</span> <span class="o">=</span> <span class="n">context</span><span class="o">.</span><span class="n">execute_async_v3</span><span class="p">(</span><span class="n">stream</span><span class="p">)</span>
<span class="k">return</span> <span class="n">ok</span>
<span class="k">def</span> <span class="fm">__del__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="n">cudart</span><span class="o">.</span><span class="n">cudaFree</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">address</span><span class="p">)</span>
<div class="viewcode-block" id="ModelConfig">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.ModelConfig">[docs]</a>
<span class="nd">@dataclass</span>
<span class="k">class</span> <span class="nc">ModelConfig</span><span class="p">:</span>
<span class="n">vocab_size</span><span class="p">:</span> <span class="nb">int</span>
<span class="n">num_layers</span><span class="p">:</span> <span class="nb">int</span>
<span class="n">num_heads</span><span class="p">:</span> <span class="nb">int</span>
<span class="n">num_kv_heads</span><span class="p">:</span> <span class="nb">int</span>
<span class="n">hidden_size</span><span class="p">:</span> <span class="nb">int</span>
<span class="n">gpt_attention_plugin</span><span class="p">:</span> <span class="nb">bool</span>
<span class="n">remove_input_padding</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
<span class="n">model_name</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;&quot;</span>
<span class="n">paged_kv_cache</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
<span class="n">cross_attention</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
<span class="n">head_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">has_position_embedding</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span>
<span class="n">has_token_type_embedding</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
<span class="n">tokens_per_block</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">64</span>
<span class="n">max_prompt_embedding_table_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">quant_mode</span><span class="p">:</span> <span class="n">QuantMode</span> <span class="o">=</span> <span class="n">QuantMode</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="n">gather_all_token_logits</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
<span class="n">dtype</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;&quot;</span>
<span class="n">use_custom_all_reduce</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
<span class="n">lora_plugin</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span></div>
<span class="nd">@dataclass</span>
<span class="k">class</span> <span class="nc">SamplingConfig</span><span class="p">:</span>
<span class="n">end_id</span><span class="p">:</span> <span class="nb">int</span>
<span class="n">pad_id</span><span class="p">:</span> <span class="nb">int</span>
<span class="n">max_new_tokens</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mi">20</span><span class="p">)</span>
<span class="n">num_beams</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="n">max_kv_cache_length</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span>
<span class="n">output_sequence_lengths</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="n">return_dict</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="n">temperature</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mf">1.0</span><span class="p">)</span>
<span class="n">top_k</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="n">top_p</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mf">0.0</span><span class="p">)</span>
<span class="n">length_penalty</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mf">1.0</span><span class="p">)</span>
<span class="n">repetition_penalty</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mf">1.0</span><span class="p">)</span>
<span class="n">min_length</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="n">presence_penalty</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mf">0.0</span><span class="p">)</span>
<span class="n">use_beam_hyps</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="c1">## None here means user didn&#39;t set it, and dynamicDecodeOp.cpp take optional value</span>
<span class="c1">## The real default value is set in dynamicDecodeOp.cpp when it&#39;s None</span>
<span class="n">beam_search_diversity_rate</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">init</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">default</span><span class="o">=</span><span class="mf">0.0</span><span class="p">)</span>
<span class="n">random_seed</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">init</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span>
<span class="n">output_cum_log_probs</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">init</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="n">output_log_probs</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">init</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">update</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="n">unused_kwargs</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span>
<span class="k">for</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span> <span class="ow">in</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">key</span><span class="p">):</span>
<span class="nb">setattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">unused_kwargs</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">value</span>
<span class="k">return</span> <span class="n">unused_kwargs</span>
<div class="viewcode-block" id="GenerationSession">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession">[docs]</a>
<span class="k">class</span> <span class="nc">GenerationSession</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
<span class="n">_model_config</span><span class="p">:</span> <span class="n">ModelConfig</span>
<span class="n">mapping</span><span class="p">:</span> <span class="n">Mapping</span>
<span class="n">runtime</span><span class="p">:</span> <span class="n">_Runtime</span>
<span class="n">device</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span>
<span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span>
<span class="n">buffer_allocated</span><span class="p">:</span> <span class="nb">bool</span>
<span class="n">debug_mode</span><span class="p">:</span> <span class="nb">bool</span>
<span class="n">quant_mode</span><span class="p">:</span> <span class="n">QuantMode</span>
<span class="n">cuda_graph_mode</span><span class="p">:</span> <span class="nb">bool</span>
<span class="n">dtype</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span>
<span class="n">debug_tensors_to_save</span><span class="p">:</span> <span class="kc">None</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">model_config</span><span class="p">:</span> <span class="n">ModelConfig</span><span class="p">,</span>
<span class="n">engine_buffer</span><span class="p">,</span>
<span class="n">mapping</span><span class="p">:</span> <span class="n">Mapping</span><span class="p">,</span>
<span class="n">debug_mode</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">debug_tensors_to_save</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">cuda_graph_mode</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">stream</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">Stream</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">model_config</span><span class="p">,</span> <span class="n">ModelConfig</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span> <span class="o">=</span> <span class="n">model_config</span>
<span class="bp">self</span><span class="o">.</span><span class="n">mapping</span> <span class="o">=</span> <span class="n">mapping</span>
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span> <span class="o">=</span> <span class="n">_Runtime</span><span class="p">(</span><span class="n">engine_buffer</span><span class="p">,</span> <span class="n">mapping</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">device</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span>
<span class="sa">f</span><span class="s1">&#39;cuda:</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">runtime_rank</span><span class="w"> </span><span class="o">%</span><span class="w"> </span><span class="n">mapping</span><span class="o">.</span><span class="n">gpus_per_node</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">set_device</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="c1"># dynamic_decoder currently use torch&#39;s current stream, so must let TRT enqueue use same stream here</span>
<span class="bp">self</span><span class="o">.</span><span class="n">stream</span> <span class="o">=</span> <span class="n">stream</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">stream</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">stream</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">Stream</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">set_stream</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">stream</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">debug_mode</span> <span class="o">=</span> <span class="n">debug_mode</span>
<span class="bp">self</span><span class="o">.</span><span class="n">debug_tensors_to_save</span> <span class="o">=</span> <span class="n">debug_tensors_to_save</span>
<span class="bp">self</span><span class="o">.</span><span class="n">cuda_graph_mode</span> <span class="o">=</span> <span class="n">cuda_graph_mode</span>
<span class="c1"># Optional inputs for dynamic decoder</span>
<span class="bp">self</span><span class="o">.</span><span class="n">top_p_decay</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">top_p_min</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">top_p_reset_ids</span> <span class="o">=</span> <span class="kc">None</span>
<span class="c1">#TODO: in tensorrt_llm/cpp/tensorrt_llm/thop/dynamicDecodeOp.cpp it&#39;s T, can be float or half?</span>
<span class="bp">self</span><span class="o">.</span><span class="n">embedding_bias_opt</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer_allocated</span> <span class="o">=</span> <span class="kc">False</span>
<span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span> <span class="o">=</span> <span class="n">pad_vocab_size</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">vocab_size</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span><span class="p">:</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
<span class="s2">&quot;The paged KV cache in Python runtime is experimental. For performance and correctness, please, use C++ runtime.&quot;</span>
<span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">has_pp</span><span class="p">():</span>
<span class="bp">self</span><span class="o">.</span><span class="n">nccl_comm</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">classes</span><span class="o">.</span><span class="n">FasterTransformer</span><span class="o">.</span><span class="n">NcclCommunicatorOp</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">rank</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
<span class="bp">self</span><span class="o">.</span><span class="n">decoder_logits_dtype</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_tensor_dtype</span><span class="p">(</span><span class="s1">&#39;logits&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder_logits_dtype</span> <span class="ow">not</span> <span class="ow">in</span> <span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">]:</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
<span class="s2">&quot;Logits dtype not supported by decoder. Falling back to float32. You may want to change the logits dtype to float16 in your model definition.&quot;</span>
<span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">decoder_logits_dtype</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span>
<span class="bp">self</span><span class="o">.</span><span class="n">dynamic_decoder</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">classes</span><span class="o">.</span><span class="n">FasterTransformer</span><span class="o">.</span><span class="n">DynamicDecodeOp</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">vocab_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder_logits_dtype</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">gather_tree</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">tensorrt_llm</span><span class="o">.</span><span class="n">gather_tree</span>
<span class="n">expected_tensor_names</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_first_pp_rank</span><span class="p">():</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">&#39;input_ids&#39;</span><span class="p">]</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">&#39;hidden_states_input&#39;</span><span class="p">]</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">&#39;logits&#39;</span><span class="p">]</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">model_config</span><span class="o">.</span><span class="n">gather_all_token_logits</span><span class="p">:</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">&#39;last_token_ids&#39;</span><span class="p">]</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">&#39;hidden_states_output&#39;</span><span class="p">]</span>
<span class="k">if</span> <span class="n">model_config</span><span class="o">.</span><span class="n">has_position_embedding</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_first_pp_rank</span><span class="p">(</span>
<span class="p">):</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">&#39;position_ids&#39;</span><span class="p">]</span>
<span class="k">if</span> <span class="n">model_config</span><span class="o">.</span><span class="n">has_token_type_embedding</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_first_pp_rank</span><span class="p">(</span>
<span class="p">):</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">&#39;token_type_ids&#39;</span><span class="p">]</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">&#39;cache_indirection&#39;</span><span class="p">]</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span><span class="p">:</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
<span class="sa">f</span><span class="s1">&#39;kv_cache_block_pointers_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">)</span>
<span class="p">]</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
<span class="sa">f</span><span class="s1">&#39;past_key_value_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">)</span>
<span class="p">]</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
<span class="sa">f</span><span class="s1">&#39;present_key_value_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">)</span>
<span class="p">]</span>
<span class="k">if</span> <span class="n">model_config</span><span class="o">.</span><span class="n">gpt_attention_plugin</span><span class="p">:</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
<span class="s1">&#39;sequence_length&#39;</span><span class="p">,</span> <span class="s1">&#39;context_lengths&#39;</span><span class="p">,</span> <span class="s1">&#39;host_request_types&#39;</span><span class="p">,</span>
<span class="s1">&#39;host_past_key_value_lengths&#39;</span>
<span class="p">]</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
<span class="sa">f</span><span class="s1">&#39;host_max_kv_cache_length_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">)</span>
<span class="p">]</span>
<span class="k">if</span> <span class="n">model_config</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
<span class="n">expected_tensor_names</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="s1">&#39;host_context_lengths&#39;</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
<span class="s1">&#39;attention_mask&#39;</span><span class="p">,</span>
<span class="p">]</span>
<span class="k">if</span> <span class="n">model_config</span><span class="o">.</span><span class="n">max_prompt_embedding_table_size</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
<span class="s1">&#39;prompt_embedding_table&#39;</span><span class="p">,</span> <span class="s1">&#39;tasks&#39;</span><span class="p">,</span> <span class="s1">&#39;prompt_vocab_size&#39;</span>
<span class="p">]</span>
<span class="k">if</span> <span class="n">model_config</span><span class="o">.</span><span class="n">cross_attention</span><span class="p">:</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
<span class="sa">f</span><span class="s1">&#39;cross_present_key_value_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">)</span>
<span class="p">]</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
<span class="sa">f</span><span class="s1">&#39;cross_past_key_value_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">)</span>
<span class="p">]</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
<span class="s1">&#39;encoder_output&#39;</span><span class="p">,</span> <span class="s1">&#39;encoder_input_lengths&#39;</span><span class="p">,</span>
<span class="s1">&#39;encoder_max_input_length&#39;</span>
<span class="p">]</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span> <span class="o">&gt;</span> <span class="mi">1</span> <span class="ow">and</span> <span class="n">model_config</span><span class="o">.</span><span class="n">use_custom_all_reduce</span><span class="p">:</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">&#39;all_reduce_workspace&#39;</span><span class="p">]</span>
<span class="k">if</span> <span class="n">model_config</span><span class="o">.</span><span class="n">lora_plugin</span><span class="p">:</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">&#39;lora_ranks&#39;</span><span class="p">]</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
<span class="sa">f</span><span class="s1">&#39;lora_weights_pointers_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">)</span>
<span class="p">]</span>
<span class="n">found_tensor_names</span> <span class="o">=</span> <span class="p">[</span>
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_name</span><span class="p">(</span><span class="n">i</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">num_io_tensors</span><span class="p">)</span>
<span class="p">]</span>
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">debug_mode</span> <span class="ow">and</span> <span class="nb">set</span><span class="p">(</span><span class="n">expected_tensor_names</span><span class="p">)</span> <span class="o">!=</span> <span class="nb">set</span><span class="p">(</span>
<span class="n">found_tensor_names</span><span class="p">):</span>
<span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;The following expected tensors are not found: </span><span class="si">{</span><span class="nb">set</span><span class="p">(</span><span class="n">expected_tensor_names</span><span class="p">)</span><span class="o">.</span><span class="n">difference</span><span class="p">(</span><span class="nb">set</span><span class="p">(</span><span class="n">found_tensor_names</span><span class="p">))</span><span class="si">}</span><span class="s2">&quot;</span>
<span class="p">)</span>
<span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;Those tensors in engine are not expected: </span><span class="si">{</span><span class="nb">set</span><span class="p">(</span><span class="n">found_tensor_names</span><span class="p">)</span><span class="o">.</span><span class="n">difference</span><span class="p">(</span><span class="nb">set</span><span class="p">(</span><span class="n">expected_tensor_names</span><span class="p">))</span><span class="si">}</span><span class="s2">&quot;</span>
<span class="p">)</span>
<span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Expected tensor names: </span><span class="si">{</span><span class="n">expected_tensor_names</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Found tensor names: </span><span class="si">{</span><span class="n">found_tensor_names</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span>
<span class="s2">&quot;Tensor names in engine are not the same as expected, to use this GenerationSession, &quot;</span> \
<span class="s2">&quot;you need to use GPTLMHeadModel.prepare_inputs to create TRT Network inputs.&quot;</span>
<span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">debug_mode</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">debug_tensors</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span>
<span class="nb">set</span><span class="p">(</span><span class="n">found_tensor_names</span><span class="p">)</span> <span class="o">-</span> <span class="nb">set</span><span class="p">(</span><span class="n">expected_tensor_names</span><span class="p">))</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">vocab_size</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">vocab_size</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">num_layers</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">num_layers</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_size</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> \
<span class="sa">f</span><span class="s2">&quot;num_layers </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">num_layers</span><span class="si">}</span><span class="s2"> must be a multiple of pipeline parallelism size </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_size</span><span class="si">}</span><span class="s2">&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">num_layers</span> <span class="o">//</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_size</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">first_layer</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_rank</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">last_layer</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">num_heads</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">num_heads</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">hidden_size</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">hidden_size</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">use_gpt_attention_plugin</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">gpt_attention_plugin</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">paged_kv_cache</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">paged_kv_cache</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">tokens_per_block</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">tokens_per_block</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">remove_input_padding</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">remove_input_padding</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">num_heads_kv</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">num_kv_heads</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">head_size</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span> <span class="o">//</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_heads</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">head_size</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">head_size</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">quant_mode</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">quant_mode</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">gather_all_token_logits</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">gather_all_token_logits</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">dtype</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="n">str_dtype_to_torch</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">use_custom_all_reduce</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">use_custom_all_reduce</span>
<div class="viewcode-block" id="GenerationSession.cuda_stream_guard">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.cuda_stream_guard">[docs]</a>
<span class="k">def</span> <span class="nf">cuda_stream_guard</span><span class="p">(</span><span class="n">func</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;Sync external stream and set current stream to the one bound to the session. Reset on exit.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="nd">@wraps</span><span class="p">(</span><span class="n">func</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">wrapper</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="n">external_stream</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">current_stream</span><span class="p">()</span>
<span class="k">if</span> <span class="n">external_stream</span> <span class="o">!=</span> <span class="bp">self</span><span class="o">.</span><span class="n">stream</span><span class="p">:</span>
<span class="n">external_stream</span><span class="o">.</span><span class="n">synchronize</span><span class="p">()</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">set_stream</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">stream</span><span class="p">)</span>
<span class="n">ret</span> <span class="o">=</span> <span class="n">func</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
<span class="k">if</span> <span class="n">external_stream</span> <span class="o">!=</span> <span class="bp">self</span><span class="o">.</span><span class="n">stream</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">stream</span><span class="o">.</span><span class="n">synchronize</span><span class="p">()</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">set_stream</span><span class="p">(</span><span class="n">external_stream</span><span class="p">)</span>
<span class="k">return</span> <span class="n">ret</span>
<span class="k">return</span> <span class="n">wrapper</span></div>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">cross_attention</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">cross_attention</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">has_position_embedding</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">has_position_embedding</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">has_token_type_embedding</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">has_token_type_embedding</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">use_lora_plugin</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">lora_plugin</span>
<span class="k">def</span> <span class="nf">__setup_decoder</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">sampling_config</span><span class="p">:</span> <span class="n">SamplingConfig</span><span class="p">,</span>
<span class="n">host_context_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;Allocate buffers and setup the post-processing decoder kernel</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">batch_size</span> <span class="o">=</span> <span class="n">host_context_lengths</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">scfg</span> <span class="o">=</span> <span class="n">sampling_config</span> <span class="c1"># just to make a shorter name, no other meaning</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scfg</span><span class="o">.</span><span class="n">top_k</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">top_k</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.top_k.dtype (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">top_k</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">) must be torch.int32&quot;</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">top_k</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
<span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">batch_size</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.top_k.shape[0] (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">top_k</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">) must equal to batch_size (</span><span class="si">{</span><span class="n">batch_size</span><span class="si">}</span><span class="s2">)&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">top_k</span> <span class="o">=</span> <span class="n">scfg</span><span class="o">.</span><span class="n">top_k</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">top_k</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span> <span class="n">scfg</span><span class="o">.</span><span class="n">top_k</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scfg</span><span class="o">.</span><span class="n">top_p</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">top_p</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.top_p.dtype (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">top_p</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">) must be torch.float32&quot;</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">top_p</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
<span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">batch_size</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.top_p.shape[0] (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">top_p</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">) must equal to batch_size (</span><span class="si">{</span><span class="n">batch_size</span><span class="si">}</span><span class="s2">)&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">top_p</span> <span class="o">=</span> <span class="n">scfg</span><span class="o">.</span><span class="n">top_p</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">top_p</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
<span class="n">scfg</span><span class="o">.</span><span class="n">top_p</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scfg</span><span class="o">.</span><span class="n">temperature</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">temperature</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.temperature.dtype (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">temperature</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">) must be torch.float32&quot;</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">temperature</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
<span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">batch_size</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.temperature.shape[0] (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">temperature</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">) must equal to batch_size (</span><span class="si">{</span><span class="n">batch_size</span><span class="si">}</span><span class="s2">)&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">temperature</span> <span class="o">=</span> <span class="n">scfg</span><span class="o">.</span><span class="n">temperature</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">temperature</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
<span class="n">scfg</span><span class="o">.</span><span class="n">temperature</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scfg</span><span class="o">.</span><span class="n">repetition_penalty</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">repetition_penalty</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.repetition_penalty.dtype (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">repetition_penalty</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">) must be torch.float32&quot;</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">repetition_penalty</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
<span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">batch_size</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.repetition_penalty.shape[0] (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">repetition_penalty</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">) must equal to batch_size (</span><span class="si">{</span><span class="n">batch_size</span><span class="si">}</span><span class="s2">)&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">repetition_penalty</span> <span class="o">=</span> <span class="n">scfg</span><span class="o">.</span><span class="n">repetition_penalty</span>
<span class="k">elif</span> <span class="n">scfg</span><span class="o">.</span><span class="n">repetition_penalty</span> <span class="o">==</span> <span class="mf">1.0</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">repetition_penalty</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">repetition_penalty</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
<span class="n">scfg</span><span class="o">.</span><span class="n">repetition_penalty</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">host_length_penalty</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
<span class="n">scfg</span><span class="o">.</span><span class="n">length_penalty</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">length_penalty</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">host_length_penalty</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scfg</span><span class="o">.</span><span class="n">presence_penalty</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">presence_penalty</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.presence_penalty.dtype (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">presence_penalty</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">) must be torch.float32&quot;</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">presence_penalty</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
<span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">batch_size</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.presence_penalty.shape[0] (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">presence_penalty</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">) must equal to batch_size (</span><span class="si">{</span><span class="n">batch_size</span><span class="si">}</span><span class="s2">)&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">presence_penalty</span> <span class="o">=</span> <span class="n">scfg</span><span class="o">.</span><span class="n">presence_penalty</span>
<span class="k">elif</span> <span class="n">scfg</span><span class="o">.</span><span class="n">presence_penalty</span> <span class="o">==</span> <span class="mf">0.0</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">presence_penalty</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">presence_penalty</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
<span class="n">scfg</span><span class="o">.</span><span class="n">presence_penalty</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="k">assert</span> <span class="p">(</span>
<span class="n">scfg</span><span class="o">.</span><span class="n">presence_penalty</span> <span class="o">==</span> <span class="mf">0.0</span> <span class="ow">or</span> <span class="n">scfg</span><span class="o">.</span><span class="n">repetition_penalty</span> <span class="o">==</span> <span class="mf">1.0</span>
<span class="p">),</span> <span class="sa">f</span><span class="s2">&quot;presence_penalty(</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">presence_penalty</span><span class="si">}</span><span class="s2">) and repetition_penalty(</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">repetition_penalty</span><span class="si">}</span><span class="s2">) cannot be non-default values at the same time.&quot;</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scfg</span><span class="o">.</span><span class="n">min_length</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">min_length</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.min_length.dtype (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">min_length</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">) must be torch.int32&quot;</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">min_length</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
<span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">batch_size</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.min_length.shape[0] (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">min_length</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">) must equal to batch_size (</span><span class="si">{</span><span class="n">batch_size</span><span class="si">}</span><span class="s2">)&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">min_length</span> <span class="o">=</span> <span class="n">scfg</span><span class="o">.</span><span class="n">min_length</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">min_length</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
<span class="n">scfg</span><span class="o">.</span><span class="n">min_length</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scfg</span><span class="o">.</span><span class="n">beam_search_diversity_rate</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">beam_search_diversity_rate</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.beam_search_diversity_rate.dtype (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">beam_search_diversity_rate</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">) must be torch.float32&quot;</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">beam_search_diversity_rate</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
<span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">batch_size</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.beam_search_diversity_rate.shape[0] (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">beam_search_diversity_rate</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">) must equal to batch_size (</span><span class="si">{</span><span class="n">batch_size</span><span class="si">}</span><span class="s2">)&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_search_diversity_rate</span> <span class="o">=</span> <span class="n">scfg</span><span class="o">.</span><span class="n">beam_search_diversity_rate</span>
<span class="k">elif</span> <span class="n">scfg</span><span class="o">.</span><span class="n">beam_search_diversity_rate</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_search_diversity_rate</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">(</span>
<span class="p">[</span><span class="n">batch_size</span><span class="p">],</span>
<span class="n">scfg</span><span class="o">.</span><span class="n">beam_search_diversity_rate</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_search_diversity_rate</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scfg</span><span class="o">.</span><span class="n">random_seed</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">random_seed</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">int64</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.random_seed.dtype (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">random_seed</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">) must be torch.int64&quot;</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">random_seed</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
<span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">batch_size</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.random_seed.shape[0] (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">random_seed</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">) must equal to batch_size (</span><span class="si">{</span><span class="n">batch_size</span><span class="si">}</span><span class="s2">)&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">random_seed</span> <span class="o">=</span> <span class="n">scfg</span><span class="o">.</span><span class="n">random_seed</span>
<span class="k">elif</span> <span class="n">scfg</span><span class="o">.</span><span class="n">random_seed</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">random_seed</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
<span class="n">scfg</span><span class="o">.</span><span class="n">random_seed</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int64</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">random_seed</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
<span class="bp">self</span><span class="o">.</span><span class="n">dynamic_decoder</span><span class="o">.</span><span class="n">setup</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">top_k</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">top_p</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">temperature</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">repetition_penalty</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">presence_penalty</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">min_length</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">host_length_penalty</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_search_diversity_rate</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">random_seed</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">top_p_decay</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">top_p_min</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">top_p_reset_ids</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">end_id</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">,</span> <span class="s2">&quot;end_id cannot be none&quot;</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">pad_id</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">,</span> <span class="s1">&#39;pad_id cannot be none&#39;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">end_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">((</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span><span class="p">,</span> <span class="p">),</span>
<span class="n">scfg</span><span class="o">.</span><span class="n">end_id</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="n">max_context_length</span> <span class="o">=</span> <span class="n">host_context_lengths</span><span class="o">.</span><span class="n">max</span><span class="p">()</span>
<span class="c1"># setup output ids buffer</span>
<span class="k">if</span> <span class="n">input_ids</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">!=</span> <span class="n">host_context_lengths</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]:</span>
<span class="c1"># dim 0 of input_ids is not batch size, which means remove_padding is enabled</span>
<span class="n">split_ids_list</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span>
<span class="n">torch</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">input_ids</span><span class="p">,</span>
<span class="n">host_context_lengths</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span><span class="o">.</span><span class="n">tolist</span><span class="p">(),</span>
<span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">))</span>
<span class="n">padded_input_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nested</span><span class="o">.</span><span class="n">to_padded_tensor</span><span class="p">(</span>
<span class="n">torch</span><span class="o">.</span><span class="n">nested</span><span class="o">.</span><span class="n">nested_tensor</span><span class="p">(</span><span class="n">split_ids_list</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">),</span>
<span class="n">scfg</span><span class="o">.</span><span class="n">pad_id</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">max_context_length</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">padded_input_ids</span> <span class="o">=</span> <span class="n">input_ids</span>
<span class="k">if</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
<span class="n">tiled_input_ids</span> <span class="o">=</span> <span class="n">_tile_beam_width</span><span class="p">(</span><span class="n">padded_input_ids</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span><span class="p">)</span>
<span class="n">tiled_input_ids</span> <span class="o">=</span> <span class="n">tiled_input_ids</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span>
<span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span><span class="p">,</span>
<span class="n">max_context_length</span><span class="p">)</span>
<span class="n">tiled_input_ids</span><span class="o">.</span><span class="n">permute</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="c1"># TODO: delete?</span>
<span class="bp">self</span><span class="o">.</span><span class="n">output_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span>
<span class="p">(</span><span class="n">tiled_input_ids</span><span class="p">,</span>
<span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span> <span class="o">-</span> <span class="n">max_context_length</span><span class="p">),</span>
<span class="n">scfg</span><span class="o">.</span><span class="n">end_id</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">padded_input_ids</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="n">padded_input_ids</span><span class="o">.</span><span class="n">device</span><span class="p">)),</span>
<span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">output_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span>
<span class="p">(</span><span class="n">padded_input_ids</span><span class="p">,</span>
<span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">(</span>
<span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span> <span class="o">-</span> <span class="n">max_context_length</span><span class="p">),</span>
<span class="n">scfg</span><span class="o">.</span><span class="n">end_id</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">padded_input_ids</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="n">padded_input_ids</span><span class="o">.</span><span class="n">device</span><span class="p">)),</span>
<span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="c1"># Note: we still allocate max_seq_length size of parent ids (not max_kv_cache_length).</span>
<span class="bp">self</span><span class="o">.</span><span class="n">parent_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span>
<span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="k">if</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span><span class="p">,</span> <span class="mi">1</span><span class="p">],</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">],</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="k">if</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span> <span class="o">&gt;</span> <span class="mi">1</span> <span class="ow">or</span> <span class="n">scfg</span><span class="o">.</span><span class="n">output_cum_log_probs</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">cum_log_probs</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span><span class="p">),</span>
<span class="o">-</span><span class="mf">1e20</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">cum_log_probs</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="mf">0.0</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">cum_log_probs</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">if</span> <span class="n">scfg</span><span class="o">.</span><span class="n">output_log_probs</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">log_probs</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span>
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">max_new_tokens</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">log_probs</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">finished</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">bool</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="k">if</span> <span class="n">scfg</span><span class="o">.</span><span class="n">use_beam_hyps</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_output_ids_tgt</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">(</span>
<span class="n">size</span><span class="o">=</span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span> <span class="o">*</span> <span class="mi">2</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">],</span>
<span class="n">fill_value</span><span class="o">=</span><span class="n">scfg</span><span class="o">.</span><span class="n">end_id</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_sequence_lengths_tgt</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span>
<span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span> <span class="o">*</span> <span class="mi">2</span><span class="p">],</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_cum_log_probs</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span>
<span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span> <span class="o">*</span> <span class="mi">2</span><span class="p">],</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_normed_scores</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span>
<span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span> <span class="o">*</span> <span class="mi">2</span><span class="p">],</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_log_probs</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span>
<span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span> <span class="o">*</span> <span class="mi">2</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">],</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_min_normed_scores</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_num_beams</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_is_done</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">bool</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_output_ids_tgt</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_sequence_lengths_tgt</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_cum_log_probs</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_normed_scores</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_log_probs</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_min_normed_scores</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_num_beams</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_is_done</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">def</span> <span class="nf">_tensor_dtype</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">name</span><span class="p">):</span>
<span class="c1"># return torch dtype given tensor name for convenience</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="n">trt_dtype_to_torch</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_dtype</span><span class="p">(</span><span class="n">name</span><span class="p">))</span>
<span class="k">return</span> <span class="n">dtype</span>
<div class="viewcode-block" id="GenerationSession.setup">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.setup">[docs]</a>
<span class="k">def</span> <span class="nf">setup</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">max_context_length</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">max_new_tokens</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">beam_width</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
<span class="n">max_kv_cache_length</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">encoder_max_input_length</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">lora_manager</span><span class="p">:</span> <span class="n">LoraManager</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">lora_uids</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
<span class="c1"># Store these params related to buffer size to check against</span>
<span class="c1"># the input shape with the params given in decode()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">=</span> <span class="n">batch_size</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_context_length</span> <span class="o">=</span> <span class="n">max_context_length</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_new_tokens</span> <span class="o">=</span> <span class="n">max_new_tokens</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span> <span class="o">=</span> <span class="n">max_context_length</span> <span class="o">+</span> <span class="n">max_new_tokens</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_width</span> <span class="o">=</span> <span class="n">beam_width</span>
<span class="bp">self</span><span class="o">.</span><span class="n">encoder_max_input_length</span> <span class="o">=</span> <span class="n">encoder_max_input_length</span>
<span class="k">if</span> <span class="n">max_kv_cache_length</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_kv_cache_length</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span>
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span>
<span class="s2">&quot;The max_kv_cache_length is not set, we will use max_seq_length by default.&quot;</span>
<span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">host_max_kv_cache_lengths</span> <span class="o">=</span> <span class="p">[</span>
<span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_kv_cache_length</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span><span class="p">)</span>
<span class="p">]</span>
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">max_kv_cache_length</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
<span class="k">if</span> <span class="n">max_kv_cache_length</span> <span class="o">&gt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">:</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
<span class="s2">&quot;The value of max_kv_cache_length should ideally not exceed max_seq_length. &quot;</span>
<span class="s2">&quot;Therefore, it has been adjusted to match the value of max_seq_length.&quot;</span>
<span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_kv_cache_length</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">max_kv_cache_length</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">host_max_kv_cache_lengths</span> <span class="o">=</span> <span class="p">[</span>
<span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_kv_cache_length</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span><span class="p">)</span>
<span class="p">]</span>
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">max_kv_cache_length</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_kv_cache_length</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span>
<span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">max_kv_cache_length</span><span class="p">)</span><span class="o">.</span><span class="n">item</span><span class="p">())</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_kv_cache_length</span> <span class="o">&gt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">:</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
<span class="s2">&quot;The value of max_kv_cache_length should ideally not exceed max_seq_length. &quot;</span>
<span class="s2">&quot;Therefore, it has been adjusted to match the value of max_seq_length.&quot;</span>
<span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_kv_cache_length</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">max_kv_cache_length</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">)</span>
<span class="k">if</span> <span class="n">max_kv_cache_length</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">!=</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span><span class="p">:</span>
<span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span>
<span class="s2">&quot;max_kv_cache_length tensor&#39;s size is not equal to num_layers! &quot;</span>
<span class="s2">&quot;Note that num_layers = num_total_layers // pipeline_parallelism_size.&quot;</span>
<span class="p">)</span>
<span class="k">assert</span> <span class="kc">False</span>
<span class="bp">self</span><span class="o">.</span><span class="n">host_max_kv_cache_lengths</span> <span class="o">=</span> <span class="p">[</span>
<span class="n">torch</span><span class="o">.</span><span class="n">minimum</span><span class="p">(</span>
<span class="n">max_kv_cache_length</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)[</span><span class="n">i</span><span class="p">],</span>
<span class="n">torch</span><span class="o">.</span><span class="n">IntTensor</span><span class="p">([</span><span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">]))</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span><span class="p">)</span>
<span class="p">]</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">assert</span> <span class="kc">False</span><span class="p">,</span> <span class="s2">&quot;invalid max_kv_cache_length!&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">lora_manager</span> <span class="o">=</span> <span class="n">lora_manager</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span> <span class="o">=</span> <span class="p">{}</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;logits&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span>
<span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_all_token_logits</span> <span class="k">else</span>
<span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">max_context_length</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_tensor_dtype</span><span class="p">(</span><span class="s1">&#39;logits&#39;</span><span class="p">),</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span><span class="p">:</span>
<span class="c1"># use shape info to pass max length info in remove padding mode</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;encoder_max_input_length&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span>
<span class="p">(</span><span class="n">encoder_max_input_length</span><span class="p">,</span> <span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_tensor_dtype</span><span class="p">(</span><span class="s1">&#39;encoder_max_input_length&#39;</span><span class="p">),</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span><span class="p">:</span>
<span class="n">blocks</span> <span class="o">=</span> <span class="n">batch_size</span> <span class="o">*</span> <span class="n">beam_width</span> <span class="o">*</span> <span class="n">math</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_kv_cache_length</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">tokens_per_block</span><span class="p">)</span>
<span class="n">cache_shape</span> <span class="o">=</span> <span class="p">(</span>
<span class="n">blocks</span><span class="p">,</span>
<span class="mi">2</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">num_heads_kv</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">tokens_per_block</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">head_size</span><span class="p">,</span>
<span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">cache_shape</span> <span class="o">=</span> <span class="p">(</span>
<span class="n">batch_size</span><span class="p">,</span>
<span class="mi">2</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">num_heads_kv</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_kv_cache_length</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">head_size</span><span class="p">,</span>
<span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span><span class="p">:</span>
<span class="n">cross_cache_shape</span> <span class="o">=</span> <span class="p">(</span>
<span class="n">batch_size</span><span class="p">,</span>
<span class="mi">2</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">num_heads_kv</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">encoder_max_input_length</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">head_size</span><span class="p">,</span>
<span class="p">)</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">):</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">quant_mode</span><span class="o">.</span><span class="n">has_kv_cache_quant</span><span class="p">():</span>
<span class="c1"># Since torch does not support fp8 now, using int8 here.</span>
<span class="n">kv_cache_type</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">int8</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">kv_cache_type</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dtype</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span> <span class="k">else</span> <span class="bp">self</span><span class="o">.</span><span class="n">_tensor_dtype</span><span class="p">(</span>
<span class="sa">f</span><span class="s1">&#39;present_key_value_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;present_key_value_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span>
<span class="n">cache_shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">kv_cache_type</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;cross_present_key_value_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span>
<span class="n">cross_cache_shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">kv_cache_type</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="c1"># without plugin, we need two set of kv cache buffers,</span>
<span class="c1"># one for inputs, and the other for outputs.</span>
<span class="c1"># They will take turns to act as input and output buffers.</span>
<span class="c1"># Not applicable to cross KV buffers as it&#39;s constant</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;1_present_key_value_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span>
<span class="n">cache_shape</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_tensor_dtype</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;present_key_value_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">),</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_custom_all_reduce</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
<span class="n">set_peer_access</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="p">)</span>
<span class="n">float_element_size</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([],</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">)</span><span class="o">.</span><span class="n">element_size</span><span class="p">()</span>
<span class="n">buffer_size</span> <span class="o">=</span> <span class="n">batch_size</span> <span class="o">*</span> <span class="n">beam_width</span> <span class="o">*</span> <span class="n">max_context_length</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span> <span class="o">*</span> <span class="n">float_element_size</span>
<span class="n">barrier_size</span> <span class="o">=</span> <span class="n">IpcMemory</span><span class="o">.</span><span class="n">IPC_BARRIERS_SIZE_PER_GPU</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span>
<span class="bp">self</span><span class="o">.</span><span class="n">ipc_buffers</span> <span class="o">=</span> <span class="n">IpcMemory</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="p">,</span> <span class="n">buffer_size</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">ipc_barriers_in</span> <span class="o">=</span> <span class="n">IpcMemory</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="p">,</span> <span class="n">barrier_size</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">ipc_barriers_out</span> <span class="o">=</span> <span class="n">IpcMemory</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="p">,</span> <span class="n">barrier_size</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">all_reduce_workspace</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">ipc_buffers</span><span class="o">.</span><span class="n">serialize</span><span class="p">()</span> <span class="o">+</span>
<span class="bp">self</span><span class="o">.</span><span class="n">ipc_barriers_in</span><span class="o">.</span><span class="n">serialize</span><span class="p">()</span> <span class="o">+</span>
<span class="bp">self</span><span class="o">.</span><span class="n">ipc_barriers_out</span><span class="o">.</span><span class="n">serialize</span><span class="p">(),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int64</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="s2">&quot;cpu&quot;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_lora_plugin</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">lora_manager</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">assert</span> <span class="n">lora_uids</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="n">lora_weights_pointers_list</span> <span class="o">=</span> <span class="p">[</span>
<span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">2</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int64</span><span class="p">)</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span>
<span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span><span class="p">)</span>
<span class="p">]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="o">.</span><span class="n">update</span><span class="p">({</span>
<span class="s1">&#39;lora_ranks&#39;</span><span class="p">:</span>
<span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span>
<span class="p">})</span>
<span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span><span class="p">):</span>
<span class="n">layer_idx</span> <span class="o">=</span> <span class="n">idx</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="o">.</span><span class="n">update</span><span class="p">({</span>
<span class="sa">f</span><span class="s1">&#39;lora_weights_pointers_</span><span class="si">{</span><span class="n">layer_idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">:</span>
<span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">2</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int64</span><span class="p">)</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span>
<span class="p">})</span>
<span class="k">for</span> <span class="n">batch_idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
<span class="n">lora_uid</span> <span class="o">=</span> <span class="n">lora_uids</span><span class="p">[</span><span class="n">batch_idx</span><span class="p">]</span>
<span class="k">if</span> <span class="n">lora_uid</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;lora_ranks&#39;</span><span class="p">][</span>
<span class="n">batch_idx</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">lora_manager</span><span class="o">.</span><span class="n">uid_to_low_ranks</span><span class="p">(</span>
<span class="n">lora_uid</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;lora_weights_pointers_</span><span class="si">{</span><span class="n">layer_idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">][</span>
<span class="n">batch_idx</span><span class="p">][</span>
<span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">lora_manager</span><span class="o">.</span><span class="n">lora_weights_pointers_list</span><span class="p">[</span>
<span class="n">layer_idx</span><span class="p">][</span><span class="n">lora_uid</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;lora_weights_pointers_</span><span class="si">{</span><span class="n">layer_idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">][</span>
<span class="n">batch_idx</span><span class="p">][</span>
<span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">lora_manager</span><span class="o">.</span><span class="n">lora_weights_pointers_list</span><span class="p">[</span>
<span class="n">layer_idx</span><span class="p">][</span><span class="n">lora_uid</span><span class="p">][</span><span class="mi">1</span><span class="p">]</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;lora_ranks&#39;</span><span class="p">][</span><span class="n">batch_idx</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer_allocated</span> <span class="o">=</span> <span class="kc">True</span></div>
<span class="k">def</span> <span class="nf">_get_context_shape_buffer</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">input_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">context_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">host_context_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">position_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">last_token_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">attention_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">cache_indirection</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">kv_cache_block_pointers</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span>
<span class="n">hidden_states_input</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">prompt_embedding_table</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">tasks</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">prompt_vocab_size</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">encoder_output</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">encoder_input_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
<span class="n">ctx_shape</span> <span class="o">=</span> <span class="p">{</span>
<span class="s1">&#39;context_lengths&#39;</span><span class="p">:</span> <span class="n">context_lengths</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span>
<span class="s1">&#39;cache_indirection&#39;</span><span class="p">:</span> <span class="n">cache_indirection</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span>
<span class="p">}</span>
<span class="n">ctx_buffer</span> <span class="o">=</span> <span class="p">{</span>
<span class="s1">&#39;context_lengths&#39;</span><span class="p">:</span> <span class="n">context_lengths</span><span class="o">.</span><span class="n">contiguous</span><span class="p">(),</span>
<span class="s1">&#39;cache_indirection&#39;</span><span class="p">:</span> <span class="n">cache_indirection</span><span class="o">.</span><span class="n">contiguous</span><span class="p">(),</span>
<span class="p">}</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_position_embedding</span><span class="p">:</span>
<span class="n">ctx_shape</span><span class="p">[</span><span class="s1">&#39;position_ids&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">position_ids</span><span class="o">.</span><span class="n">shape</span>
<span class="n">ctx_buffer</span><span class="p">[</span><span class="s1">&#39;position_ids&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">position_ids</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span><span class="p">:</span>
<span class="n">ctx_shape</span><span class="p">[</span><span class="s1">&#39;encoder_output&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">encoder_output</span><span class="o">.</span><span class="n">shape</span>
<span class="n">ctx_shape</span><span class="p">[</span><span class="s1">&#39;encoder_input_lengths&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">encoder_input_lengths</span><span class="o">.</span><span class="n">shape</span>
<span class="n">ctx_shape</span><span class="p">[</span><span class="s1">&#39;encoder_max_input_length&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span>
<span class="s1">&#39;encoder_max_input_length&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span>
<span class="n">ctx_buffer</span><span class="p">[</span><span class="s1">&#39;encoder_output&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">encoder_output</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span>
<span class="n">ctx_buffer</span><span class="p">[</span>
<span class="s1">&#39;encoder_input_lengths&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">encoder_input_lengths</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span>
<span class="n">ctx_buffer</span><span class="p">[</span><span class="s1">&#39;encoder_max_input_length&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span>
<span class="s1">&#39;encoder_max_input_length&#39;</span><span class="p">]</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">has_pp</span><span class="p">():</span>
<span class="n">hidden_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span>
<span class="n">hidden_states_input</span> <span class="o">=</span> <span class="n">hidden_states_input</span><span class="o">.</span><span class="n">resize_</span><span class="p">(</span>
<span class="n">input_ids</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">input_ids</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">hidden_size</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
<span class="n">ctx_buffer</span><span class="p">[</span><span class="s1">&#39;logits&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;logits&#39;</span><span class="p">]</span>
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_all_token_logits</span><span class="p">:</span>
<span class="n">ctx_shape</span><span class="p">[</span><span class="s1">&#39;last_token_ids&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">last_token_ids</span><span class="o">.</span><span class="n">shape</span>
<span class="n">ctx_buffer</span><span class="p">[</span><span class="s1">&#39;last_token_ids&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">last_token_ids</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">ctx_shape</span><span class="p">[</span><span class="s1">&#39;hidden_states_output&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">hidden_states_input</span><span class="o">.</span><span class="n">shape</span>
<span class="n">ctx_buffer</span><span class="p">[</span><span class="s1">&#39;hidden_states_output&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">hidden_states_input</span><span class="o">.</span><span class="n">contiguous</span><span class="p">(</span>
<span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_first_pp_rank</span><span class="p">():</span>
<span class="n">ctx_shape</span><span class="p">[</span><span class="s1">&#39;input_ids&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">input_ids</span><span class="o">.</span><span class="n">shape</span>
<span class="n">ctx_buffer</span><span class="p">[</span><span class="s1">&#39;input_ids&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">input_ids</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">ctx_shape</span><span class="p">[</span><span class="s1">&#39;hidden_states_input&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">hidden_states_input</span><span class="o">.</span><span class="n">shape</span>
<span class="n">ctx_buffer</span><span class="p">[</span><span class="s1">&#39;hidden_states_input&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">hidden_states_input</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span>
<span class="k">if</span> <span class="n">prompt_embedding_table</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">ctx_buffer</span><span class="p">[</span>
<span class="s1">&#39;prompt_embedding_table&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">prompt_embedding_table</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span>
<span class="n">ctx_shape</span><span class="p">[</span><span class="s1">&#39;prompt_embedding_table&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">prompt_embedding_table</span><span class="o">.</span><span class="n">shape</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
<span class="n">tasks_generation</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">concat</span><span class="p">([</span>
<span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">([</span><span class="n">context_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]</span><span class="o">.</span><span class="n">item</span><span class="p">()],</span>
<span class="n">tasks</span><span class="p">[</span><span class="n">b</span><span class="p">]</span><span class="o">.</span><span class="n">item</span><span class="p">(),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
<span class="k">for</span> <span class="n">b</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">context_lengths</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">))</span>
<span class="p">])</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">tasks_generation</span> <span class="o">=</span> <span class="n">tasks</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">ctx_buffer</span><span class="p">[</span><span class="s1">&#39;tasks&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">tasks_generation</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span>
<span class="n">ctx_shape</span><span class="p">[</span><span class="s1">&#39;tasks&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">tasks_generation</span><span class="o">.</span><span class="n">shape</span>
<span class="n">ctx_buffer</span><span class="p">[</span><span class="s1">&#39;prompt_vocab_size&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">prompt_vocab_size</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span>
<span class="n">ctx_shape</span><span class="p">[</span><span class="s1">&#39;prompt_vocab_size&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">prompt_vocab_size</span><span class="o">.</span><span class="n">shape</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span><span class="p">:</span>
<span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span><span class="p">):</span>
<span class="n">layer_idx</span> <span class="o">=</span> <span class="n">idx</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span>
<span class="n">ctx_buffer</span><span class="p">[</span>
<span class="sa">f</span><span class="s1">&#39;kv_cache_block_pointers_</span><span class="si">{</span><span class="n">layer_idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">kv_cache_block_pointers</span><span class="p">[</span>
<span class="n">idx</span><span class="p">]</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span>
<span class="n">shape</span> <span class="o">=</span> <span class="n">kv_cache_block_pointers</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span>
<span class="n">shape</span> <span class="o">=</span> <span class="p">[</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="o">*</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">:]]</span>
<span class="n">ctx_shape</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;kv_cache_block_pointers_</span><span class="si">{</span><span class="n">layer_idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">shape</span>
<span class="n">batch_size</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span><span class="p">:</span>
<span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">):</span>
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
<span class="n">kv_cache_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_heads_kv</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">head_size</span><span class="p">)</span>
<span class="c1"># for empty tensor, TRT does not really use the tensor data, so any dtype is fine</span>
<span class="n">kv_cache_buffer</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="n">ctx_shape</span><span class="o">.</span><span class="n">update</span><span class="p">({</span>
<span class="sa">f</span><span class="s1">&#39;past_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">:</span> <span class="n">kv_cache_shape</span><span class="p">,</span>
<span class="p">})</span>
<span class="n">ctx_buffer</span><span class="o">.</span><span class="n">update</span><span class="p">({</span>
<span class="sa">f</span><span class="s1">&#39;past_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">:</span>
<span class="n">kv_cache_buffer</span><span class="p">,</span>
<span class="sa">f</span><span class="s1">&#39;present_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;present_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">],</span>
<span class="p">})</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span><span class="p">:</span>
<span class="n">cross_kv_cache_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">num_heads_kv</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">head_size</span><span class="p">)</span>
<span class="c1"># for empty tensor, TRT does not really use the tensor data, so any dtype is fine</span>
<span class="n">cross_kv_cache_buffer</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="n">ctx_shape</span><span class="o">.</span><span class="n">update</span><span class="p">({</span>
<span class="sa">f</span><span class="s1">&#39;cross_past_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">:</span>
<span class="n">cross_kv_cache_shape</span><span class="p">,</span>
<span class="p">})</span>
<span class="n">ctx_buffer</span><span class="o">.</span><span class="n">update</span><span class="p">({</span>
<span class="sa">f</span><span class="s1">&#39;cross_past_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">:</span>
<span class="n">cross_kv_cache_buffer</span><span class="p">,</span>
<span class="sa">f</span><span class="s1">&#39;cross_present_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;cross_present_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">],</span>
<span class="p">})</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">key_value_cache</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;present_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">]</span>
<span class="n">cache_shape</span> <span class="o">=</span> <span class="n">key_value_cache</span><span class="o">.</span><span class="n">shape</span>
<span class="n">ctx_shape</span><span class="o">.</span><span class="n">update</span><span class="p">({</span>
<span class="sa">f</span><span class="s1">&#39;past_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">:</span> <span class="n">cache_shape</span><span class="p">,</span>
<span class="p">})</span>
<span class="n">ctx_buffer</span><span class="o">.</span><span class="n">update</span><span class="p">({</span>
<span class="sa">f</span><span class="s1">&#39;past_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">:</span>
<span class="n">key_value_cache</span><span class="p">,</span>
<span class="sa">f</span><span class="s1">&#39;present_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">:</span>
<span class="n">key_value_cache</span><span class="p">,</span>
<span class="p">})</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span><span class="p">:</span>
<span class="n">cross_cache_shape</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span>
<span class="sa">f</span><span class="s1">&#39;cross_present_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span>
<span class="n">cross_cache_buffer</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span>
<span class="sa">f</span><span class="s1">&#39;cross_present_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">]</span>
<span class="n">ctx_shape</span><span class="o">.</span><span class="n">update</span><span class="p">({</span>
<span class="sa">f</span><span class="s1">&#39;cross_past_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">:</span>
<span class="n">cross_cache_shape</span><span class="p">,</span>
<span class="p">})</span>
<span class="n">ctx_buffer</span><span class="o">.</span><span class="n">update</span><span class="p">({</span>
<span class="sa">f</span><span class="s1">&#39;cross_past_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">:</span>
<span class="n">cross_cache_buffer</span><span class="p">,</span>
<span class="sa">f</span><span class="s1">&#39;cross_present_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">:</span>
<span class="n">cross_cache_buffer</span>
<span class="p">})</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
<span class="c1"># context request</span>
<span class="n">host_request_types</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="s1">&#39;cpu&#39;</span><span class="p">)</span><span class="o">.</span><span class="n">int</span><span class="p">()</span>
<span class="n">ctx_shape</span><span class="o">.</span><span class="n">update</span><span class="p">({</span>
<span class="s1">&#39;sequence_length&#39;</span><span class="p">:</span> <span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="p">),</span>
<span class="s1">&#39;host_past_key_value_lengths&#39;</span><span class="p">:</span> <span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="p">),</span>
<span class="s1">&#39;host_request_types&#39;</span><span class="p">:</span> <span class="n">host_request_types</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span>
<span class="p">})</span>
<span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">):</span>
<span class="n">ctx_shape</span><span class="o">.</span><span class="n">update</span><span class="p">({</span>
<span class="sa">f</span><span class="s1">&#39;host_max_kv_cache_length_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">:</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="p">),</span>
<span class="p">})</span>
<span class="n">ctx_buffer</span><span class="o">.</span><span class="n">update</span><span class="p">({</span>
<span class="sa">f</span><span class="s1">&#39;host_max_kv_cache_length_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">host_max_kv_cache_lengths</span><span class="p">[</span><span class="n">idx</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">],</span>
<span class="p">})</span>
<span class="n">ctx_buffer</span><span class="o">.</span><span class="n">update</span><span class="p">({</span>
<span class="s1">&#39;sequence_length&#39;</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span><span class="p">,</span>
<span class="s1">&#39;host_past_key_value_lengths&#39;</span><span class="p">:</span>
<span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span>
<span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span>
<span class="p">),</span> <span class="c1"># field 0: past_key_value_length, field 1: is_context (deprecated). changed to [0], otherwise affects batch padded input mode</span>
<span class="s1">&#39;host_request_types&#39;</span><span class="p">:</span>
<span class="n">host_request_types</span><span class="o">.</span><span class="n">contiguous</span><span class="p">(),</span>
<span class="p">})</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
<span class="n">ctx_buffer</span><span class="p">[</span>
<span class="s1">&#39;host_context_lengths&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">host_context_lengths</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span>
<span class="n">ctx_shape</span><span class="p">[</span><span class="s1">&#39;host_context_lengths&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">host_context_lengths</span><span class="o">.</span><span class="n">shape</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">ctx_shape</span><span class="o">.</span><span class="n">update</span><span class="p">({</span><span class="s1">&#39;attention_mask&#39;</span><span class="p">:</span> <span class="n">attention_mask</span><span class="o">.</span><span class="n">shape</span><span class="p">})</span>
<span class="n">ctx_buffer</span><span class="o">.</span><span class="n">update</span><span class="p">({</span><span class="s1">&#39;attention_mask&#39;</span><span class="p">:</span> <span class="n">attention_mask</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()})</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_custom_all_reduce</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
<span class="n">ctx_shape</span><span class="p">[</span><span class="s1">&#39;all_reduce_workspace&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">all_reduce_workspace</span><span class="o">.</span><span class="n">shape</span>
<span class="n">ctx_buffer</span><span class="p">[</span><span class="s1">&#39;all_reduce_workspace&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">all_reduce_workspace</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_lora_plugin</span><span class="p">:</span>
<span class="n">ctx_shape</span><span class="p">[</span><span class="s1">&#39;lora_ranks&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;lora_ranks&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span>
<span class="n">ctx_buffer</span><span class="p">[</span><span class="s1">&#39;lora_ranks&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;lora_ranks&#39;</span><span class="p">]</span>
<span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span><span class="p">):</span>
<span class="n">layer_idx</span> <span class="o">=</span> <span class="n">idx</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span>
<span class="n">ctx_shape</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;lora_weights_pointers_</span><span class="si">{</span><span class="n">layer_idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span>
<span class="sa">f</span><span class="s1">&#39;lora_weights_pointers_</span><span class="si">{</span><span class="n">layer_idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span>
<span class="n">ctx_buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;lora_weights_pointers_</span><span class="si">{</span><span class="n">layer_idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span>
<span class="sa">f</span><span class="s1">&#39;lora_weights_pointers_</span><span class="si">{</span><span class="n">layer_idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">]</span>
<span class="k">return</span> <span class="n">ctx_shape</span><span class="p">,</span> <span class="n">ctx_buffer</span>
<span class="k">def</span> <span class="nf">_get_next_step_shape_buffer</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">beam_width</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">max_context_length</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">step</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">context_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">host_context_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">position_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">last_token_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">attention_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">cache_indirection</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">kv_cache_block_pointers</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span>
<span class="n">hidden_states_input</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">prompt_embedding_table</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">tasks</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">prompt_vocab_size</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">encoder_output</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">encoder_input_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
<span class="n">next_step_shape</span> <span class="o">=</span> <span class="p">{</span>
<span class="s1">&#39;context_lengths&#39;</span><span class="p">:</span> <span class="n">context_lengths</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span>
<span class="s1">&#39;cache_indirection&#39;</span><span class="p">:</span> <span class="n">cache_indirection</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span>
<span class="p">}</span>
<span class="n">next_step_buffer</span> <span class="o">=</span> <span class="p">{</span>
<span class="s1">&#39;context_lengths&#39;</span><span class="p">:</span> <span class="n">context_lengths</span><span class="o">.</span><span class="n">contiguous</span><span class="p">(),</span>
<span class="s1">&#39;cache_indirection&#39;</span><span class="p">:</span> <span class="n">cache_indirection</span><span class="o">.</span><span class="n">contiguous</span><span class="p">(),</span>
<span class="p">}</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">has_pp</span><span class="p">():</span>
<span class="n">hidden_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span>
<span class="n">shape</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">batch_size</span> <span class="o">*</span> <span class="n">beam_width</span><span class="p">,</span>
<span class="n">hidden_size</span><span class="p">)</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span> <span class="k">else</span> <span class="p">(</span>
<span class="n">batch_size</span> <span class="o">*</span> <span class="n">beam_width</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">)</span>
<span class="n">hidden_states_input</span> <span class="o">=</span> <span class="n">hidden_states_input</span><span class="o">.</span><span class="n">resize_</span><span class="p">(</span><span class="o">*</span><span class="n">shape</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
<span class="n">next_step_buffer</span><span class="p">[</span><span class="s1">&#39;logits&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;logits&#39;</span><span class="p">]</span>
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_all_token_logits</span><span class="p">:</span>
<span class="n">next_step_shape</span><span class="p">[</span><span class="s1">&#39;last_token_ids&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">last_token_ids</span><span class="o">.</span><span class="n">shape</span>
<span class="n">next_step_buffer</span><span class="p">[</span><span class="s1">&#39;last_token_ids&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">last_token_ids</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">next_step_shape</span><span class="p">[</span><span class="s1">&#39;hidden_states_output&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">hidden_states_input</span><span class="o">.</span><span class="n">shape</span>
<span class="n">next_step_buffer</span><span class="p">[</span>
<span class="s1">&#39;hidden_states_output&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">hidden_states_input</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_first_pp_rank</span><span class="p">():</span>
<span class="n">next_step_shape</span><span class="p">[</span><span class="s1">&#39;input_ids&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span>
<span class="mi">1</span><span class="p">,</span> <span class="n">batch_size</span> <span class="o">*</span>
<span class="n">beam_width</span><span class="p">)</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span> <span class="k">else</span> <span class="p">(</span><span class="n">batch_size</span> <span class="o">*</span>
<span class="n">beam_width</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">next_step_buffer</span><span class="p">[</span><span class="s1">&#39;input_ids&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">next_step_shape</span><span class="p">[</span><span class="s1">&#39;hidden_states_input&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">hidden_states_input</span><span class="o">.</span><span class="n">shape</span>
<span class="n">next_step_buffer</span><span class="p">[</span>
<span class="s1">&#39;hidden_states_input&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">hidden_states_input</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
<span class="n">next_step_shape</span><span class="p">[</span><span class="s1">&#39;host_context_lengths&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">host_context_lengths</span><span class="o">.</span><span class="n">shape</span>
<span class="n">next_step_buffer</span><span class="p">[</span>
<span class="s1">&#39;host_context_lengths&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">host_context_lengths</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_position_embedding</span><span class="p">:</span>
<span class="n">next_step_shape</span><span class="p">[</span><span class="s1">&#39;position_ids&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">position_ids</span><span class="o">.</span><span class="n">shape</span>
<span class="n">next_step_buffer</span><span class="p">[</span><span class="s1">&#39;position_ids&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">position_ids</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span><span class="p">:</span>
<span class="c1"># hack: disable (or minimize) cross qkv computation at generation phase</span>
<span class="c1"># TODO: enable [0,0,.] true zero tensor input; or use IfConditionalLayer</span>
<span class="n">next_step_shape</span><span class="p">[</span><span class="s1">&#39;encoder_output&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="p">[</span>
<span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">encoder_output</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
<span class="p">]</span> <span class="c1"># encoder_output.shape</span>
<span class="n">next_step_shape</span><span class="p">[</span>
<span class="s1">&#39;encoder_input_lengths&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">encoder_input_lengths</span><span class="o">.</span><span class="n">shape</span>
<span class="n">next_step_shape</span><span class="p">[</span><span class="s1">&#39;encoder_max_input_length&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span>
<span class="s1">&#39;encoder_max_input_length&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span>
<span class="n">next_step_buffer</span><span class="p">[</span><span class="s1">&#39;encoder_output&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">encoder_output</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span>
<span class="n">next_step_buffer</span><span class="p">[</span>
<span class="s1">&#39;encoder_input_lengths&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">encoder_input_lengths</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span>
<span class="n">next_step_buffer</span><span class="p">[</span><span class="s1">&#39;encoder_max_input_length&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span>
<span class="s1">&#39;encoder_max_input_length&#39;</span><span class="p">]</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span><span class="p">:</span>
<span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span><span class="p">):</span>
<span class="n">layer_idx</span> <span class="o">=</span> <span class="n">idx</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span>
<span class="n">next_step_buffer</span><span class="p">[</span>
<span class="sa">f</span><span class="s1">&#39;kv_cache_block_pointers_</span><span class="si">{</span><span class="n">layer_idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">kv_cache_block_pointers</span><span class="p">[</span>
<span class="n">idx</span><span class="p">]</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span>
<span class="n">shape</span> <span class="o">=</span> <span class="n">kv_cache_block_pointers</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span>
<span class="n">shape</span> <span class="o">=</span> <span class="p">[</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="o">*</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">:]]</span>
<span class="n">next_step_shape</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;kv_cache_block_pointers_</span><span class="si">{</span><span class="n">layer_idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">shape</span>
<span class="k">if</span> <span class="n">prompt_embedding_table</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">next_step_buffer</span><span class="p">[</span>
<span class="s1">&#39;prompt_embedding_table&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">prompt_embedding_table</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span>
<span class="n">next_step_shape</span><span class="p">[</span>
<span class="s1">&#39;prompt_embedding_table&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">prompt_embedding_table</span><span class="o">.</span><span class="n">shape</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
<span class="n">gen_tasks</span> <span class="o">=</span> <span class="n">tasks</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">gen_tasks</span> <span class="o">=</span> <span class="n">tasks</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">next_step_buffer</span><span class="p">[</span><span class="s1">&#39;tasks&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">gen_tasks</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span>
<span class="n">next_step_shape</span><span class="p">[</span><span class="s1">&#39;tasks&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">gen_tasks</span><span class="o">.</span><span class="n">shape</span>
<span class="n">next_step_buffer</span><span class="p">[</span>
<span class="s1">&#39;prompt_vocab_size&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">prompt_vocab_size</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span>
<span class="n">next_step_shape</span><span class="p">[</span><span class="s1">&#39;prompt_vocab_size&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">prompt_vocab_size</span><span class="o">.</span><span class="n">shape</span>
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span><span class="p">:</span>
<span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">):</span>
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
<span class="k">if</span> <span class="n">step</span> <span class="o">%</span> <span class="mi">2</span><span class="p">:</span>
<span class="n">next_step_buffer</span><span class="o">.</span><span class="n">update</span><span class="p">({</span>
<span class="sa">f</span><span class="s1">&#39;past_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;1_present_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">],</span>
<span class="sa">f</span><span class="s1">&#39;present_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;present_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">],</span>
<span class="p">})</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">next_step_buffer</span><span class="o">.</span><span class="n">update</span><span class="p">({</span>
<span class="sa">f</span><span class="s1">&#39;past_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;present_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">],</span>
<span class="sa">f</span><span class="s1">&#39;present_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;1_present_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">],</span>
<span class="p">})</span>
<span class="n">next_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">beam_width</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_heads_kv</span><span class="p">,</span>
<span class="n">max_context_length</span> <span class="o">+</span> <span class="n">step</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">head_size</span><span class="p">)</span>
<span class="n">next_step_shape</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;past_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">next_shape</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">key_value_cache</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;present_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">]</span>
<span class="n">cache_shape</span> <span class="o">=</span> <span class="n">key_value_cache</span><span class="o">.</span><span class="n">shape</span>
<span class="n">next_step_buffer</span><span class="o">.</span><span class="n">update</span><span class="p">({</span>
<span class="sa">f</span><span class="s1">&#39;past_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">:</span>
<span class="n">key_value_cache</span><span class="p">,</span>
<span class="sa">f</span><span class="s1">&#39;present_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">:</span>
<span class="n">key_value_cache</span><span class="p">,</span>
<span class="p">})</span>
<span class="n">next_step_shape</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;past_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">cache_shape</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span><span class="p">:</span>
<span class="n">cross_cache_shape</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span>
<span class="sa">f</span><span class="s1">&#39;cross_present_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span>
<span class="n">cross_cache_buffer</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span>
<span class="sa">f</span><span class="s1">&#39;cross_present_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">]</span>
<span class="n">next_step_buffer</span><span class="o">.</span><span class="n">update</span><span class="p">({</span>
<span class="sa">f</span><span class="s1">&#39;cross_past_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">:</span>
<span class="n">cross_cache_buffer</span><span class="p">,</span>
<span class="sa">f</span><span class="s1">&#39;cross_present_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">:</span>
<span class="n">cross_cache_buffer</span><span class="p">,</span>
<span class="p">})</span>
<span class="n">next_step_shape</span><span class="p">[</span>
<span class="sa">f</span><span class="s1">&#39;cross_past_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">cross_cache_shape</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
<span class="c1"># generation requests</span>
<span class="n">host_request_types</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones_like</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="s1">&#39;cpu&#39;</span><span class="p">)</span><span class="o">.</span><span class="n">int</span><span class="p">()</span>
<span class="c1"># previous [past_kv_length, is_context] has been deprecated. only past_kv_length should be given here</span>
<span class="c1"># Note we should use max_context_length here to align to max -- but isn&#39;t this done in attn plugin&#39;s max_element() already?</span>
<span class="n">host_past_key_value_lengths</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span>
<span class="p">[</span><span class="n">max_context_length</span> <span class="o">+</span> <span class="n">step</span><span class="p">]</span> <span class="o">*</span> <span class="p">(</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">beam_width</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="s1">&#39;cpu&#39;</span><span class="p">)</span>
<span class="n">next_step_shape</span><span class="o">.</span><span class="n">update</span><span class="p">({</span>
<span class="s1">&#39;sequence_length&#39;</span><span class="p">:</span> <span class="p">(</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">beam_width</span><span class="p">,</span> <span class="p">),</span>
<span class="s1">&#39;host_past_key_value_lengths&#39;</span><span class="p">:</span>
<span class="n">host_past_key_value_lengths</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span>
<span class="s1">&#39;host_request_types&#39;</span><span class="p">:</span>
<span class="n">host_request_types</span><span class="o">.</span><span class="n">shape</span>
<span class="p">})</span>
<span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">):</span>
<span class="n">next_step_shape</span><span class="o">.</span><span class="n">update</span><span class="p">({</span>
<span class="sa">f</span><span class="s1">&#39;host_max_kv_cache_length_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">:</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="p">),</span>
<span class="p">})</span>
<span class="n">next_step_buffer</span><span class="o">.</span><span class="n">update</span><span class="p">({</span>
<span class="sa">f</span><span class="s1">&#39;host_max_kv_cache_length_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">host_max_kv_cache_lengths</span><span class="p">[</span><span class="n">idx</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">],</span>
<span class="p">})</span>
<span class="n">next_step_buffer</span><span class="o">.</span><span class="n">update</span><span class="p">({</span>
<span class="c1"># Sequence lengths are not used in the context phase actually.</span>
<span class="s1">&#39;sequence_length&#39;</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span><span class="p">,</span>
<span class="s1">&#39;host_past_key_value_lengths&#39;</span><span class="p">:</span> <span class="n">host_past_key_value_lengths</span><span class="p">,</span>
<span class="s1">&#39;host_request_types&#39;</span><span class="p">:</span> <span class="n">host_request_types</span><span class="p">,</span>
<span class="p">})</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
<span class="n">next_step_buffer</span><span class="p">[</span>
<span class="s1">&#39;host_context_lengths&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">host_context_lengths</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span>
<span class="n">next_step_shape</span><span class="p">[</span>
<span class="s1">&#39;host_context_lengths&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">host_context_lengths</span><span class="o">.</span><span class="n">shape</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">next_step_shape</span><span class="o">.</span><span class="n">update</span><span class="p">({</span><span class="s1">&#39;attention_mask&#39;</span><span class="p">:</span> <span class="n">attention_mask</span><span class="o">.</span><span class="n">shape</span><span class="p">})</span>
<span class="n">next_step_buffer</span><span class="o">.</span><span class="n">update</span><span class="p">({</span>
<span class="s1">&#39;attention_mask&#39;</span><span class="p">:</span>
<span class="n">attention_mask</span><span class="o">.</span><span class="n">contiguous</span><span class="p">(),</span>
<span class="p">})</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_custom_all_reduce</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
<span class="n">next_step_shape</span><span class="p">[</span>
<span class="s1">&#39;all_reduce_workspace&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">all_reduce_workspace</span><span class="o">.</span><span class="n">shape</span>
<span class="n">next_step_buffer</span><span class="p">[</span><span class="s1">&#39;all_reduce_workspace&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">all_reduce_workspace</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_lora_plugin</span><span class="p">:</span>
<span class="n">next_step_shape</span><span class="p">[</span><span class="s1">&#39;lora_ranks&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;lora_ranks&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span>
<span class="n">next_step_buffer</span><span class="p">[</span><span class="s1">&#39;lora_ranks&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;lora_ranks&#39;</span><span class="p">]</span>
<span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span><span class="p">):</span>
<span class="n">layer_idx</span> <span class="o">=</span> <span class="n">idx</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span>
<span class="n">next_step_shape</span><span class="p">[</span>
<span class="sa">f</span><span class="s1">&#39;lora_weights_pointers_</span><span class="si">{</span><span class="n">layer_idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span>
<span class="sa">f</span><span class="s1">&#39;lora_weights_pointers_</span><span class="si">{</span><span class="n">layer_idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span>
<span class="n">next_step_buffer</span><span class="p">[</span>
<span class="sa">f</span><span class="s1">&#39;lora_weights_pointers_</span><span class="si">{</span><span class="n">layer_idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span>
<span class="sa">f</span><span class="s1">&#39;lora_weights_pointers_</span><span class="si">{</span><span class="n">layer_idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">]</span>
<span class="k">return</span> <span class="n">next_step_shape</span><span class="p">,</span> <span class="n">next_step_buffer</span>
<span class="k">def</span> <span class="nf">_prepare_context_inputs</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span>
<span class="n">host_context_lengths</span><span class="p">,</span> <span class="n">use_gpt_attention_plugin</span><span class="p">,</span>
<span class="n">remove_input_padding</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
<span class="k">if</span> <span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
<span class="n">max_context_length</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;max_context_length&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="n">remove_input_padding</span><span class="p">:</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span>
<span class="n">torch</span><span class="o">.</span><span class="n">concat</span><span class="p">([</span>
<span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span>
<span class="n">host_context_lengths</span><span class="p">[</span><span class="n">i</span><span class="p">],</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">)</span>
<span class="p">]),</span> <span class="mi">0</span><span class="p">)</span>
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">last_token_ids</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">int</span><span class="p">()</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">max_context_length</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span>
<span class="p">[</span><span class="mi">1</span><span class="p">,</span>
<span class="o">-</span><span class="mi">1</span><span class="p">])</span><span class="o">.</span><span class="n">expand</span><span class="p">([</span><span class="n">batch_size</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">])</span>
<span class="n">ret</span> <span class="o">=</span> <span class="p">{</span><span class="s1">&#39;last_token_ids&#39;</span><span class="p">:</span> <span class="n">last_token_ids</span><span class="p">}</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">input_ids</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;input_ids&#39;</span><span class="p">)</span>
<span class="n">pad_id</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;pad_id&#39;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
<span class="n">attention_mask</span> <span class="o">=</span> <span class="n">_prepare_attention_mask</span><span class="p">(</span><span class="n">input_ids</span><span class="p">,</span> <span class="n">pad_id</span><span class="p">)</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">attention_mask</span><span class="o">.</span><span class="n">long</span><span class="p">()</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span>
<span class="n">position_ids</span><span class="o">.</span><span class="n">masked_fill_</span><span class="p">(</span><span class="n">attention_mask</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">position_ids</span><span class="o">.</span><span class="n">int</span><span class="p">()</span>
<span class="n">ret</span> <span class="o">=</span> <span class="p">{</span>
<span class="s1">&#39;attention_mask&#39;</span><span class="p">:</span> <span class="n">attention_mask</span><span class="p">,</span>
<span class="s1">&#39;last_token_ids&#39;</span><span class="p">:</span> <span class="n">last_token_ids</span>
<span class="p">}</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_position_embedding</span><span class="p">:</span>
<span class="n">ret</span><span class="p">[</span><span class="s1">&#39;position_ids&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">position_ids</span>
<span class="k">return</span> <span class="n">ret</span>
<span class="k">def</span> <span class="nf">_prepare_generation_inputs</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span>
<span class="n">use_gpt_attention_plugin</span><span class="p">,</span>
<span class="n">remove_input_padding</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones_like</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">)</span>
<span class="k">if</span> <span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
<span class="n">step</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;step&#39;</span><span class="p">)</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">context_lengths</span> <span class="o">+</span> <span class="n">step</span>
<span class="k">if</span> <span class="n">remove_input_padding</span><span class="p">:</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="n">position_ids</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">last_token_ids</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">int</span><span class="p">()</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="n">position_ids</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">ret</span> <span class="o">=</span> <span class="p">{</span><span class="s1">&#39;last_token_ids&#39;</span><span class="p">:</span> <span class="n">last_token_ids</span><span class="p">}</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">attention_mask</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;attention_mask&#39;</span><span class="p">)</span>
<span class="n">num_beams</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;num_beams&#39;</span><span class="p">)</span>
<span class="n">attention_mask</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">attention_mask</span><span class="p">,</span>
<span class="n">attention_mask</span><span class="o">.</span><span class="n">new_ones</span><span class="p">(</span>
<span class="p">(</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">num_beams</span><span class="p">,</span> <span class="mi">1</span><span class="p">))),</span>
<span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">attention_mask</span><span class="o">.</span><span class="n">long</span><span class="p">()</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span>
<span class="n">position_ids</span><span class="o">.</span><span class="n">masked_fill_</span><span class="p">(</span><span class="n">attention_mask</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">position_ids</span><span class="p">[:,</span> <span class="o">-</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">position_ids</span><span class="o">.</span><span class="n">int</span><span class="p">()</span>
<span class="n">ret</span> <span class="o">=</span> <span class="p">{</span>
<span class="s1">&#39;last_token_ids&#39;</span><span class="p">:</span> <span class="n">last_token_ids</span><span class="p">,</span>
<span class="s1">&#39;attention_mask&#39;</span><span class="p">:</span> <span class="n">attention_mask</span><span class="p">,</span>
<span class="p">}</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_position_embedding</span><span class="p">:</span>
<span class="n">ret</span><span class="p">[</span><span class="s1">&#39;position_ids&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">position_ids</span>
<span class="k">return</span> <span class="n">ret</span>
<div class="viewcode-block" id="GenerationSession.pp_communicate_new_tokens">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.pp_communicate_new_tokens">[docs]</a>
<span class="k">def</span> <span class="nf">pp_communicate_new_tokens</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">should_stop</span><span class="p">,</span> <span class="n">cache_indir</span><span class="p">,</span>
<span class="n">sequence_length</span><span class="p">):</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
<span class="k">for</span> <span class="n">pg</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_group</span><span class="p">:</span>
<span class="k">if</span> <span class="n">pg</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">rank</span><span class="p">:</span>
<span class="k">continue</span>
<span class="n">should_stop</span> <span class="o">=</span> <span class="n">should_stop</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">nccl_comm</span><span class="o">.</span><span class="n">send</span><span class="p">(</span><span class="n">should_stop</span><span class="p">,</span> <span class="n">pg</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">nccl_comm</span><span class="o">.</span><span class="n">send</span><span class="p">(</span><span class="n">cache_indir</span><span class="p">,</span> <span class="n">pg</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">nccl_comm</span><span class="o">.</span><span class="n">send</span><span class="p">(</span><span class="n">sequence_length</span><span class="p">,</span> <span class="n">pg</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">nccl_comm</span><span class="o">.</span><span class="n">send</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_group</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">should_stop</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">bool</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">nccl_comm</span><span class="o">.</span><span class="n">recv</span><span class="p">(</span><span class="n">should_stop</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_group</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
<span class="bp">self</span><span class="o">.</span><span class="n">nccl_comm</span><span class="o">.</span><span class="n">recv</span><span class="p">(</span><span class="n">cache_indir</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_group</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
<span class="bp">self</span><span class="o">.</span><span class="n">nccl_comm</span><span class="o">.</span><span class="n">recv</span><span class="p">(</span><span class="n">sequence_length</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_group</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_first_pp_rank</span><span class="p">():</span>
<span class="bp">self</span><span class="o">.</span><span class="n">nccl_comm</span><span class="o">.</span><span class="n">recv</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_group</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
<span class="k">return</span> <span class="n">should_stop</span></div>
<div class="viewcode-block" id="GenerationSession.pp_communicate_final_output_ids">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.pp_communicate_final_output_ids">[docs]</a>
<span class="k">def</span> <span class="nf">pp_communicate_final_output_ids</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">final_output_ids</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span>
<span class="n">beam_width</span><span class="p">):</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
<span class="bp">self</span><span class="o">.</span><span class="n">nccl_comm</span><span class="o">.</span><span class="n">send</span><span class="p">(</span><span class="n">final_output_ids</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_group</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_first_pp_rank</span><span class="p">():</span>
<span class="n">final_output_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span>
<span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">nccl_comm</span><span class="o">.</span><span class="n">recv</span><span class="p">(</span><span class="n">final_output_ids</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_group</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
<span class="k">return</span> <span class="n">final_output_ids</span></div>
<div class="viewcode-block" id="GenerationSession.finalize_decoder">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.finalize_decoder">[docs]</a>
<span class="k">def</span> <span class="nf">finalize_decoder</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">,</span> <span class="n">scfg</span><span class="p">):</span>
<span class="n">final_output_ids</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
<span class="c1"># output shape of self.gather_tree: [batch_size, beam_width, output_len]</span>
<span class="n">final_output_ids</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_tree</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">output_ids</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">parent_ids</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">end_ids</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">cum_log_probs</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_output_ids_tgt</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_sequence_lengths_tgt</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_cum_log_probs</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_normed_scores</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_log_probs</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_min_normed_scores</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_num_beams</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_is_done</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">finished</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">length_penalty</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">use_beam_hyps</span><span class="p">)</span>
<span class="c1"># Communicate ranks in Pipeline Parallelism</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">has_pp</span><span class="p">():</span>
<span class="n">final_output_ids</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">pp_communicate_final_output_ids</span><span class="p">(</span>
<span class="n">final_output_ids</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">)</span>
<span class="k">return</span> <span class="n">final_output_ids</span></div>
<div class="viewcode-block" id="GenerationSession.handle_per_step">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.handle_per_step">[docs]</a>
<span class="k">def</span> <span class="nf">handle_per_step</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span> <span class="n">cache_indirections</span><span class="p">:</span> <span class="nb">list</span><span class="p">,</span> <span class="n">step</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">max_context_length</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">input_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">hidden_states</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">scfg</span><span class="p">:</span> <span class="n">SamplingConfig</span><span class="p">,</span>
<span class="n">kv_cache_block_pointers</span><span class="p">:</span> <span class="nb">list</span><span class="p">,</span> <span class="n">prompt_embedding_table</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">tasks</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">host_context_lengths</span><span class="p">,</span> <span class="n">attention_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">prompt_vocab_size</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">ite</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">sequence_limit_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">sequence_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">next_step_buffer</span><span class="p">:</span> <span class="nb">dict</span><span class="p">,</span>
<span class="n">stop_words_list</span><span class="p">,</span> <span class="n">bad_words_list</span><span class="p">,</span> <span class="n">no_repeat_ngram_size</span><span class="p">,</span>
<span class="n">encoder_output</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">encoder_input_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="k">if</span> <span class="n">step</span> <span class="o">%</span> <span class="mi">2</span><span class="p">:</span>
<span class="n">context</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">context_0</span>
<span class="n">this_src_cache_indirection</span> <span class="o">=</span> <span class="n">cache_indirections</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
<span class="n">this_tgt_cache_indirection</span> <span class="o">=</span> <span class="n">cache_indirections</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">next_src_cache_indirection</span> <span class="o">=</span> <span class="n">cache_indirections</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">context</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">context_1</span>
<span class="n">this_src_cache_indirection</span> <span class="o">=</span> <span class="n">cache_indirections</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">this_tgt_cache_indirection</span> <span class="o">=</span> <span class="n">cache_indirections</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
<span class="n">next_src_cache_indirection</span> <span class="o">=</span> <span class="n">cache_indirections</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
<span class="k">if</span> <span class="n">step</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">model_inputs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_prepare_context_inputs</span><span class="p">(</span>
<span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span>
<span class="n">context_lengths</span><span class="o">=</span><span class="n">context_lengths</span><span class="p">,</span>
<span class="n">host_context_lengths</span><span class="o">=</span><span class="n">host_context_lengths</span><span class="p">,</span>
<span class="n">use_gpt_attention_plugin</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">,</span>
<span class="n">remove_input_padding</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">,</span>
<span class="n">max_context_length</span><span class="o">=</span><span class="n">max_context_length</span><span class="p">,</span>
<span class="n">input_ids</span><span class="o">=</span><span class="n">input_ids</span><span class="p">,</span>
<span class="n">pad_id</span><span class="o">=</span><span class="n">scfg</span><span class="o">.</span><span class="n">pad_id</span><span class="p">,</span>
<span class="n">eos_id</span><span class="o">=</span><span class="n">scfg</span><span class="o">.</span><span class="n">end_id</span><span class="p">)</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">model_inputs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;position_ids&#39;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">model_inputs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;last_token_ids&#39;</span><span class="p">)</span>
<span class="n">attention_mask</span> <span class="o">=</span> <span class="n">model_inputs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;attention_mask&#39;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span><span class="p">:</span>
<span class="n">kv_cache_block_pointers</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_manager</span><span class="o">.</span><span class="n">get_pointer_arrays</span><span class="p">(</span>
<span class="mi">1</span><span class="p">)</span>
<span class="n">ctx_shape</span><span class="p">,</span> <span class="n">ctx_buffer</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_context_shape_buffer</span><span class="p">(</span>
<span class="n">input_ids</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span> <span class="n">host_context_lengths</span><span class="p">,</span> <span class="n">position_ids</span><span class="p">,</span>
<span class="n">last_token_ids</span><span class="p">,</span> <span class="n">attention_mask</span><span class="p">,</span> <span class="n">this_src_cache_indirection</span><span class="p">,</span>
<span class="n">kv_cache_block_pointers</span><span class="p">,</span> <span class="n">hidden_states</span><span class="p">,</span> <span class="n">prompt_embedding_table</span><span class="p">,</span>
<span class="n">tasks</span><span class="p">,</span> <span class="n">prompt_vocab_size</span><span class="p">,</span> <span class="n">encoder_output</span><span class="p">,</span> <span class="n">encoder_input_lengths</span><span class="p">)</span>
<span class="n">context</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">ctx_context</span>
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">_set_shape</span><span class="p">(</span><span class="n">context</span><span class="p">,</span> <span class="n">ctx_shape</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">_set_buffer</span><span class="p">(</span><span class="n">context</span><span class="p">,</span> <span class="n">ctx_buffer</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">debug_mode</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">debug_buffer</span> <span class="o">=</span> <span class="n">ctx_buffer</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cuda_graph_mode</span><span class="p">:</span>
<span class="c1"># context mode, clean cuda graph instances</span>
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">cuda_graph_instances</span> <span class="o">=</span> <span class="p">[</span><span class="kc">None</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">2</span><span class="p">)]</span>
<span class="c1"># dynamic_decoder currently use torch&#39;s current stream, so must let TRT enqueue use same stream here</span>
<span class="n">stream</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">current_stream</span><span class="p">()</span><span class="o">.</span><span class="n">cuda_stream</span>
<span class="n">instance_idx</span> <span class="o">=</span> <span class="n">step</span> <span class="o">%</span> <span class="mi">2</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cuda_graph_mode</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">cuda_graph_instances</span><span class="p">[</span>
<span class="n">instance_idx</span><span class="p">]</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="c1"># launch cuda graph</span>
<span class="n">CUASSERT</span><span class="p">(</span>
<span class="n">cudart</span><span class="o">.</span><span class="n">cudaGraphLaunch</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">cuda_graph_instances</span><span class="p">[</span><span class="n">instance_idx</span><span class="p">],</span> <span class="n">stream</span><span class="p">))</span>
<span class="n">ok</span> <span class="o">=</span> <span class="kc">True</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">ok</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">_run</span><span class="p">(</span><span class="n">context</span><span class="p">,</span> <span class="n">stream</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">ok</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span><span class="s1">&#39;Executing TRT engine failed!&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">debug_mode</span><span class="p">:</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">synchronize</span><span class="p">()</span>
<span class="n">context_logits</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
<span class="k">if</span> <span class="n">step</span> <span class="o">==</span> <span class="mi">0</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_all_token_logits</span><span class="p">:</span>
<span class="n">context_logits</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;logits&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
<span class="c1"># reshape self.buffer[&#39;logits&#39;] from [bs, max_context_length, vocab]</span>
<span class="c1"># to [1, bs * max_context_length, vocab]</span>
<span class="c1"># Note that the data are put in the buffer without padding although</span>
<span class="c1"># the allocated buffer has padding.</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;logits&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;logits&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span>
<span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;logits&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]])</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;logits&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">index_select</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;logits&#39;</span><span class="p">],</span> <span class="mi">1</span><span class="p">,</span>
<span class="n">last_token_ids</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">last_token_ids</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">last_token_ids</span><span class="o">.</span><span class="n">expand</span><span class="p">(</span>
<span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;logits&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">gather</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;logits&#39;</span><span class="p">],</span>
<span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
<span class="n">index</span><span class="o">=</span><span class="n">last_token_ids</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int64</span><span class="p">))</span><span class="o">.</span><span class="n">view</span><span class="p">(</span>
<span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">)</span>
<span class="k">if</span> <span class="n">step</span> <span class="o">==</span> <span class="mi">0</span> <span class="ow">and</span> <span class="n">beam_width</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
<span class="c1"># these tiled tensors are returned by handle_per_step(), so they can relay to the next generation calls</span>
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
<span class="n">attention_mask</span> <span class="o">=</span> <span class="n">_tile_beam_width</span><span class="p">(</span><span class="n">attention_mask</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">)</span>
<span class="n">context_lengths</span> <span class="o">=</span> <span class="n">_tile_beam_width</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">)</span>
<span class="n">host_context_lengths</span> <span class="o">=</span> <span class="n">_tile_beam_width</span><span class="p">(</span><span class="n">host_context_lengths</span><span class="p">,</span>
<span class="n">beam_width</span><span class="p">)</span>
<span class="k">if</span> <span class="n">encoder_input_lengths</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">encoder_input_lengths</span> <span class="o">=</span> <span class="n">_tile_beam_width</span><span class="p">(</span>
<span class="n">encoder_input_lengths</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">)</span>
<span class="k">if</span> <span class="n">tasks</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">tasks</span> <span class="o">=</span> <span class="n">_tile_beam_width</span><span class="p">(</span><span class="n">tasks</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">)</span>
<span class="c1"># Move tiling before logit computing of context</span>
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span><span class="p">:</span>
<span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
<span class="c1"># Note: this tiles both self attn cache and cross attn cache!</span>
<span class="c1"># both names contain &quot;present_key_value&quot;</span>
<span class="k">if</span> <span class="s2">&quot;present_key_value&quot;</span> <span class="ow">in</span> <span class="n">key</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">_tile_beam_width</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="n">key</span><span class="p">],</span> <span class="n">beam_width</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;logits&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">_tile_beam_width</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;logits&#39;</span><span class="p">],</span> <span class="n">beam_width</span><span class="p">)</span>
<span class="c1"># Initialize sequence_lengths (no paddings) for the generation phase.</span>
<span class="k">if</span> <span class="n">step</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">step</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_new_tokens</span> <span class="o">-</span> <span class="mi">1</span><span class="p">:</span>
<span class="c1"># Set shape and address for the next step</span>
<span class="n">model_inputs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_prepare_generation_inputs</span><span class="p">(</span>
<span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span>
<span class="n">context_lengths</span><span class="o">=</span><span class="n">context_lengths</span><span class="p">,</span>
<span class="n">use_gpt_attention_plugin</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">,</span>
<span class="n">remove_input_padding</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">,</span>
<span class="n">step</span><span class="o">=</span><span class="n">step</span><span class="p">,</span>
<span class="n">num_beams</span><span class="o">=</span><span class="n">beam_width</span><span class="p">,</span>
<span class="n">attention_mask</span><span class="o">=</span><span class="n">attention_mask</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">model_inputs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;position_ids&#39;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">model_inputs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;last_token_ids&#39;</span><span class="p">)</span>
<span class="n">attention_mask</span> <span class="o">=</span> <span class="n">model_inputs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;attention_mask&#39;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span><span class="p">:</span>
<span class="n">kv_cache_block_pointers</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_manager</span><span class="o">.</span><span class="n">get_pointer_arrays</span><span class="p">(</span>
<span class="n">beam_width</span><span class="p">)</span>
<span class="n">next_context</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">context_1</span> <span class="k">if</span> <span class="n">step</span> <span class="o">%</span> <span class="mi">2</span> <span class="k">else</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">context_0</span>
<span class="n">next_step_shape</span><span class="p">,</span> <span class="n">next_step_buffer</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_next_step_shape_buffer</span><span class="p">(</span>
<span class="n">batch_size</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">,</span> <span class="n">max_context_length</span><span class="p">,</span> <span class="n">step</span><span class="p">,</span>
<span class="n">context_lengths</span><span class="p">,</span> <span class="n">host_context_lengths</span><span class="p">,</span> <span class="n">position_ids</span><span class="p">,</span>
<span class="n">last_token_ids</span><span class="p">,</span> <span class="n">attention_mask</span><span class="p">,</span> <span class="n">next_src_cache_indirection</span><span class="p">,</span>
<span class="n">kv_cache_block_pointers</span><span class="p">,</span> <span class="n">hidden_states</span><span class="p">,</span> <span class="n">prompt_embedding_table</span><span class="p">,</span>
<span class="n">tasks</span><span class="p">,</span> <span class="n">prompt_vocab_size</span><span class="p">,</span> <span class="n">encoder_output</span><span class="p">,</span> <span class="n">encoder_input_lengths</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">_set_shape</span><span class="p">(</span><span class="n">next_context</span><span class="p">,</span> <span class="n">next_step_shape</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">_set_buffer</span><span class="p">(</span><span class="n">next_context</span><span class="p">,</span> <span class="n">next_step_buffer</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">debug_mode</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">debug_buffer</span> <span class="o">=</span> <span class="n">next_step_buffer</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cuda_graph_mode</span><span class="p">:</span>
<span class="c1"># capture cuda graph</span>
<span class="n">CUASSERT</span><span class="p">(</span>
<span class="n">cudart</span><span class="o">.</span><span class="n">cudaStreamBeginCapture</span><span class="p">(</span>
<span class="n">stream</span><span class="p">,</span> <span class="n">cudart</span><span class="o">.</span><span class="n">cudaStreamCaptureMode</span><span class="o">.</span>
<span class="n">cudaStreamCaptureModeGlobal</span><span class="p">))</span>
<span class="n">next_context</span><span class="o">.</span><span class="n">execute_async_v3</span><span class="p">(</span><span class="n">stream</span><span class="p">)</span>
<span class="n">next_graph</span> <span class="o">=</span> <span class="n">CUASSERT</span><span class="p">(</span><span class="n">cudart</span><span class="o">.</span><span class="n">cudaStreamEndCapture</span><span class="p">(</span><span class="n">stream</span><span class="p">))[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">instance_idx</span> <span class="o">=</span> <span class="p">(</span><span class="n">step</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">%</span> <span class="mi">2</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">cuda_graph_instances</span><span class="p">[</span><span class="n">instance_idx</span><span class="p">]</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">cuda_graph_instances</span><span class="p">[</span>
<span class="n">instance_idx</span><span class="p">]</span> <span class="o">=</span> <span class="n">_update_cuda_graph_instance</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">cuda_graph_instances</span><span class="p">[</span><span class="n">instance_idx</span><span class="p">],</span>
<span class="n">next_graph</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">cuda_graph_instances</span><span class="p">[</span><span class="n">instance_idx</span><span class="p">]</span> <span class="o">=</span> <span class="n">CUASSERT</span><span class="p">(</span>
<span class="n">cudart</span><span class="o">.</span><span class="n">cudaGraphInstantiate</span><span class="p">(</span><span class="n">next_graph</span><span class="p">,</span> <span class="mi">0</span><span class="p">))[</span><span class="mi">0</span><span class="p">]</span>
<span class="c1"># Pre-upload cuda graph to stream</span>
<span class="n">CUASSERT</span><span class="p">(</span>
<span class="n">cudart</span><span class="o">.</span><span class="n">cudaGraphUpload</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">cuda_graph_instances</span><span class="p">[</span><span class="n">instance_idx</span><span class="p">],</span>
<span class="n">stream</span><span class="p">))</span>
<span class="n">should_stop</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">logits</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
<span class="n">logits</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;logits&#39;</span><span class="p">]</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">debug_mode</span><span class="p">:</span>
<span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">debug_buffer</span><span class="p">:</span>
<span class="c1"># if needed, apply filter based on output name</span>
<span class="n">tensors_to_save</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">debug_tensors</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">debug_tensors_to_save</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">tensors_to_save</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">debug_tensors_to_save</span>
<span class="k">if</span> <span class="nb">all</span><span class="p">([</span><span class="n">kk</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">k</span> <span class="k">for</span> <span class="n">kk</span> <span class="ow">in</span> <span class="n">tensors_to_save</span><span class="p">]):</span>
<span class="k">continue</span>
<span class="n">t</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">debug_buffer</span><span class="p">[</span><span class="n">k</span><span class="p">]</span>
<span class="n">t</span> <span class="o">=</span> <span class="n">t</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">t</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span> <span class="c1"># consolidate all but last dim</span>
<span class="c1"># convert tensor name to valid file name</span>
<span class="n">fname</span> <span class="o">=</span> <span class="s2">&quot;&quot;</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">c</span> <span class="k">for</span> <span class="n">c</span> <span class="ow">in</span> <span class="n">k</span> <span class="k">if</span> <span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">isalnum</span><span class="p">()</span> <span class="ow">or</span> <span class="n">c</span> <span class="ow">in</span> <span class="s2">&quot;._-&quot;</span><span class="p">))</span>
<span class="n">np</span><span class="o">.</span><span class="n">savetxt</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">fname</span><span class="si">}</span><span class="s2">-step</span><span class="si">{</span><span class="n">step</span><span class="si">}</span><span class="s2">.txt&quot;</span><span class="p">,</span> <span class="n">t</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">detach</span><span class="p">())</span>
<span class="k">if</span> <span class="n">logits</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="c1"># [batch_size x beam_width, vocab_size_padded] -&gt; [batch_size, beam_width, vocab_size_padded]</span>
<span class="n">next_token_logits</span> <span class="o">=</span> <span class="n">logits</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span>
<span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">))</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">decoder_logits_dtype</span><span class="p">)</span>
<span class="n">decode_step</span> <span class="o">=</span> <span class="n">step</span> <span class="o">+</span> <span class="n">max_context_length</span>
<span class="n">should_stop</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dynamic_decoder</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span>
<span class="n">next_token_logits</span><span class="p">,</span> <span class="n">decode_step</span><span class="p">,</span> <span class="n">max_context_length</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_kv_cache_length</span><span class="p">,</span> <span class="n">ite</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">end_ids</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">embedding_bias_opt</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span>
<span class="n">sequence_limit_lengths</span><span class="p">,</span> <span class="n">stop_words_list</span><span class="p">,</span> <span class="n">bad_words_list</span><span class="p">,</span>
<span class="n">no_repeat_ngram_size</span><span class="p">,</span> <span class="n">this_src_cache_indirection</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">output_ids</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">finished</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">finished</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">cum_log_probs</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">log_probs</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">parent_ids</span><span class="p">,</span>
<span class="n">this_tgt_cache_indirection</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_output_ids_tgt</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_sequence_lengths_tgt</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_cum_log_probs</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_normed_scores</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_log_probs</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_min_normed_scores</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_num_beams</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_is_done</span><span class="p">,</span>
<span class="n">scfg</span><span class="o">.</span><span class="n">use_beam_hyps</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">has_pp</span><span class="p">():</span>
<span class="n">should_stop</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">pp_communicate_new_tokens</span><span class="p">(</span>
<span class="n">should_stop</span><span class="p">,</span> <span class="n">this_tgt_cache_indirection</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span><span class="p">:</span>
<span class="k">if</span> <span class="p">(</span><span class="n">step</span> <span class="o">&gt;=</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_new_tokens</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="ow">or</span> <span class="p">(</span><span class="n">should_stop</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="ow">and</span> <span class="n">should_stop</span><span class="o">.</span><span class="n">item</span><span class="p">()):</span>
<span class="c1"># Free all blocks in all sequences.</span>
<span class="c1"># With in-flight batching and while loop we&#39;ll free some sequences, when they are done</span>
<span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_manager</span><span class="o">.</span><span class="n">step</span><span class="p">([</span><span class="kc">True</span><span class="p">]</span> <span class="o">*</span> <span class="n">batch_size</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="c1"># Iterate to the next step in KV cache manager.</span>
<span class="c1"># Increase number of tokens for all unfinished sequences.</span>
<span class="c1"># And allocate new blocks if needed.</span>
<span class="c1"># We set this to False for all sequences, since we use only length criterion to stop now</span>
<span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_manager</span><span class="o">.</span><span class="n">step</span><span class="p">([</span><span class="kc">False</span><span class="p">]</span> <span class="o">*</span> <span class="n">batch_size</span><span class="p">)</span>
<span class="k">return</span> <span class="n">should_stop</span><span class="p">,</span> <span class="n">next_step_buffer</span><span class="p">,</span> <span class="n">tasks</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span> <span class="n">host_context_lengths</span><span class="p">,</span> <span class="n">attention_mask</span><span class="p">,</span> <span class="n">context_logits</span><span class="p">,</span> <span class="n">encoder_input_lengths</span></div>
<div class="viewcode-block" id="GenerationSession.decode_regular">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.decode_regular">[docs]</a>
<span class="k">def</span> <span class="nf">decode_regular</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">scfg</span><span class="p">:</span> <span class="n">SamplingConfig</span><span class="p">,</span>
<span class="n">sequence_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">context_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">host_context_lengths</span><span class="p">,</span>
<span class="n">max_context_length</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">beam_width</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">cache_indirections</span><span class="p">:</span> <span class="nb">list</span><span class="p">,</span>
<span class="n">input_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">hidden_states</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">prompt_embedding_table</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">tasks</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">prompt_vocab_size</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">ite</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">sequence_limit_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">stop_words_list</span><span class="p">,</span>
<span class="n">bad_words_list</span><span class="p">,</span>
<span class="n">no_repeat_ngram_size</span><span class="p">,</span>
<span class="n">output_sequence_lengths</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">return_dict</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">encoder_output</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">encoder_input_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
<span class="n">kv_cache_block_pointers</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">next_step_buffer</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">attention_mask</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">context_logits</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">generation_logits</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">def</span> <span class="nf">get_outputs_dict</span><span class="p">(</span><span class="n">output_ids</span><span class="p">):</span>
<span class="n">outputs</span> <span class="o">=</span> <span class="p">{}</span>
<span class="n">outputs</span><span class="p">[</span><span class="s1">&#39;output_ids&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">output_ids</span>
<span class="k">if</span> <span class="n">output_sequence_lengths</span><span class="p">:</span>
<span class="n">outputs</span><span class="p">[</span>
<span class="s1">&#39;sequence_lengths&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span>
<span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">])</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_all_token_logits</span><span class="p">:</span>
<span class="n">outputs</span><span class="p">[</span><span class="s1">&#39;context_logits&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">context_logits</span>
<span class="n">outputs</span><span class="p">[</span><span class="s1">&#39;generation_logits&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">generation_logits</span>
<span class="k">return</span> <span class="n">outputs</span>
<span class="k">for</span> <span class="n">step</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_new_tokens</span><span class="p">):</span>
<span class="n">should_stop</span><span class="p">,</span> <span class="n">next_step_buffer</span><span class="p">,</span> <span class="n">tasks</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span> <span class="n">host_context_lengths</span><span class="p">,</span> <span class="n">attention_mask</span><span class="p">,</span> <span class="n">logits</span><span class="p">,</span> <span class="n">encoder_input_lengths</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">handle_per_step</span><span class="p">(</span>
<span class="n">cache_indirections</span><span class="p">,</span> <span class="n">step</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">max_context_length</span><span class="p">,</span>
<span class="n">beam_width</span><span class="p">,</span> <span class="n">input_ids</span><span class="p">,</span> <span class="n">hidden_states</span><span class="p">,</span> <span class="n">scfg</span><span class="p">,</span>
<span class="n">kv_cache_block_pointers</span><span class="p">,</span> <span class="n">prompt_embedding_table</span><span class="p">,</span> <span class="n">tasks</span><span class="p">,</span>
<span class="n">context_lengths</span><span class="p">,</span> <span class="n">host_context_lengths</span><span class="p">,</span> <span class="n">attention_mask</span><span class="p">,</span>
<span class="n">prompt_vocab_size</span><span class="p">,</span> <span class="n">ite</span><span class="p">,</span> <span class="n">sequence_limit_lengths</span><span class="p">,</span>
<span class="n">sequence_lengths</span><span class="p">,</span> <span class="n">next_step_buffer</span><span class="p">,</span> <span class="n">stop_words_list</span><span class="p">,</span>
<span class="n">bad_words_list</span><span class="p">,</span> <span class="n">no_repeat_ngram_size</span><span class="p">,</span> <span class="n">encoder_output</span><span class="p">,</span>
<span class="n">encoder_input_lengths</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_all_token_logits</span><span class="p">:</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
<span class="k">if</span> <span class="n">step</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">context_logits</span> <span class="o">=</span> <span class="n">logits</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">generation_logits</span><span class="o">.</span><span class="n">append</span><span class="p">(</span>
<span class="n">next_step_buffer</span><span class="p">[</span><span class="s1">&#39;logits&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span><span class="o">.</span><span class="n">detach</span><span class="p">())</span>
<span class="k">if</span> <span class="n">should_stop</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">should_stop</span><span class="o">.</span><span class="n">item</span><span class="p">():</span>
<span class="n">final_output_ids</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">finalize_decoder</span><span class="p">(</span>
<span class="n">context_lengths</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">,</span> <span class="n">scfg</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_first_pp_rank</span><span class="p">():</span>
<span class="k">if</span> <span class="n">return_dict</span><span class="p">:</span>
<span class="k">return</span> <span class="n">get_outputs_dict</span><span class="p">(</span><span class="n">final_output_ids</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">return</span> <span class="n">final_output_ids</span>
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">(</span>
<span class="p">)</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_all_token_logits</span><span class="p">:</span>
<span class="n">outputs</span> <span class="o">=</span> <span class="p">{}</span>
<span class="n">outputs</span><span class="p">[</span><span class="s1">&#39;context_logits&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">context_logits</span>
<span class="n">outputs</span><span class="p">[</span><span class="s1">&#39;generation_logits&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">generation_logits</span>
<span class="k">return</span> <span class="n">outputs</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">return</span> <span class="kc">None</span>
<span class="n">final_output_ids</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">finalize_decoder</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span>
<span class="n">beam_width</span><span class="p">,</span> <span class="n">scfg</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_first_pp_rank</span><span class="p">():</span>
<span class="k">if</span> <span class="n">return_dict</span><span class="p">:</span>
<span class="k">return</span> <span class="n">get_outputs_dict</span><span class="p">(</span><span class="n">final_output_ids</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">return</span> <span class="n">final_output_ids</span>
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">()</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_all_token_logits</span><span class="p">:</span>
<span class="n">outputs</span> <span class="o">=</span> <span class="p">{}</span>
<span class="n">outputs</span><span class="p">[</span><span class="s1">&#39;context_logits&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">context_logits</span>
<span class="n">outputs</span><span class="p">[</span><span class="s1">&#39;generation_logits&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">generation_logits</span>
<span class="k">return</span> <span class="n">outputs</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">return</span> <span class="kc">None</span></div>
<div class="viewcode-block" id="GenerationSession.decode_stream">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.decode_stream">[docs]</a>
<span class="k">def</span> <span class="nf">decode_stream</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">scfg</span><span class="p">:</span> <span class="n">SamplingConfig</span><span class="p">,</span>
<span class="n">sequence_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">context_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">host_context_lengths</span><span class="p">,</span>
<span class="n">max_context_length</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">beam_width</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">cache_indirections</span><span class="p">:</span> <span class="nb">list</span><span class="p">,</span>
<span class="n">input_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">hidden_states</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">prompt_embedding_table</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">tasks</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">prompt_vocab_size</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">ite</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">sequence_limit_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">stop_words_list</span><span class="p">,</span>
<span class="n">bad_words_list</span><span class="p">,</span>
<span class="n">no_repeat_ngram_size</span><span class="p">,</span>
<span class="n">output_sequence_lengths</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">return_dict</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">encoder_output</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">encoder_input_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
<span class="n">kv_cache_block_pointers</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">next_step_buffer</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">attention_mask</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">context_logits</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">def</span> <span class="nf">get_outputs_dict</span><span class="p">(</span><span class="n">output_ids</span><span class="p">):</span>
<span class="n">outputs</span> <span class="o">=</span> <span class="p">{}</span>
<span class="n">outputs</span><span class="p">[</span><span class="s1">&#39;output_ids&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">output_ids</span>
<span class="k">if</span> <span class="n">output_sequence_lengths</span><span class="p">:</span>
<span class="n">outputs</span><span class="p">[</span>
<span class="s1">&#39;sequence_lengths&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span>
<span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">])</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_all_token_logits</span><span class="p">:</span>
<span class="n">outputs</span><span class="p">[</span><span class="s1">&#39;context_logits&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">context_logits</span>
<span class="k">return</span> <span class="n">outputs</span>
<span class="k">for</span> <span class="n">step</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_new_tokens</span><span class="p">):</span>
<span class="n">should_stop</span><span class="p">,</span> <span class="n">next_step_buffer</span><span class="p">,</span> <span class="n">tasks</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span> <span class="n">host_context_lengths</span><span class="p">,</span> <span class="n">attention_mask</span><span class="p">,</span> <span class="n">logits</span><span class="p">,</span> <span class="n">encoder_input_lengths</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">handle_per_step</span><span class="p">(</span>
<span class="n">cache_indirections</span><span class="p">,</span> <span class="n">step</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">max_context_length</span><span class="p">,</span>
<span class="n">beam_width</span><span class="p">,</span> <span class="n">input_ids</span><span class="p">,</span> <span class="n">hidden_states</span><span class="p">,</span> <span class="n">scfg</span><span class="p">,</span>
<span class="n">kv_cache_block_pointers</span><span class="p">,</span> <span class="n">prompt_embedding_table</span><span class="p">,</span> <span class="n">tasks</span><span class="p">,</span>
<span class="n">context_lengths</span><span class="p">,</span> <span class="n">host_context_lengths</span><span class="p">,</span> <span class="n">attention_mask</span><span class="p">,</span>
<span class="n">prompt_vocab_size</span><span class="p">,</span> <span class="n">ite</span><span class="p">,</span> <span class="n">sequence_limit_lengths</span><span class="p">,</span>
<span class="n">sequence_lengths</span><span class="p">,</span> <span class="n">next_step_buffer</span><span class="p">,</span> <span class="n">stop_words_list</span><span class="p">,</span>
<span class="n">bad_words_list</span><span class="p">,</span> <span class="n">no_repeat_ngram_size</span><span class="p">,</span> <span class="n">encoder_output</span><span class="p">,</span>
<span class="n">encoder_input_lengths</span><span class="p">)</span>
<span class="k">if</span> <span class="n">step</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">context_logits</span> <span class="o">=</span> <span class="n">logits</span>
<span class="k">if</span> <span class="n">should_stop</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">final_output_ids</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">finalize_decoder</span><span class="p">(</span>
<span class="n">context_lengths</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">,</span> <span class="n">scfg</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_first_pp_rank</span><span class="p">():</span>
<span class="k">if</span> <span class="n">return_dict</span><span class="p">:</span>
<span class="k">yield</span> <span class="n">get_outputs_dict</span><span class="p">(</span><span class="n">final_output_ids</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">yield</span> <span class="n">final_output_ids</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">yield</span> <span class="kc">None</span>
<span class="k">if</span> <span class="n">should_stop</span><span class="o">.</span><span class="n">item</span><span class="p">():</span>
<span class="k">return</span>
<span class="n">final_output_ids</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">finalize_decoder</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span>
<span class="n">beam_width</span><span class="p">,</span> <span class="n">scfg</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_first_pp_rank</span><span class="p">():</span>
<span class="k">if</span> <span class="n">return_dict</span><span class="p">:</span>
<span class="k">yield</span> <span class="n">get_outputs_dict</span><span class="p">(</span><span class="n">final_output_ids</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">yield</span> <span class="n">final_output_ids</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">yield</span> <span class="kc">None</span></div>
<div class="viewcode-block" id="GenerationSession.decode_batch">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.decode_batch">[docs]</a>
<span class="k">def</span> <span class="nf">decode_batch</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">input_ids</span><span class="p">:</span> <span class="n">Sequence</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span>
<span class="n">sampling_config</span><span class="p">:</span> <span class="n">SamplingConfig</span><span class="p">,</span>
<span class="n">streaming</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="n">input_ids</span><span class="p">,</span> <span class="n">context_lengths</span> <span class="o">=</span> <span class="n">_prepare_input_ids</span><span class="p">(</span><span class="n">input_ids</span><span class="p">)</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">decode</span><span class="p">(</span><span class="n">input_ids</span><span class="p">,</span>
<span class="n">context_lengths</span><span class="p">,</span>
<span class="n">sampling_config</span><span class="p">,</span>
<span class="n">streaming</span><span class="o">=</span><span class="n">streaming</span><span class="p">,</span>
<span class="o">**</span><span class="n">kwargs</span><span class="p">)</span></div>
<span class="c1"># As dynamic_decoder uses torch&#39;s current stream, we must ensure it runs on the same stream that</span>
<span class="c1"># dynamic_decoder was set up with</span>
<div class="viewcode-block" id="GenerationSession.decode">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.decode">[docs]</a>
<span class="nd">@cuda_stream_guard</span>
<span class="k">def</span> <span class="nf">decode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">input_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">context_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">sampling_config</span><span class="p">:</span> <span class="n">SamplingConfig</span><span class="p">,</span>
<span class="n">prompt_embedding_table</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">tasks</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">prompt_vocab_size</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">stop_words_list</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">bad_words_list</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">no_repeat_ngram_size</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">streaming</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">output_sequence_lengths</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">return_dict</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">encoder_output</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">encoder_input_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
<span class="n">scfg</span> <span class="o">=</span> <span class="n">sampling_config</span>
<span class="n">batch_size</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="n">beam_width</span> <span class="o">=</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span>
<span class="n">max_context_length</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">)</span><span class="o">.</span><span class="n">item</span><span class="p">()</span>
<span class="n">host_context_lengths</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span>
<span class="k">assert</span> <span class="n">batch_size</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span> \
<span class="s2">&quot;Given batch size is different from the one used in setup(),&quot;</span> \
<span class="s2">&quot;rerun the setup function with the new batch size to avoid buffer overflow.&quot;</span>
<span class="k">assert</span> <span class="n">max_context_length</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_context_length</span><span class="p">,</span> \
<span class="s2">&quot;Given input length is large then the one used in setup(),&quot;</span> \
<span class="s2">&quot;rerun the setup function with the new max_context_length to avoid buffer overflow.&quot;</span>
<span class="k">assert</span> <span class="n">beam_width</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">beam_width</span><span class="p">,</span> \
<span class="s2">&quot;Given beam width is different from the one used in setup(),&quot;</span> \
<span class="s2">&quot;rerun the setup function with the new beam width to avoid buffer overflow.&quot;</span>
<span class="n">ite</span> <span class="o">=</span> <span class="mi">0</span> <span class="c1"># index of local batches, will always be 0 if pp_size = 1</span>
<span class="bp">self</span><span class="o">.</span><span class="n">__setup_decoder</span><span class="p">(</span><span class="n">input_ids</span><span class="p">,</span> <span class="n">scfg</span><span class="p">,</span> <span class="n">host_context_lengths</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer_allocated</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span><span class="s1">&#39;Buffer not allocated, please call setup first!&#39;</span><span class="p">)</span>
<span class="n">sequence_limit_lengths</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="c1"># Sequence_lengths for the dynamic decoder still has the input paddings.</span>
<span class="n">sequence_lengths</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">((</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">beam_width</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
<span class="n">max_context_length</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="n">cache_indirections</span> <span class="o">=</span> <span class="p">[</span>
<span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">((</span>
<span class="n">batch_size</span><span class="p">,</span>
<span class="n">beam_width</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_kv_cache_length</span><span class="p">,</span>
<span class="p">),</span>
<span class="mi">0</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">),</span>
<span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">((</span>
<span class="n">batch_size</span><span class="p">,</span>
<span class="n">beam_width</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_kv_cache_length</span><span class="p">,</span>
<span class="p">),</span>
<span class="mi">0</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="p">]</span> <span class="c1"># ping-pong buffers</span>
<span class="n">hidden_states</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">has_pp</span><span class="p">():</span>
<span class="n">max_num_tokens</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">beam_width</span><span class="p">,</span>
<span class="n">batch_size</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">)</span>
<span class="n">hidden_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span>
<span class="n">hidden_states</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="n">max_num_tokens</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">))</span>
<span class="c1"># Init KV cache block manager</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span><span class="p">:</span>
<span class="n">max_blocks_per_seq</span> <span class="o">=</span> <span class="n">math</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">max_kv_cache_length</span> <span class="o">/</span>
<span class="bp">self</span><span class="o">.</span><span class="n">tokens_per_block</span><span class="p">)</span>
<span class="n">blocks</span> <span class="o">=</span> <span class="n">batch_size</span> <span class="o">*</span> <span class="n">beam_width</span> <span class="o">*</span> <span class="n">max_blocks_per_seq</span>
<span class="n">memory_pools</span> <span class="o">=</span> <span class="p">[</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;present_key_value_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">]</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">)</span>
<span class="p">]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_manager</span> <span class="o">=</span> <span class="n">KVCacheManager</span><span class="p">(</span><span class="n">memory_pools</span><span class="p">,</span> <span class="n">blocks</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">tokens_per_block</span><span class="p">,</span>
<span class="n">max_blocks_per_seq</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_kv_cache_length</span><span class="p">,</span>
<span class="n">beam_width</span><span class="p">)</span>
<span class="c1"># Add sequences to the manager</span>
<span class="k">for</span> <span class="n">bi</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
<span class="n">generation_sequence</span> <span class="o">=</span> <span class="n">GenerationSequence</span><span class="p">(</span><span class="n">seq_idx</span><span class="o">=</span><span class="n">bi</span><span class="p">,</span>
<span class="n">batch_idx</span><span class="o">=</span><span class="n">bi</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_manager</span><span class="o">.</span><span class="n">add_sequence</span><span class="p">(</span><span class="n">generation_sequence</span><span class="p">,</span>
<span class="n">max_context_length</span><span class="p">)</span>
<span class="c1"># start context phase</span>
<span class="k">if</span> <span class="n">streaming</span><span class="p">:</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">decode_stream</span><span class="p">(</span>
<span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="p">,</span> <span class="n">sequence_lengths</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span>
<span class="n">host_context_lengths</span><span class="p">,</span> <span class="n">max_context_length</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">,</span>
<span class="n">cache_indirections</span><span class="p">,</span> <span class="n">input_ids</span><span class="p">,</span> <span class="n">hidden_states</span><span class="p">,</span>
<span class="n">prompt_embedding_table</span><span class="p">,</span> <span class="n">tasks</span><span class="p">,</span> <span class="n">prompt_vocab_size</span><span class="p">,</span> <span class="n">ite</span><span class="p">,</span>
<span class="n">sequence_limit_lengths</span><span class="p">,</span> <span class="n">stop_words_list</span><span class="p">,</span> <span class="n">bad_words_list</span><span class="p">,</span>
<span class="n">no_repeat_ngram_size</span><span class="p">,</span> <span class="n">output_sequence_lengths</span><span class="p">,</span> <span class="n">return_dict</span><span class="p">,</span>
<span class="n">encoder_output</span><span class="p">,</span> <span class="n">encoder_input_lengths</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">decode_regular</span><span class="p">(</span>
<span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="p">,</span> <span class="n">sequence_lengths</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span>
<span class="n">host_context_lengths</span><span class="p">,</span> <span class="n">max_context_length</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">,</span>
<span class="n">cache_indirections</span><span class="p">,</span> <span class="n">input_ids</span><span class="p">,</span> <span class="n">hidden_states</span><span class="p">,</span>
<span class="n">prompt_embedding_table</span><span class="p">,</span> <span class="n">tasks</span><span class="p">,</span> <span class="n">prompt_vocab_size</span><span class="p">,</span> <span class="n">ite</span><span class="p">,</span>
<span class="n">sequence_limit_lengths</span><span class="p">,</span> <span class="n">stop_words_list</span><span class="p">,</span> <span class="n">bad_words_list</span><span class="p">,</span>
<span class="n">no_repeat_ngram_size</span><span class="p">,</span> <span class="n">output_sequence_lengths</span><span class="p">,</span> <span class="n">return_dict</span><span class="p">,</span>
<span class="n">encoder_output</span><span class="p">,</span> <span class="n">encoder_input_lengths</span><span class="p">)</span></div>
</div>
<div class="viewcode-block" id="ChatGLMGenerationSession">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.ChatGLMGenerationSession">[docs]</a>
<span class="k">class</span> <span class="nc">ChatGLMGenerationSession</span><span class="p">(</span><span class="n">GenerationSession</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">_prepare_context_inputs</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span>
<span class="n">use_gpt_attention_plugin</span><span class="p">,</span> <span class="n">remove_input_padding</span><span class="p">,</span>
<span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
<span class="n">max_context_length</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;max_context_length&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="n">remove_input_padding</span><span class="p">:</span>
<span class="n">input_lengths_acc</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span>
<span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">IntTensor</span><span class="p">([</span><span class="mi">0</span><span class="p">])</span><span class="o">.</span><span class="n">cuda</span><span class="p">(),</span> <span class="n">context_lengths</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">),</span>
<span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">input_lengths_acc</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]],</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
<span class="n">position_ids</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">input_lengths_acc</span><span class="p">[</span><span class="n">i</span><span class="p">]:</span><span class="n">input_lengths_acc</span><span class="p">[</span>
<span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span>
<span class="n">context_lengths</span><span class="p">[</span><span class="n">i</span><span class="p">],</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
<span class="n">position_ids</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">input_lengths_acc</span><span class="p">[</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]</span> <span class="o">-</span>
<span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">-</span> <span class="mi">2</span>
<span class="n">position_ids</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">input_lengths_acc</span><span class="p">[</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]</span> <span class="o">-</span> <span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">position_ids</span><span class="o">.</span><span class="n">int</span><span class="p">()</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">last_token_ids</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">int</span><span class="p">()</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">max_context_length</span><span class="p">],</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
<span class="n">position_ids</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">,</span> <span class="p">:]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">max_context_length</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
<span class="n">length</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
<span class="n">position_ids</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">length</span> <span class="o">-</span> <span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="n">length</span> <span class="o">-</span> <span class="mi">2</span>
<span class="n">position_ids</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">length</span> <span class="o">-</span> <span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span>
<span class="n">position_ids</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="p">:,</span> <span class="n">length</span><span class="p">:]</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">position_ids</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
<span class="n">inputs</span> <span class="o">=</span> <span class="p">{</span>
<span class="s1">&#39;position_ids&#39;</span><span class="p">:</span> <span class="n">position_ids</span><span class="p">,</span>
<span class="s1">&#39;last_token_ids&#39;</span><span class="p">:</span> <span class="n">last_token_ids</span>
<span class="p">}</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
<span class="n">attention_mask</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
<span class="n">inputs</span><span class="p">[</span><span class="s1">&#39;attention_mask&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">attention_mask</span>
<span class="k">return</span> <span class="n">inputs</span>
<span class="k">def</span> <span class="nf">_prepare_generation_inputs</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span>
<span class="n">use_gpt_attention_plugin</span><span class="p">,</span>
<span class="n">remove_input_padding</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="n">step</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;step&#39;</span><span class="p">)</span>
<span class="n">num_beams</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;num_beams&#39;</span><span class="p">)</span>
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones_like</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">)</span>
<span class="k">if</span> <span class="n">remove_input_padding</span><span class="p">:</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
<span class="n">position_ids</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="n">num_beams</span><span class="p">]</span> <span class="o">-</span> <span class="mi">2</span>
<span class="n">position_ids</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">step</span> <span class="o">+</span> <span class="mi">2</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">_tile_beam_width</span><span class="p">(</span><span class="n">position_ids</span><span class="p">,</span> <span class="n">num_beams</span><span class="p">)</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">position_ids</span><span class="o">.</span><span class="n">int</span><span class="p">()</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">last_token_ids</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">int</span><span class="p">()</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">data</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
<span class="n">data</span><span class="o">.</span><span class="n">append</span><span class="p">([[</span><span class="n">context_lengths</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="n">num_beams</span><span class="p">]</span> <span class="o">-</span> <span class="mi">2</span><span class="p">],</span> <span class="p">[</span><span class="n">step</span> <span class="o">+</span> <span class="mi">2</span><span class="p">]])</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">_tile_beam_width</span><span class="p">(</span><span class="n">position_ids</span><span class="p">,</span> <span class="n">num_beams</span><span class="p">)</span>
<span class="n">inputs</span> <span class="o">=</span> <span class="p">{</span>
<span class="s1">&#39;position_ids&#39;</span><span class="p">:</span> <span class="n">position_ids</span><span class="p">,</span>
<span class="s1">&#39;last_token_ids&#39;</span><span class="p">:</span> <span class="n">last_token_ids</span>
<span class="p">}</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
<span class="n">attention_mask</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
<span class="n">inputs</span><span class="p">[</span><span class="s1">&#39;attention_mask&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">attention_mask</span>
<span class="k">return</span> <span class="n">inputs</span></div>
</pre></div>
</div>
</div>
<footer>
<hr/>
<div role="contentinfo">
<p>&#169; Copyright 2023, NVidia.</p>
</div>
Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
<a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
provided by <a href="https://readthedocs.org">Read the Docs</a>.
</footer>
</div>
</div>
</section>
</div>
<script>
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
</script>
</body>
</html>