TensorRT-LLMs/_modules/tensorrt_llm/runtime/generation.html
2024-04-17 14:59:33 +08:00

3858 lines
645 KiB
HTML
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<!DOCTYPE html>
<html class="writer-html5" lang="en" data-content_root="../../../">
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>tensorrt_llm.runtime.generation &mdash; tensorrt_llm documentation</title>
<link rel="stylesheet" type="text/css" href="../../../_static/pygments.css?v=80d5e7a1" />
<link rel="stylesheet" type="text/css" href="../../../_static/css/theme.css?v=19f00094" />
<!--[if lt IE 9]>
<script src="../../../_static/js/html5shiv.min.js"></script>
<![endif]-->
<script src="../../../_static/jquery.js?v=5d32c60e"></script>
<script src="../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
<script src="../../../_static/documentation_options.js?v=5929fcd5"></script>
<script src="../../../_static/doctools.js?v=888ff710"></script>
<script src="../../../_static/sphinx_highlight.js?v=dc90522c"></script>
<script src="../../../_static/js/theme.js"></script>
<link rel="index" title="Index" href="../../../genindex.html" />
<link rel="search" title="Search" href="../../../search.html" />
</head>
<body class="wy-body-for-nav">
<div class="wy-grid-for-nav">
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
<div class="wy-side-scroll">
<div class="wy-side-nav-search" >
<a href="../../../index.html" class="icon icon-home">
tensorrt_llm
</a>
<div role="search">
<form id="rtd-search-form" class="wy-form" action="../../../search.html" method="get">
<input type="text" name="q" placeholder="Search docs" aria-label="Search docs" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
</div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../overview.html">Overview</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../quick-start-guide.html">Quick Start Guide</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../release-notes.html">Release Notes</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Installation</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../installation/linux.html">Installing on Linux</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../installation/build-from-source-linux.html">Building from Source Code on Linux</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../installation/windows.html">Installing on Windows</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../installation/build-from-source-windows.html">Building from Source Code on Windows</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Architecture</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../architecture/overview.html">TensorRT-LLM Architecture</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../architecture/core-concepts.html">Model Definition</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../architecture/core-concepts.html#compilation">Compilation</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../architecture/core-concepts.html#runtime">Runtime</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../architecture/core-concepts.html#multi-gpu-and-multi-node-support">Multi-GPU and Multi-Node Support</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../architecture/checkpoint.html">TensorRT-LLM Checkpoint</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../architecture/workflow.html">TensorRT-LLM Build Workflow</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../architecture/add-model.html">Adding a Model</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Advanced</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/gpt-attention.html">Multi-Head, Multi-Query, and Group-Query Attention</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/gpt-runtime.html">C++ GPT Runtime</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/graph-rewriting.html">Graph Rewriting Module</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/batch-manager.html">The Batch Manager in TensorRT-LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/inference-request.html">Inference Request</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/lora.html">Run gpt-2b + LoRA using GptManager / cpp runtime</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../performance/perf-overview.html">Overview</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../performance/perf-best-practices.html">Best Practices for Tuning the Performance of TensorRT-LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../performance/perf-analysis.html">Performance Analysis</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Reference</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../reference/troubleshooting.html">Troubleshooting</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../reference/support-matrix.html">Support Matrix</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../reference/precision.html">Numerical Precision</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../reference/memory.html">Memory Usage of TensorRT-LLM</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">C++ API</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../_cpp_gen/runtime.html">Runtime</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../python-api/tensorrt_llm.layers.html">Layers</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../python-api/tensorrt_llm.functional.html">Functionals</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../python-api/tensorrt_llm.models.html">Models</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../python-api/tensorrt_llm.plugin.html">Plugin</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../python-api/tensorrt_llm.quantization.html">Quantization</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../python-api/tensorrt_llm.runtime.html">Runtime</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Blogs</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/H100vsA100.html">H100 has 4.6x A100 Performance in TensorRT-LLM, achieving 10,000 tok/s at 100ms to first token</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/H200launch.html">H200 achieves nearly 12,000 tokens/sec on Llama2-13B with TensorRT-LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/Falcon180B-H200.html">Falcon-180B on a single H200 GPU with INT4 AWQ, and 6.7x faster Llama-70B over A100</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/quantization-in-TRT-LLM.html">Speed up inference with SOTA quantization techniques in TRT-LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/XQA-kernel.html">New XQA-kernel provides 2.4x more Llama-70B throughput within the same latency budget</a></li>
</ul>
</div>
</div>
</nav>
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
<a href="../../../index.html">tensorrt_llm</a>
</nav>
<div class="wy-nav-content">
<div class="rst-content">
<div role="navigation" aria-label="Page navigation">
<ul class="wy-breadcrumbs">
<li><a href="../../../index.html" class="icon icon-home" aria-label="Home"></a></li>
<li class="breadcrumb-item"><a href="../../index.html">Module code</a></li>
<li class="breadcrumb-item active">tensorrt_llm.runtime.generation</li>
<li class="wy-breadcrumbs-aside">
</li>
</ul>
<hr/>
</div>
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
<div itemprop="articleBody">
<h1>Source code for tensorrt_llm.runtime.generation</h1><div class="highlight"><pre>
<span></span><span class="c1"># SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION &amp; AFFILIATES. All rights reserved.</span>
<span class="c1"># SPDX-License-Identifier: Apache-2.0</span>
<span class="c1">#</span>
<span class="c1"># Licensed under the Apache License, Version 2.0 (the &quot;License&quot;);</span>
<span class="c1"># you may not use this file except in compliance with the License.</span>
<span class="c1"># You may obtain a copy of the License at</span>
<span class="c1">#</span>
<span class="c1"># http://www.apache.org/licenses/LICENSE-2.0</span>
<span class="c1">#</span>
<span class="c1"># Unless required by applicable law or agreed to in writing, software</span>
<span class="c1"># distributed under the License is distributed on an &quot;AS IS&quot; BASIS,</span>
<span class="c1"># WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.</span>
<span class="c1"># See the License for the specific language governing permissions and</span>
<span class="c1"># limitations under the License.</span>
<span class="kn">import</span> <span class="nn">copy</span>
<span class="kn">import</span> <span class="nn">csv</span>
<span class="kn">import</span> <span class="nn">math</span>
<span class="kn">from</span> <span class="nn">dataclasses</span> <span class="kn">import</span> <span class="n">dataclass</span><span class="p">,</span> <span class="n">field</span>
<span class="kn">from</span> <span class="nn">functools</span> <span class="kn">import</span> <span class="n">reduce</span><span class="p">,</span> <span class="n">wraps</span>
<span class="kn">from</span> <span class="nn">pathlib</span> <span class="kn">import</span> <span class="n">Path</span>
<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">Iterable</span><span class="p">,</span> <span class="n">List</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Sequence</span><span class="p">,</span> <span class="n">Union</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="kn">import</span> <span class="nn">tensorrt</span> <span class="k">as</span> <span class="nn">trt</span>
<span class="c1"># isort: off</span>
<span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">import</span> <span class="nn">tensorrt</span> <span class="k">as</span> <span class="nn">trt</span>
<span class="c1"># isort: on</span>
<span class="kn">from</span> <span class="nn">cuda</span> <span class="kn">import</span> <span class="n">cudart</span>
<span class="kn">from</span> <span class="nn">tensorrt_llm.plugin.plugin</span> <span class="kn">import</span> <span class="n">CustomAllReduceHelper</span>
<span class="kn">from</span> <span class="nn">.._ipc_utils</span> <span class="kn">import</span> <span class="n">set_peer_access</span>
<span class="kn">from</span> <span class="nn">.._utils</span> <span class="kn">import</span> <span class="p">(</span><span class="n">pad_vocab_size</span><span class="p">,</span> <span class="n">preview_trt_version</span><span class="p">,</span> <span class="n">str_dtype_to_torch</span><span class="p">,</span>
<span class="n">torch_to_numpy</span><span class="p">,</span> <span class="n">trt_dtype_to_torch</span><span class="p">)</span>
<span class="kn">from</span> <span class="nn">..logger</span> <span class="kn">import</span> <span class="n">logger</span>
<span class="kn">from</span> <span class="nn">..lora_manager</span> <span class="kn">import</span> <span class="n">LoraManager</span>
<span class="kn">from</span> <span class="nn">..mapping</span> <span class="kn">import</span> <span class="n">Mapping</span>
<span class="kn">from</span> <span class="nn">..quantization</span> <span class="kn">import</span> <span class="n">QuantMode</span>
<span class="kn">from</span> <span class="nn">.kv_cache_manager</span> <span class="kn">import</span> <span class="n">GenerationSequence</span><span class="p">,</span> <span class="n">KVCacheManager</span><span class="p">,</span> <span class="n">KVCacheUpdater</span>
<span class="kn">from</span> <span class="nn">.session</span> <span class="kn">import</span> <span class="n">_scoped_stream</span>
<div class="viewcode-block" id="to_word_list_format">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.to_word_list_format">[docs]</a>
<span class="k">def</span> <span class="nf">to_word_list_format</span><span class="p">(</span><span class="n">word_dict</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]],</span>
<span class="n">tokenizer</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">add_special_tokens</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> format of word_dict</span>
<span class="sd"> len(word_dict) should be same to batch_size</span>
<span class="sd"> word_dict[i] means the words for batch i</span>
<span class="sd"> len(word_dict[i]) must be 1, which means it only contains 1 string</span>
<span class="sd"> This string can contains several sentences and split by &quot;,&quot;.</span>
<span class="sd"> For example, if word_dict[2] = &quot; I am happy, I am sad&quot;, then this function will return</span>
<span class="sd"> the ids for two short sentences &quot; I am happy&quot; and &quot; I am sad&quot;.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">assert</span> <span class="n">tokenizer</span> <span class="o">!=</span> <span class="kc">None</span><span class="p">,</span> <span class="s2">&quot;need to set tokenizer&quot;</span>
<span class="n">flat_ids</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">offsets</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">word_dict_item</span> <span class="ow">in</span> <span class="n">word_dict</span><span class="p">:</span>
<span class="n">item_flat_ids</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">item_offsets</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">word_dict_item</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="nb">bytes</span><span class="p">):</span>
<span class="n">word_dict_item</span> <span class="o">=</span> <span class="p">[</span><span class="n">word_dict_item</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">decode</span><span class="p">()]</span>
<span class="n">words</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">csv</span><span class="o">.</span><span class="n">reader</span><span class="p">(</span><span class="n">word_dict_item</span><span class="p">))[</span><span class="mi">0</span><span class="p">]</span>
<span class="k">for</span> <span class="n">word</span> <span class="ow">in</span> <span class="n">words</span><span class="p">:</span>
<span class="n">ids</span> <span class="o">=</span> <span class="n">tokenizer</span><span class="o">.</span><span class="n">encode</span><span class="p">(</span><span class="n">word</span><span class="p">,</span> <span class="n">add_special_tokens</span><span class="o">=</span><span class="n">add_special_tokens</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">ids</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="k">continue</span>
<span class="n">item_flat_ids</span> <span class="o">+=</span> <span class="n">ids</span>
<span class="n">item_offsets</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">ids</span><span class="p">))</span>
<span class="n">flat_ids</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">item_flat_ids</span><span class="p">))</span>
<span class="n">offsets</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">item_offsets</span><span class="p">)))</span>
<span class="n">pad_to</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="nb">max</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">ids</span><span class="p">)</span> <span class="k">for</span> <span class="n">ids</span> <span class="ow">in</span> <span class="n">flat_ids</span><span class="p">))</span>
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="p">(</span><span class="n">ids</span><span class="p">,</span> <span class="n">offs</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="nb">zip</span><span class="p">(</span><span class="n">flat_ids</span><span class="p">,</span> <span class="n">offsets</span><span class="p">)):</span>
<span class="n">flat_ids</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">pad</span><span class="p">(</span><span class="n">ids</span><span class="p">,</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">pad_to</span> <span class="o">-</span> <span class="nb">len</span><span class="p">(</span><span class="n">ids</span><span class="p">)),</span> <span class="n">constant_values</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="n">offsets</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">pad</span><span class="p">(</span><span class="n">offs</span><span class="p">,</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">pad_to</span> <span class="o">-</span> <span class="nb">len</span><span class="p">(</span><span class="n">offs</span><span class="p">)),</span> <span class="n">constant_values</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">flat_ids</span><span class="p">,</span> <span class="n">offsets</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">&quot;int32&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">transpose</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">))</span></div>
<span class="k">def</span> <span class="nf">_prepare_input_ids</span><span class="p">(</span><span class="n">tensors</span><span class="p">:</span> <span class="n">Sequence</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]):</span>
<span class="n">tensors</span> <span class="o">=</span> <span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">flatten</span><span class="p">(</span><span class="n">t</span><span class="p">)</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="n">tensors</span><span class="p">]</span>
<span class="n">data</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">concat</span><span class="p">(</span><span class="n">tensors</span><span class="p">)</span>
<span class="n">row_lengths</span> <span class="o">=</span> <span class="p">[</span><span class="n">t</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="n">tensors</span><span class="p">]</span>
<span class="n">row_lengths</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">row_lengths</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="n">data</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="k">return</span> <span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">row_lengths</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">CUASSERT</span><span class="p">(</span><span class="n">cuda_ret</span><span class="p">):</span>
<span class="n">err</span> <span class="o">=</span> <span class="n">cuda_ret</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="k">if</span> <span class="n">err</span> <span class="o">!=</span> <span class="n">cudart</span><span class="o">.</span><span class="n">cudaError_t</span><span class="o">.</span><span class="n">cudaSuccess</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;CUDA ERROR: </span><span class="si">{</span><span class="n">err</span><span class="si">}</span><span class="s2">, error code reference: https://nvidia.github.io/cuda-python/module/cudart.html#cuda.cudart.cudaError_t&quot;</span>
<span class="p">)</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">cuda_ret</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
<span class="k">return</span> <span class="n">cuda_ret</span><span class="p">[</span><span class="mi">1</span><span class="p">:]</span>
<span class="k">return</span> <span class="kc">None</span>
<span class="k">def</span> <span class="nf">_update_cuda_graph_instance</span><span class="p">(</span><span class="n">instance</span><span class="p">,</span> <span class="n">graph</span><span class="p">):</span>
<span class="n">err</span> <span class="o">=</span> <span class="n">cudart</span><span class="o">.</span><span class="n">cudaGraphExecUpdate</span><span class="p">(</span><span class="n">instance</span><span class="p">,</span> <span class="n">graph</span><span class="p">)</span>
<span class="k">if</span> <span class="n">err</span> <span class="o">!=</span> <span class="n">cudart</span><span class="o">.</span><span class="n">cudaError_t</span><span class="o">.</span><span class="n">cudaSuccess</span><span class="p">:</span>
<span class="c1"># When updating cuda graph failed, destroy and instantiate one.</span>
<span class="n">CUASSERT</span><span class="p">(</span><span class="n">cudart</span><span class="o">.</span><span class="n">cudaGraphExecDestroy</span><span class="p">(</span><span class="n">instance</span><span class="p">))</span>
<span class="n">instance</span> <span class="o">=</span> <span class="n">CUASSERT</span><span class="p">(</span><span class="n">cudart</span><span class="o">.</span><span class="n">cudaGraphInstantiate</span><span class="p">(</span><span class="n">graph</span><span class="p">,</span> <span class="mi">0</span><span class="p">))[</span><span class="mi">0</span><span class="p">]</span>
<span class="k">return</span> <span class="n">instance</span>
<span class="k">def</span> <span class="nf">_prepare_attention_mask</span><span class="p">(</span><span class="n">input_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">pad_id</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
<span class="n">is_pad_id_in_inputs</span> <span class="o">=</span> <span class="p">(</span><span class="n">pad_id</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">)</span> <span class="ow">and</span> <span class="p">(</span><span class="n">pad_id</span> <span class="ow">in</span> <span class="n">input_ids</span><span class="p">)</span>
<span class="k">if</span> <span class="n">input_ids</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">is_pad_id_in_inputs</span><span class="p">:</span>
<span class="n">mask</span> <span class="o">=</span> <span class="n">input_ids</span><span class="o">.</span><span class="n">ne</span><span class="p">(</span><span class="n">pad_id</span><span class="p">)</span><span class="o">.</span><span class="n">int</span><span class="p">()</span>
<span class="c1"># for enc-dec models, pad_id could be the start token and should be always counted</span>
<span class="c1"># as valid token rather than padded token, so we force its mask to be 1.</span>
<span class="c1"># This doesn&#39;t impact the existing behavior</span>
<span class="n">mask</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span>
<span class="k">return</span> <span class="n">mask</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">input_ids</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="n">input_ids</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">_tile_beam_width</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">num_beams</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
<span class="n">new_shape</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="n">new_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">new_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">num_beams</span>
<span class="n">tile_size</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">new_shape</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
<span class="n">tile_size</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">insert</span><span class="p">(</span><span class="n">tile_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">num_beams</span><span class="p">)</span>
<span class="n">new_tensor</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">new_tensor</span> <span class="o">=</span> <span class="n">new_tensor</span><span class="o">.</span><span class="n">tile</span><span class="p">(</span><span class="n">tile_size</span><span class="o">.</span><span class="n">tolist</span><span class="p">())</span>
<span class="n">new_tensor</span> <span class="o">=</span> <span class="n">new_tensor</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">new_shape</span><span class="o">.</span><span class="n">tolist</span><span class="p">())</span>
<span class="k">return</span> <span class="n">new_tensor</span>
<span class="k">class</span> <span class="nc">_Runtime</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
<span class="n">runtime_rank</span><span class="p">:</span> <span class="nb">int</span>
<span class="n">runtime</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">Runtime</span>
<span class="n">engine</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">ICudaEngine</span>
<span class="n">ctx_context</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">IExecutionContext</span>
<span class="n">context_0</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">IExecutionContext</span>
<span class="n">context_1</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">IExecutionContext</span>
<span class="n">cuda_graph_instances</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">cudart</span><span class="o">.</span><span class="n">cudaGraphExec_t</span><span class="p">]</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">engine_buffer</span><span class="p">,</span> <span class="n">mapping</span><span class="p">:</span> <span class="n">Mapping</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">address</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">__prepare</span><span class="p">(</span><span class="n">mapping</span><span class="p">,</span> <span class="n">engine_buffer</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">_serialize_engine</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">trt</span><span class="o">.</span><span class="n">IHostMemory</span><span class="p">:</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">serialize</span><span class="p">()</span>
<span class="k">def</span> <span class="nf">__create_and_setup_context</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">address</span><span class="p">,</span> <span class="n">profile_idx</span><span class="p">,</span>
<span class="n">stream</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">trt</span><span class="o">.</span><span class="n">IExecutionContext</span><span class="p">:</span>
<span class="n">context</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">create_execution_context_without_device_memory</span><span class="p">()</span>
<span class="k">assert</span> <span class="n">context</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="n">context</span><span class="o">.</span><span class="n">device_memory</span> <span class="o">=</span> <span class="n">address</span>
<span class="n">context</span><span class="o">.</span><span class="n">set_optimization_profile_async</span><span class="p">(</span><span class="n">profile_idx</span><span class="p">,</span> <span class="n">stream</span><span class="p">)</span>
<span class="k">return</span> <span class="n">context</span>
<span class="k">def</span> <span class="nf">__prepare</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">mapping</span><span class="p">:</span> <span class="n">Mapping</span><span class="p">,</span> <span class="n">engine_buffer</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">runtime_rank</span> <span class="o">=</span> <span class="n">mapping</span><span class="o">.</span><span class="n">rank</span>
<span class="n">local_rank</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime_rank</span> <span class="o">%</span> <span class="n">mapping</span><span class="o">.</span><span class="n">gpus_per_node</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">set_device</span><span class="p">(</span><span class="n">local_rank</span><span class="p">)</span>
<span class="n">CUASSERT</span><span class="p">(</span><span class="n">cudart</span><span class="o">.</span><span class="n">cudaSetDevice</span><span class="p">(</span><span class="n">local_rank</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">Runtime</span><span class="p">(</span><span class="n">logger</span><span class="o">.</span><span class="n">trt_logger</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">engine</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">deserialize_cuda_engine</span><span class="p">(</span><span class="n">engine_buffer</span><span class="p">)</span>
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="c1"># The device_memory_size stores the memory required by the largest profile</span>
<span class="n">address</span> <span class="o">=</span> <span class="n">CUASSERT</span><span class="p">(</span><span class="n">cudart</span><span class="o">.</span><span class="n">cudaMalloc</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">device_memory_size</span><span class="p">))[</span><span class="mi">0</span><span class="p">]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">address</span> <span class="o">=</span> <span class="n">address</span>
<span class="c1"># cuda graph ping-pong instances</span>
<span class="bp">self</span><span class="o">.</span><span class="n">cuda_graph_instances</span> <span class="o">=</span> <span class="p">[</span><span class="kc">None</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">2</span><span class="p">)]</span>
<span class="k">with</span> <span class="n">_scoped_stream</span><span class="p">()</span> <span class="k">as</span> <span class="n">stream</span><span class="p">:</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">num_optimization_profiles</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
<span class="c1"># At step = 0, context_1 is active</span>
<span class="c1"># At step = 1, context_0 is active</span>
<span class="c1"># At step = 2, context_1 is active</span>
<span class="bp">self</span><span class="o">.</span><span class="n">context_0</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">__create_and_setup_context</span><span class="p">(</span>
<span class="n">address</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">stream</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">context_1</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">__create_and_setup_context</span><span class="p">(</span>
<span class="n">address</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">stream</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">ctx_context</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">context_1</span>
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">num_optimization_profiles</span> <span class="o">==</span> <span class="mi">2</span><span class="p">:</span>
<span class="c1"># At step = 0, ctx_context is active</span>
<span class="c1"># At step = 1, context_0 is active</span>
<span class="c1"># At step = 2, context_1 is active</span>
<span class="bp">self</span><span class="o">.</span><span class="n">ctx_context</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">__create_and_setup_context</span><span class="p">(</span>
<span class="n">address</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">stream</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">context_0</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">__create_and_setup_context</span><span class="p">(</span>
<span class="n">address</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">stream</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">context_1</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">__create_and_setup_context</span><span class="p">(</span>
<span class="n">address</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">stream</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">assert</span> <span class="kc">False</span><span class="p">,</span> <span class="s2">&quot;Maximum of up to two optimization profiles only&quot;</span>
<span class="k">def</span> <span class="nf">_set_shape</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">IExecutionContext</span><span class="p">,</span>
<span class="n">shape_dict</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]):</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">num_io_tensors</span><span class="p">):</span>
<span class="n">name</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_name</span><span class="p">(</span><span class="n">i</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">shape_dict</span><span class="p">:</span>
<span class="c1"># shape and buffer can be set by calling _set_tensors API</span>
<span class="k">continue</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_mode</span><span class="p">(</span><span class="n">name</span><span class="p">)</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">TensorIOMode</span><span class="o">.</span><span class="n">INPUT</span><span class="p">:</span>
<span class="n">ok</span> <span class="o">=</span> <span class="n">context</span><span class="o">.</span><span class="n">set_input_shape</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">shape_dict</span><span class="p">[</span><span class="n">name</span><span class="p">])</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_dtype</span><span class="p">(</span><span class="n">name</span><span class="p">)</span>
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;setting input tensor </span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s2"> with shape </span><span class="si">{</span><span class="n">shape_dict</span><span class="p">[</span><span class="n">name</span><span class="p">]</span><span class="si">}</span><span class="s2"> and type </span><span class="si">{</span><span class="n">dtype</span><span class="si">}</span><span class="s2">&quot;</span>
<span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">ok</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;Couldn&#39;t assign </span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s2"> with shape </span><span class="si">{</span><span class="n">shape_dict</span><span class="p">[</span><span class="n">name</span><span class="p">]</span><span class="si">}</span><span class="s2">, &quot;</span>
<span class="sa">f</span><span class="s2">&quot;engine supports [min, opt, max] = </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_profile_shape</span><span class="p">(</span><span class="n">name</span><span class="p">,</span><span class="w"> </span><span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">active_optimization_profile</span><span class="p">)</span><span class="si">}</span><span class="s2">&quot;</span>
<span class="p">)</span>
<span class="k">def</span> <span class="nf">_set_buffer</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">IExecutionContext</span><span class="p">,</span>
<span class="n">buffer_dict</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]):</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">num_io_tensors</span><span class="p">):</span>
<span class="n">name</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_name</span><span class="p">(</span><span class="n">i</span><span class="p">)</span>
<span class="k">if</span> <span class="n">name</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">buffer_dict</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_dtype</span><span class="p">(</span><span class="n">name</span><span class="p">)</span>
<span class="n">shape</span> <span class="o">=</span> <span class="n">context</span><span class="o">.</span><span class="n">get_tensor_shape</span><span class="p">(</span><span class="n">name</span><span class="p">)</span>
<span class="n">buffer_dict</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="nb">tuple</span><span class="p">(</span><span class="n">shape</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">trt_dtype_to_torch</span><span class="p">(</span><span class="n">dtype</span><span class="p">),</span>
<span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">buffer_dict</span><span class="p">[</span><span class="n">name</span><span class="p">]</span><span class="o">.</span><span class="n">is_contiguous</span><span class="p">(</span>
<span class="p">),</span> <span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s2"> is not contiguous()&quot;</span>
<span class="n">context</span><span class="o">.</span><span class="n">set_tensor_address</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">buffer_dict</span><span class="p">[</span><span class="n">name</span><span class="p">]</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">())</span>
<span class="k">def</span> <span class="nf">_set_tensors</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">IExecutionContext</span><span class="p">,</span>
<span class="n">tensors</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="s2">&quot;RuntimeTensor&quot;</span><span class="p">]):</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">num_io_tensors</span><span class="p">):</span>
<span class="n">name</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_name</span><span class="p">(</span><span class="n">i</span><span class="p">)</span>
<span class="c1"># it&#39;s allowed to call set_tensors multi times with different tensors</span>
<span class="c1"># each time only set some of the engine tensors, so it is valid to skip the ones not in the current given tensors dict</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">tensors</span><span class="p">:</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_mode</span><span class="p">(</span><span class="n">name</span><span class="p">)</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">TensorIOMode</span><span class="o">.</span><span class="n">OUTPUT</span><span class="p">:</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_dtype</span><span class="p">(</span><span class="n">name</span><span class="p">)</span>
<span class="n">shape</span> <span class="o">=</span> <span class="n">context</span><span class="o">.</span><span class="n">get_tensor_shape</span><span class="p">(</span><span class="n">name</span><span class="p">)</span>
<span class="n">tensors</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="n">RuntimeTensor</span><span class="o">.</span><span class="n">from_torch</span><span class="p">(</span>
<span class="n">name</span><span class="p">,</span>
<span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="nb">tuple</span><span class="p">(</span><span class="n">shape</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">trt_dtype_to_torch</span><span class="p">(</span><span class="n">dtype</span><span class="p">),</span>
<span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">))</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">continue</span>
<span class="n">t</span> <span class="o">=</span> <span class="n">tensors</span><span class="p">[</span><span class="n">name</span><span class="p">]</span>
<span class="c1"># output&#39;s shape is inference by TRT, no need to set the shape here</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_mode</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">name</span><span class="p">)</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">TensorIOMode</span><span class="o">.</span><span class="n">INPUT</span><span class="p">:</span>
<span class="n">context</span><span class="o">.</span><span class="n">set_input_shape</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="n">t</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="n">context</span><span class="o">.</span><span class="n">set_tensor_address</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="n">t</span><span class="o">.</span><span class="n">data</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">_check_tensors</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">IExecutionContext</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">num_io_tensors</span><span class="p">):</span>
<span class="n">name</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_name</span><span class="p">(</span><span class="n">i</span><span class="p">)</span>
<span class="n">ptr</span> <span class="o">=</span> <span class="n">context</span><span class="o">.</span><span class="n">get_tensor_address</span><span class="p">(</span><span class="n">name</span><span class="p">)</span>
<span class="k">if</span> <span class="n">ptr</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Engine I/O tensor </span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s2"> is unbound&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">_run</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">context</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">IExecutionContext</span><span class="p">,</span>
<span class="n">stream</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">Stream</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
<span class="k">if</span> <span class="n">stream</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">stream</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">current_stream</span><span class="p">()</span><span class="o">.</span><span class="n">cuda_stream</span>
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">stream</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">Stream</span><span class="p">):</span>
<span class="n">stream</span> <span class="o">=</span> <span class="n">stream</span><span class="o">.</span><span class="n">cuda_stream</span>
<span class="n">ok</span> <span class="o">=</span> <span class="n">context</span><span class="o">.</span><span class="n">execute_async_v3</span><span class="p">(</span><span class="n">stream</span><span class="p">)</span>
<span class="k">return</span> <span class="n">ok</span>
<span class="k">def</span> <span class="fm">__del__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">try</span><span class="p">:</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">address</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">cudart</span><span class="o">.</span><span class="n">cudaFree</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">address</span><span class="p">)</span>
<span class="k">except</span> <span class="ne">TypeError</span><span class="p">:</span>
<span class="k">pass</span>
<div class="viewcode-block" id="ModelConfig">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.ModelConfig">[docs]</a>
<span class="nd">@dataclass</span>
<span class="k">class</span> <span class="nc">ModelConfig</span><span class="p">:</span>
<span class="n">max_batch_size</span><span class="p">:</span> <span class="nb">int</span>
<span class="n">max_beam_width</span><span class="p">:</span> <span class="nb">int</span>
<span class="n">vocab_size</span><span class="p">:</span> <span class="nb">int</span>
<span class="n">num_layers</span><span class="p">:</span> <span class="nb">int</span>
<span class="n">num_heads</span><span class="p">:</span> <span class="nb">int</span>
<span class="n">num_kv_heads</span><span class="p">:</span> <span class="nb">int</span>
<span class="n">hidden_size</span><span class="p">:</span> <span class="nb">int</span>
<span class="n">gpt_attention_plugin</span><span class="p">:</span> <span class="nb">bool</span>
<span class="n">remove_input_padding</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
<span class="n">model_name</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;&quot;</span>
<span class="n">paged_kv_cache</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
<span class="n">cross_attention</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
<span class="n">head_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">has_position_embedding</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span>
<span class="n">has_token_type_embedding</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
<span class="n">tokens_per_block</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">64</span>
<span class="n">max_prompt_embedding_table_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">quant_mode</span><span class="p">:</span> <span class="n">QuantMode</span> <span class="o">=</span> <span class="n">QuantMode</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="n">gather_context_logits</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
<span class="n">gather_generation_logits</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
<span class="n">dtype</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;&quot;</span>
<span class="n">use_custom_all_reduce</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
<span class="n">lora_plugin</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
<span class="n">lora_target_modules</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default_factory</span><span class="o">=</span><span class="nb">list</span><span class="p">)</span>
<span class="n">use_context_fmha_for_generation</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
<span class="n">trtllm_modules_to_hf_modules</span><span class="p">:</span> <span class="nb">dict</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">skip_cross_qkv</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
<span class="n">num_medusa_heads</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">max_medusa_tokens</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">mamba_d_state</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">mamba_d_conv</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">mamba_expand</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">paged_state</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span>
<span class="n">mamba_conv1d_plugin</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span></div>
<div class="viewcode-block" id="SamplingConfig">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.SamplingConfig">[docs]</a>
<span class="nd">@dataclass</span>
<span class="k">class</span> <span class="nc">SamplingConfig</span><span class="p">:</span>
<span class="n">end_id</span><span class="p">:</span> <span class="nb">int</span>
<span class="n">pad_id</span><span class="p">:</span> <span class="nb">int</span>
<span class="n">max_new_tokens</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mi">20</span><span class="p">)</span>
<span class="n">num_beams</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="n">max_attention_window_size</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span>
<span class="n">sink_token_length</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span>
<span class="n">output_sequence_lengths</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="n">return_dict</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="n">stop_words_list</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span>
<span class="n">bad_words_list</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span>
<span class="n">temperature</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mf">1.0</span><span class="p">)</span>
<span class="n">top_k</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="n">top_p</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mf">0.0</span><span class="p">)</span>
<span class="n">top_p_decay</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span> <span class="c1"># float</span>
<span class="n">top_p_min</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span> <span class="c1"># float</span>
<span class="n">top_p_reset_ids</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span> <span class="c1"># int</span>
<span class="n">length_penalty</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mf">1.0</span><span class="p">)</span>
<span class="n">early_stopping</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="n">repetition_penalty</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mf">1.0</span><span class="p">)</span>
<span class="n">min_length</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="n">presence_penalty</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mf">0.0</span><span class="p">)</span>
<span class="n">frequency_penalty</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mf">0.0</span><span class="p">)</span>
<span class="n">use_beam_hyps</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="c1"># None here means user didn&#39;t set it, and dynamicDecodeOp.cpp take optional value</span>
<span class="c1"># The real default value is set in dynamicDecodeOp.cpp when it&#39;s None</span>
<span class="n">beam_search_diversity_rate</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">init</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">default</span><span class="o">=</span><span class="mf">0.0</span><span class="p">)</span>
<span class="n">random_seed</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">init</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span>
<span class="n">output_cum_log_probs</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">init</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="n">output_log_probs</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">init</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<div class="viewcode-block" id="SamplingConfig.update">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.SamplingConfig.update">[docs]</a>
<span class="k">def</span> <span class="nf">update</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="n">unused_kwargs</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span>
<span class="k">for</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span> <span class="ow">in</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">key</span><span class="p">):</span>
<span class="nb">setattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">unused_kwargs</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">value</span>
<span class="k">return</span> <span class="n">unused_kwargs</span></div>
</div>
<div class="viewcode-block" id="LogitsProcessor">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.LogitsProcessor">[docs]</a>
<span class="k">class</span> <span class="nc">LogitsProcessor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Base class for all logit processors that can be applied during generation.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">step</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">input_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">scores</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="si">}</span><span class="s2"> is an abstract class. Only classes inheriting this class can be called.&quot;</span>
<span class="p">)</span></div>
<div class="viewcode-block" id="LogitsProcessorList">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.LogitsProcessorList">[docs]</a>
<span class="k">class</span> <span class="nc">LogitsProcessorList</span><span class="p">(</span><span class="nb">list</span><span class="p">,</span> <span class="n">LogitsProcessor</span><span class="p">):</span>
<span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">step</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">input_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">scores</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span>
<span class="k">for</span> <span class="n">processor</span> <span class="ow">in</span> <span class="bp">self</span><span class="p">:</span>
<span class="n">scores</span> <span class="o">=</span> <span class="n">processor</span><span class="p">(</span><span class="n">step</span><span class="p">,</span> <span class="n">input_ids</span><span class="p">,</span> <span class="n">scores</span><span class="p">)</span>
<span class="k">return</span> <span class="n">scores</span></div>
<div class="viewcode-block" id="StoppingCriteria">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.StoppingCriteria">[docs]</a>
<span class="k">class</span> <span class="nc">StoppingCriteria</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Base class for all stopping criteria that can be applied during generation.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">step</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">input_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">scores</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">(</span><span class="s2">&quot;StoppingCriteria needs to be subclassed&quot;</span><span class="p">)</span></div>
<div class="viewcode-block" id="StoppingCriteriaList">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.StoppingCriteriaList">[docs]</a>
<span class="k">class</span> <span class="nc">StoppingCriteriaList</span><span class="p">(</span><span class="nb">list</span><span class="p">,</span> <span class="n">StoppingCriteria</span><span class="p">):</span>
<span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">step</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">input_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">scores</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
<span class="k">return</span> <span class="nb">any</span><span class="p">(</span><span class="n">criteria</span><span class="p">(</span><span class="n">step</span><span class="p">,</span> <span class="n">input_ids</span><span class="p">,</span> <span class="n">scores</span><span class="p">)</span> <span class="k">for</span> <span class="n">criteria</span> <span class="ow">in</span> <span class="bp">self</span><span class="p">)</span></div>
<span class="k">class</span> <span class="nc">RuntimeTensor</span><span class="p">:</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_name</span> <span class="o">=</span> <span class="s2">&quot;&quot;</span>
<span class="c1"># shape is the one sent to TRT, the actual torch tensor can be larger than the shape</span>
<span class="c1"># this is useful when allocating a big KV cache tensor at the beginning and incremental seq length dim of TRT engine&#39;s input tensor</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_shape</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_torch_tensor</span> <span class="o">=</span> <span class="kc">None</span>
<span class="nd">@staticmethod</span>
<span class="k">def</span> <span class="nf">from_torch</span><span class="p">(</span>
<span class="n">name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
<span class="n">data</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">override_shape</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Iterable</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s1">&#39;RuntimeTensor&#39;</span><span class="p">:</span>
<span class="k">assert</span> <span class="p">(</span><span class="nb">isinstance</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">))</span>
<span class="n">t</span> <span class="o">=</span> <span class="n">RuntimeTensor</span><span class="p">()</span>
<span class="n">t</span><span class="o">.</span><span class="n">_name</span> <span class="o">=</span> <span class="n">name</span>
<span class="c1"># need to hold the torch tensor for memory life time</span>
<span class="n">t</span><span class="o">.</span><span class="n">_torch_tensor</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span>
<span class="n">torch_shape</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">data</span><span class="o">.</span><span class="n">size</span><span class="p">())</span>
<span class="k">if</span> <span class="n">override_shape</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">t</span><span class="o">.</span><span class="n">_shape</span> <span class="o">=</span> <span class="n">override_shape</span>
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">override_shape</span><span class="p">,</span> <span class="nb">list</span><span class="p">)</span> <span class="ow">or</span> <span class="nb">isinstance</span><span class="p">(</span>
<span class="n">override_shape</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">)</span>
<span class="k">assert</span> <span class="nb">all</span><span class="p">([</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">x</span> <span class="o">&gt;=</span> <span class="mi">0</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">override_shape</span>
<span class="p">]),</span> <span class="sa">f</span><span class="s2">&quot;Expect all dimensions &gt;=0, got </span><span class="si">{</span><span class="n">override_shape</span><span class="si">}</span><span class="s2">&quot;</span>
<span class="k">def</span> <span class="nf">volume_func</span><span class="p">(</span><span class="n">dims</span><span class="p">):</span>
<span class="k">return</span> <span class="n">reduce</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">:</span> <span class="n">x</span> <span class="o">*</span> <span class="n">y</span><span class="p">,</span> <span class="n">dims</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">volume_func</span><span class="p">(</span><span class="n">override_shape</span><span class="p">)</span> <span class="o">&lt;=</span> <span class="n">volume_func</span><span class="p">(</span><span class="n">torch_shape</span><span class="p">),</span> \
<span class="sa">f</span><span class="s2">&quot;Override the shape to be larger than the underlying torch Tensor, got </span><span class="si">{</span><span class="n">override_shape</span><span class="si">}</span><span class="s2">, torch tensor shape </span><span class="si">{</span><span class="n">torch_shape</span><span class="si">}</span><span class="s2">&quot;</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">t</span><span class="o">.</span><span class="n">_shape</span> <span class="o">=</span> <span class="n">torch_shape</span>
<span class="k">return</span> <span class="n">t</span>
<span class="k">def</span> <span class="nf">to_torch</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_torch_tensor</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">shape</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Iterable</span><span class="p">[</span><span class="nb">int</span><span class="p">]:</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_shape</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">data</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_torch_tensor</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">()</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">name</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">str</span><span class="p">:</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_name</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">dtype</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">dtype</span><span class="p">:</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_torch_tensor</span><span class="o">.</span><span class="n">dtype</span>
<div class="viewcode-block" id="GenerationSession">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession">[docs]</a>
<span class="k">class</span> <span class="nc">GenerationSession</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
<span class="n">_model_config</span><span class="p">:</span> <span class="n">ModelConfig</span>
<span class="n">mapping</span><span class="p">:</span> <span class="n">Mapping</span>
<span class="n">runtime</span><span class="p">:</span> <span class="n">_Runtime</span>
<span class="n">device</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span>
<span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span>
<span class="n">buffer_allocated</span><span class="p">:</span> <span class="nb">bool</span>
<span class="n">debug_mode</span><span class="p">:</span> <span class="nb">bool</span>
<span class="n">quant_mode</span><span class="p">:</span> <span class="n">QuantMode</span>
<span class="n">cuda_graph_mode</span><span class="p">:</span> <span class="nb">bool</span>
<span class="n">dtype</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span>
<span class="n">debug_tensors_to_save</span><span class="p">:</span> <span class="kc">None</span>
<span class="n">num_medusa_tokens</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">medusa_topks</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">medusa_paths</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">medusa_tree_ids</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">medusa_position_offsets</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">medusa_temperature</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">model_config</span><span class="p">:</span> <span class="n">ModelConfig</span><span class="p">,</span>
<span class="n">engine_buffer</span><span class="p">,</span>
<span class="n">mapping</span><span class="p">:</span> <span class="n">Mapping</span><span class="p">,</span>
<span class="n">debug_mode</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">debug_tensors_to_save</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">cuda_graph_mode</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">stream</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">Stream</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">model_config</span><span class="p">,</span> <span class="n">ModelConfig</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span> <span class="o">=</span> <span class="n">model_config</span>
<span class="bp">self</span><span class="o">.</span><span class="n">mapping</span> <span class="o">=</span> <span class="n">mapping</span>
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span> <span class="o">=</span> <span class="n">_Runtime</span><span class="p">(</span><span class="n">engine_buffer</span><span class="p">,</span> <span class="n">mapping</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">device</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span>
<span class="sa">f</span><span class="s1">&#39;cuda:</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">runtime_rank</span><span class="w"> </span><span class="o">%</span><span class="w"> </span><span class="n">mapping</span><span class="o">.</span><span class="n">gpus_per_node</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">set_device</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="c1"># dynamic_decoder currently use torch&#39;s current stream, so must let TRT enqueue use same stream here</span>
<span class="bp">self</span><span class="o">.</span><span class="n">stream</span> <span class="o">=</span> <span class="n">stream</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">stream</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">stream</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">Stream</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">set_stream</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">stream</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">debug_mode</span> <span class="o">=</span> <span class="n">debug_mode</span>
<span class="bp">self</span><span class="o">.</span><span class="n">debug_tensors_to_save</span> <span class="o">=</span> <span class="n">debug_tensors_to_save</span>
<span class="bp">self</span><span class="o">.</span><span class="n">cuda_graph_mode</span> <span class="o">=</span> <span class="n">cuda_graph_mode</span>
<span class="c1"># Optional inputs for dynamic decoder</span>
<span class="bp">self</span><span class="o">.</span><span class="n">top_p_decay</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">top_p_min</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">top_p_reset_ids</span> <span class="o">=</span> <span class="kc">None</span>
<span class="c1"># TODO: in tensorrt_llm/cpp/tensorrt_llm/thop/dynamicDecodeOp.cpp it&#39;s T, can be float or half?</span>
<span class="bp">self</span><span class="o">.</span><span class="n">embedding_bias_opt</span> <span class="o">=</span> <span class="kc">None</span>
<span class="c1"># use one more block in paged kv cache.</span>
<span class="bp">self</span><span class="o">.</span><span class="n">use_one_more_block</span> <span class="o">=</span> <span class="kc">False</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer_allocated</span> <span class="o">=</span> <span class="kc">False</span>
<span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span> <span class="o">=</span> <span class="n">pad_vocab_size</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">vocab_size</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span><span class="p">:</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
<span class="s2">&quot;The paged KV cache in Python runtime is experimental. For performance and correctness, please, use C++ runtime.&quot;</span>
<span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">has_pp</span><span class="p">():</span>
<span class="bp">self</span><span class="o">.</span><span class="n">nccl_comm</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">classes</span><span class="o">.</span><span class="n">trtllm</span><span class="o">.</span><span class="n">NcclCommunicatorOp</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">rank</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
<span class="bp">self</span><span class="o">.</span><span class="n">decoder_logits_dtype</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_tensor_dtype</span><span class="p">(</span><span class="s1">&#39;logits&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder_logits_dtype</span> <span class="ow">not</span> <span class="ow">in</span> <span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">]:</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
<span class="s2">&quot;Logits dtype not supported by decoder. Falling back to float32. You may want to change the logits dtype to float16 in your model definition.&quot;</span>
<span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">decoder_logits_dtype</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span>
<span class="bp">self</span><span class="o">.</span><span class="n">dynamic_decoder</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">classes</span><span class="o">.</span><span class="n">trtllm</span><span class="o">.</span><span class="n">DynamicDecodeOp</span><span class="p">(</span>
<span class="n">model_config</span><span class="o">.</span><span class="n">max_batch_size</span><span class="p">,</span> <span class="n">model_config</span><span class="o">.</span><span class="n">max_beam_width</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">vocab_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder_logits_dtype</span><span class="p">)</span>
<span class="k">if</span> <span class="n">model_config</span><span class="o">.</span><span class="n">use_context_fmha_for_generation</span><span class="p">:</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
<span class="s2">&quot;Context FMHA is used for generation. Use it only for testing&quot;</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">gather_tree</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">tensorrt_llm</span><span class="o">.</span><span class="n">gather_tree</span>
<span class="n">expected_tensor_names</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_first_pp_rank</span><span class="p">():</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">&#39;input_ids&#39;</span><span class="p">]</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">&#39;hidden_states_input&#39;</span><span class="p">]</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">&#39;logits&#39;</span><span class="p">]</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">model_config</span><span class="o">.</span><span class="n">gather_context_logits</span><span class="p">:</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">&#39;last_token_ids&#39;</span><span class="p">]</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">&#39;hidden_states_output&#39;</span><span class="p">]</span>
<span class="k">if</span> <span class="n">model_config</span><span class="o">.</span><span class="n">has_position_embedding</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_first_pp_rank</span><span class="p">(</span>
<span class="p">):</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">&#39;position_ids&#39;</span><span class="p">]</span>
<span class="k">if</span> <span class="n">model_config</span><span class="o">.</span><span class="n">has_token_type_embedding</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_first_pp_rank</span><span class="p">(</span>
<span class="p">):</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">&#39;token_type_ids&#39;</span><span class="p">]</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">&#39;cache_indirection&#39;</span><span class="p">]</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span><span class="p">:</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="sa">f</span><span class="s1">&#39;kv_cache_block_pointers&#39;</span><span class="p">]</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="sa">f</span><span class="s1">&#39;host_kv_cache_block_pointers&#39;</span><span class="p">]</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
<span class="sa">f</span><span class="s1">&#39;past_key_value_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">)</span>
<span class="p">]</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
<span class="sa">f</span><span class="s1">&#39;present_key_value_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">)</span>
<span class="p">]</span>
<span class="k">if</span> <span class="n">model_config</span><span class="o">.</span><span class="n">gpt_attention_plugin</span><span class="p">:</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
<span class="s1">&#39;sequence_length&#39;</span><span class="p">,</span> <span class="s1">&#39;context_lengths&#39;</span><span class="p">,</span> <span class="s1">&#39;host_request_types&#39;</span><span class="p">,</span>
<span class="s1">&#39;host_past_key_value_lengths&#39;</span><span class="p">,</span> <span class="s1">&#39;host_sink_token_length&#39;</span>
<span class="p">]</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="sa">f</span><span class="s1">&#39;host_max_attention_window_sizes&#39;</span><span class="p">]</span>
<span class="k">if</span> <span class="n">model_config</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
<span class="n">expected_tensor_names</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="s1">&#39;host_context_lengths&#39;</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
<span class="s1">&#39;attention_mask&#39;</span><span class="p">,</span>
<span class="p">]</span>
<span class="k">if</span> <span class="n">model_config</span><span class="o">.</span><span class="n">max_prompt_embedding_table_size</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
<span class="s1">&#39;prompt_embedding_table&#39;</span><span class="p">,</span> <span class="s1">&#39;tasks&#39;</span><span class="p">,</span> <span class="s1">&#39;prompt_vocab_size&#39;</span>
<span class="p">]</span>
<span class="k">if</span> <span class="n">model_config</span><span class="o">.</span><span class="n">cross_attention</span><span class="p">:</span>
<span class="k">if</span> <span class="n">model_config</span><span class="o">.</span><span class="n">gpt_attention_plugin</span><span class="p">:</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
<span class="sa">f</span><span class="s1">&#39;cross_present_key_value_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">)</span>
<span class="p">]</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
<span class="sa">f</span><span class="s1">&#39;cross_past_key_value_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">)</span>
<span class="p">]</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
<span class="s1">&#39;cross_attention_mask&#39;</span><span class="p">,</span>
<span class="p">]</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
<span class="s1">&#39;encoder_output&#39;</span><span class="p">,</span>
<span class="s1">&#39;encoder_input_lengths&#39;</span><span class="p">,</span>
<span class="s1">&#39;encoder_max_input_length&#39;</span><span class="p">,</span>
<span class="s1">&#39;cross_kv_cache_gen&#39;</span><span class="p">,</span>
<span class="p">]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">skip_cross_qkv</span> <span class="o">=</span> <span class="n">model_config</span><span class="o">.</span><span class="n">skip_cross_qkv</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">skip_cross_qkv</span><span class="p">:</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">&#39;cross_qkv_reuse&#39;</span><span class="p">]</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span> <span class="o">&gt;</span> <span class="mi">1</span> <span class="ow">and</span> <span class="n">model_config</span><span class="o">.</span><span class="n">use_custom_all_reduce</span><span class="p">:</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">&#39;all_reduce_workspace&#39;</span><span class="p">]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">lora_target_modules</span> <span class="o">=</span> <span class="n">model_config</span><span class="o">.</span><span class="n">lora_target_modules</span>
<span class="c1"># In current design, q_lora_params, k_lora_params and v_lora_params should be all enabled or all disabled at the same time.</span>
<span class="c1"># However, there are some cases that the lora modules only contain one or two of them, so we use zero tensor to fill the missing ones.</span>
<span class="bp">self</span><span class="o">.</span><span class="n">missing_qkv_modules</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">lora_target_modules</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">if</span> <span class="nb">any</span><span class="p">(</span><span class="n">x</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">lora_target_modules</span>
<span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="p">[</span><span class="s2">&quot;attn_q&quot;</span><span class="p">,</span> <span class="s2">&quot;attn_k&quot;</span><span class="p">,</span> <span class="s2">&quot;attn_v&quot;</span><span class="p">]):</span>
<span class="k">for</span> <span class="n">lora_module</span> <span class="ow">in</span> <span class="p">[</span><span class="s2">&quot;attn_q&quot;</span><span class="p">,</span> <span class="s2">&quot;attn_k&quot;</span><span class="p">,</span> <span class="s2">&quot;attn_v&quot;</span><span class="p">]:</span>
<span class="k">if</span> <span class="n">lora_module</span> <span class="ow">not</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">lora_target_modules</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">missing_qkv_modules</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">lora_module</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">any</span><span class="p">(</span><span class="n">x</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">lora_target_modules</span>
<span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="p">[</span><span class="s2">&quot;cross_attn_q&quot;</span><span class="p">,</span> <span class="s2">&quot;cross_attn_k&quot;</span><span class="p">,</span> <span class="s2">&quot;cross_attn_v&quot;</span><span class="p">]):</span>
<span class="k">for</span> <span class="n">lora_module</span> <span class="ow">in</span> <span class="p">[</span>
<span class="s2">&quot;cross_attn_q&quot;</span><span class="p">,</span> <span class="s2">&quot;cross_attn_k&quot;</span><span class="p">,</span> <span class="s2">&quot;cross_attn_v&quot;</span>
<span class="p">]:</span>
<span class="k">if</span> <span class="n">lora_module</span> <span class="ow">not</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">lora_target_modules</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">missing_qkv_modules</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">lora_module</span><span class="p">)</span>
<span class="k">if</span> <span class="n">model_config</span><span class="o">.</span><span class="n">lora_plugin</span><span class="p">:</span>
<span class="k">for</span> <span class="n">lora_module</span> <span class="ow">in</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">lora_target_modules</span> <span class="o">+</span>
<span class="bp">self</span><span class="o">.</span><span class="n">missing_qkv_modules</span><span class="p">):</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
<span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">lora_module</span><span class="si">}</span><span class="s1">_lora_ranks_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">)</span>
<span class="p">]</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
<span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">lora_module</span><span class="si">}</span><span class="s1">_lora_weights_pointers_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">)</span>
<span class="p">]</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">&#39;host_encoder_input_lengths&#39;</span><span class="p">]</span>
<span class="k">if</span> <span class="n">model_config</span><span class="o">.</span><span class="n">num_medusa_heads</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
<span class="s1">&#39;medusa_position_offsets&#39;</span><span class="p">,</span> <span class="s1">&#39;medusa_packed_mask&#39;</span><span class="p">,</span> <span class="s1">&#39;medusa_logits&#39;</span>
<span class="p">]</span>
<span class="n">found_tensor_names</span> <span class="o">=</span> <span class="p">[</span>
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_name</span><span class="p">(</span><span class="n">i</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">num_io_tensors</span><span class="p">)</span>
<span class="p">]</span>
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">debug_mode</span> <span class="ow">and</span> <span class="nb">set</span><span class="p">(</span><span class="n">expected_tensor_names</span><span class="p">)</span> <span class="o">!=</span> <span class="nb">set</span><span class="p">(</span>
<span class="n">found_tensor_names</span><span class="p">):</span>
<span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;The following expected tensors are not found: </span><span class="si">{</span><span class="nb">set</span><span class="p">(</span><span class="n">expected_tensor_names</span><span class="p">)</span><span class="o">.</span><span class="n">difference</span><span class="p">(</span><span class="nb">set</span><span class="p">(</span><span class="n">found_tensor_names</span><span class="p">))</span><span class="si">}</span><span class="s2">&quot;</span>
<span class="p">)</span>
<span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;Those tensors in engine are not expected: </span><span class="si">{</span><span class="nb">set</span><span class="p">(</span><span class="n">found_tensor_names</span><span class="p">)</span><span class="o">.</span><span class="n">difference</span><span class="p">(</span><span class="nb">set</span><span class="p">(</span><span class="n">expected_tensor_names</span><span class="p">))</span><span class="si">}</span><span class="s2">&quot;</span>
<span class="p">)</span>
<span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Expected tensor names: </span><span class="si">{</span><span class="n">expected_tensor_names</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Found tensor names: </span><span class="si">{</span><span class="n">found_tensor_names</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span>
<span class="s2">&quot;Tensor names in engine are not the same as expected, to use this GenerationSession, &quot;</span>
<span class="s2">&quot;you need to use PretrainedModel.prepare_inputs to create TRT Network inputs.&quot;</span>
<span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">debug_mode</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">debug_tensors</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span>
<span class="nb">set</span><span class="p">(</span><span class="n">found_tensor_names</span><span class="p">)</span> <span class="o">-</span> <span class="nb">set</span><span class="p">(</span><span class="n">expected_tensor_names</span><span class="p">))</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">vocab_size</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">vocab_size</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">num_layers</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">num_layers</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_size</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> \
<span class="sa">f</span><span class="s2">&quot;num_layers </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">num_layers</span><span class="si">}</span><span class="s2"> must be a multiple of pipeline parallelism size </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_size</span><span class="si">}</span><span class="s2">&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">num_layers</span> <span class="o">//</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_size</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">first_layer</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_rank</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">last_layer</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">num_heads</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">num_heads</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">hidden_size</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">hidden_size</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">use_gpt_attention_plugin</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">gpt_attention_plugin</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">use_mamba_conv1d_plugin</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">mamba_conv1d_plugin</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">paged_kv_cache</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">paged_kv_cache</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">tokens_per_block</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">tokens_per_block</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">remove_input_padding</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">remove_input_padding</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">num_heads_kv</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">num_kv_heads</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">head_size</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span> <span class="o">//</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_heads</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">head_size</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">head_size</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">max_prompt_embedding_table_size</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">max_prompt_embedding_table_size</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">quant_mode</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">quant_mode</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">gather_context_logits</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">gather_context_logits</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">gather_generation_logits</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">gather_generation_logits</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">dtype</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="n">str_dtype_to_torch</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">use_custom_all_reduce</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">use_custom_all_reduce</span>
<div class="viewcode-block" id="GenerationSession.cuda_stream_guard">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.cuda_stream_guard">[docs]</a>
<span class="k">def</span> <span class="nf">cuda_stream_guard</span><span class="p">(</span><span class="n">func</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;Sync external stream and set current stream to the one bound to the session. Reset on exit.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="nd">@wraps</span><span class="p">(</span><span class="n">func</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">wrapper</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="n">external_stream</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">current_stream</span><span class="p">()</span>
<span class="k">if</span> <span class="n">external_stream</span> <span class="o">!=</span> <span class="bp">self</span><span class="o">.</span><span class="n">stream</span><span class="p">:</span>
<span class="n">external_stream</span><span class="o">.</span><span class="n">synchronize</span><span class="p">()</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">set_stream</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">stream</span><span class="p">)</span>
<span class="n">ret</span> <span class="o">=</span> <span class="n">func</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
<span class="k">if</span> <span class="n">external_stream</span> <span class="o">!=</span> <span class="bp">self</span><span class="o">.</span><span class="n">stream</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">stream</span><span class="o">.</span><span class="n">synchronize</span><span class="p">()</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">set_stream</span><span class="p">(</span><span class="n">external_stream</span><span class="p">)</span>
<span class="k">return</span> <span class="n">ret</span>
<span class="k">return</span> <span class="n">wrapper</span></div>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">cross_attention</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">cross_attention</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">has_position_embedding</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">has_position_embedding</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">has_token_type_embedding</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">has_token_type_embedding</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">use_lora_plugin</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">lora_plugin</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">use_context_fmha_for_generation</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">use_context_fmha_for_generation</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">is_medusa_mode</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_heads</span> <span class="o">&gt;</span> <span class="mi">0</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">max_medusa_tokens</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">max_medusa_tokens</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">num_medusa_heads</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">num_medusa_heads</span>
<span class="k">def</span> <span class="nf">_capture_cuda_graph_and_instantiate</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">,</span> <span class="n">stream</span><span class="p">,</span> <span class="n">step</span><span class="p">):</span>
<span class="n">instance_idx</span> <span class="o">=</span> <span class="p">(</span><span class="n">step</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">%</span> <span class="mi">2</span>
<span class="c1"># capture cuda graph</span>
<span class="n">CUASSERT</span><span class="p">(</span>
<span class="n">cudart</span><span class="o">.</span><span class="n">cudaStreamBeginCapture</span><span class="p">(</span>
<span class="n">stream</span><span class="p">,</span>
<span class="n">cudart</span><span class="o">.</span><span class="n">cudaStreamCaptureMode</span><span class="o">.</span><span class="n">cudaStreamCaptureModeGlobal</span><span class="p">))</span>
<span class="n">context</span><span class="o">.</span><span class="n">execute_async_v3</span><span class="p">(</span><span class="n">stream</span><span class="p">)</span>
<span class="n">next_graph</span> <span class="o">=</span> <span class="n">CUASSERT</span><span class="p">(</span><span class="n">cudart</span><span class="o">.</span><span class="n">cudaStreamEndCapture</span><span class="p">(</span><span class="n">stream</span><span class="p">))[</span><span class="mi">0</span><span class="p">]</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">cuda_graph_instances</span><span class="p">[</span><span class="n">instance_idx</span><span class="p">]</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">cuda_graph_instances</span><span class="p">[</span>
<span class="n">instance_idx</span><span class="p">]</span> <span class="o">=</span> <span class="n">_update_cuda_graph_instance</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">cuda_graph_instances</span><span class="p">[</span><span class="n">instance_idx</span><span class="p">],</span> <span class="n">next_graph</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">cuda_graph_instances</span><span class="p">[</span><span class="n">instance_idx</span><span class="p">]</span> <span class="o">=</span> <span class="n">CUASSERT</span><span class="p">(</span>
<span class="n">cudart</span><span class="o">.</span><span class="n">cudaGraphInstantiate</span><span class="p">(</span><span class="n">next_graph</span><span class="p">,</span> <span class="mi">0</span><span class="p">))[</span><span class="mi">0</span><span class="p">]</span>
<span class="c1"># Pre-upload cuda graph to stream</span>
<span class="n">CUASSERT</span><span class="p">(</span>
<span class="n">cudart</span><span class="o">.</span><span class="n">cudaGraphUpload</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">cuda_graph_instances</span><span class="p">[</span><span class="n">instance_idx</span><span class="p">],</span> <span class="n">stream</span><span class="p">))</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">paged_state</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">paged_state</span>
<span class="k">def</span> <span class="nf">__setup_decoder</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">sampling_config</span><span class="p">:</span> <span class="n">SamplingConfig</span><span class="p">,</span>
<span class="n">host_context_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;Allocate buffers and setup the post-processing decoder kernel</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">batch_size</span> <span class="o">=</span> <span class="n">host_context_lengths</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">scfg</span> <span class="o">=</span> <span class="n">sampling_config</span> <span class="c1"># just to make a shorter name, no other meaning</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scfg</span><span class="o">.</span><span class="n">top_k</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">top_k</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.top_k.dtype (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">top_k</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">) must be torch.int32&quot;</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">top_k</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
<span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">batch_size</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.top_k.shape[0] (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">top_k</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">) must equal to batch_size (</span><span class="si">{</span><span class="n">batch_size</span><span class="si">}</span><span class="s2">)&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">top_k</span> <span class="o">=</span> <span class="n">scfg</span><span class="o">.</span><span class="n">top_k</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">top_k</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span> <span class="n">scfg</span><span class="o">.</span><span class="n">top_k</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scfg</span><span class="o">.</span><span class="n">top_p</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">top_p</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.top_p.dtype (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">top_p</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">) must be torch.float32&quot;</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">top_p</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
<span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">batch_size</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.top_p.shape[0] (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">top_p</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">) must equal to batch_size (</span><span class="si">{</span><span class="n">batch_size</span><span class="si">}</span><span class="s2">)&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">top_p</span> <span class="o">=</span> <span class="n">scfg</span><span class="o">.</span><span class="n">top_p</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">top_p</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
<span class="n">scfg</span><span class="o">.</span><span class="n">top_p</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scfg</span><span class="o">.</span><span class="n">temperature</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">temperature</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.temperature.dtype (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">temperature</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">) must be torch.float32&quot;</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">temperature</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
<span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">batch_size</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.temperature.shape[0] (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">temperature</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">) must equal to batch_size (</span><span class="si">{</span><span class="n">batch_size</span><span class="si">}</span><span class="s2">)&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">temperature</span> <span class="o">=</span> <span class="n">scfg</span><span class="o">.</span><span class="n">temperature</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">temperature</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
<span class="n">scfg</span><span class="o">.</span><span class="n">temperature</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scfg</span><span class="o">.</span><span class="n">repetition_penalty</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">repetition_penalty</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.repetition_penalty.dtype (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">repetition_penalty</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">) must be torch.float32&quot;</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">repetition_penalty</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
<span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">batch_size</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.repetition_penalty.shape[0] (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">repetition_penalty</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">) must equal to batch_size (</span><span class="si">{</span><span class="n">batch_size</span><span class="si">}</span><span class="s2">)&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">repetition_penalty</span> <span class="o">=</span> <span class="n">scfg</span><span class="o">.</span><span class="n">repetition_penalty</span>
<span class="k">elif</span> <span class="n">scfg</span><span class="o">.</span><span class="n">repetition_penalty</span> <span class="o">==</span> <span class="mf">1.0</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">repetition_penalty</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">repetition_penalty</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
<span class="n">scfg</span><span class="o">.</span><span class="n">repetition_penalty</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scfg</span><span class="o">.</span><span class="n">length_penalty</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">length_penalty</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.length_penalty.dtype (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">length_penalty</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">) must be torch.float32&quot;</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">length_penalty</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
<span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">batch_size</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.length_penalty.shape[0] (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">length_penalty</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">) must equal to batch_size (</span><span class="si">{</span><span class="n">batch_size</span><span class="si">}</span><span class="s2">)&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">host_length_penalty</span> <span class="o">=</span> <span class="n">scfg</span><span class="o">.</span><span class="n">length_penalty</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">host_length_penalty</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
<span class="n">scfg</span><span class="o">.</span><span class="n">length_penalty</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">length_penalty</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">host_length_penalty</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scfg</span><span class="o">.</span><span class="n">early_stopping</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">early_stopping</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.early_stopping.dtype (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">early_stopping</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">) must be torch.int32&quot;</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">early_stopping</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
<span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">batch_size</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.early_stopping.shape[0] (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">early_stopping</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">) must equal to batch_size (</span><span class="si">{</span><span class="n">batch_size</span><span class="si">}</span><span class="s2">)&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">host_early_stopping</span> <span class="o">=</span> <span class="n">scfg</span><span class="o">.</span><span class="n">early_stopping</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">host_early_stopping</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
<span class="n">scfg</span><span class="o">.</span><span class="n">early_stopping</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">early_stopping</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">host_early_stopping</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scfg</span><span class="o">.</span><span class="n">presence_penalty</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">presence_penalty</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.presence_penalty.dtype (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">presence_penalty</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">) must be torch.float32&quot;</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">presence_penalty</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
<span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">batch_size</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.presence_penalty.shape[0] (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">presence_penalty</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">) must equal to batch_size (</span><span class="si">{</span><span class="n">batch_size</span><span class="si">}</span><span class="s2">)&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">presence_penalty</span> <span class="o">=</span> <span class="n">scfg</span><span class="o">.</span><span class="n">presence_penalty</span>
<span class="k">elif</span> <span class="n">scfg</span><span class="o">.</span><span class="n">presence_penalty</span> <span class="o">==</span> <span class="mf">0.0</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">presence_penalty</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">presence_penalty</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
<span class="n">scfg</span><span class="o">.</span><span class="n">presence_penalty</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scfg</span><span class="o">.</span><span class="n">frequency_penalty</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">frequency_penalty</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.frequency_penalty.dtype (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">frequency_penalty</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">) must be torch.float32&quot;</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">frequency_penalty</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
<span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">batch_size</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.frequency_penalty.shape[0] (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">frequency_penalty</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">) must equal to batch_size (</span><span class="si">{</span><span class="n">batch_size</span><span class="si">}</span><span class="s2">)&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">frequency_penalty</span> <span class="o">=</span> <span class="n">scfg</span><span class="o">.</span><span class="n">frequency_penalty</span>
<span class="k">elif</span> <span class="n">scfg</span><span class="o">.</span><span class="n">frequency_penalty</span> <span class="o">==</span> <span class="mf">0.0</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">frequency_penalty</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">frequency_penalty</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
<span class="n">scfg</span><span class="o">.</span><span class="n">frequency_penalty</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scfg</span><span class="o">.</span><span class="n">min_length</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">min_length</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.min_length.dtype (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">min_length</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">) must be torch.int32&quot;</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">min_length</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
<span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">batch_size</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.min_length.shape[0] (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">min_length</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">) must equal to batch_size (</span><span class="si">{</span><span class="n">batch_size</span><span class="si">}</span><span class="s2">)&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">min_length</span> <span class="o">=</span> <span class="n">scfg</span><span class="o">.</span><span class="n">min_length</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">min_length</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
<span class="n">scfg</span><span class="o">.</span><span class="n">min_length</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scfg</span><span class="o">.</span><span class="n">beam_search_diversity_rate</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">beam_search_diversity_rate</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.beam_search_diversity_rate.dtype (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">beam_search_diversity_rate</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">) must be torch.float32&quot;</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">beam_search_diversity_rate</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
<span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">batch_size</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.beam_search_diversity_rate.shape[0] (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">beam_search_diversity_rate</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">) must equal to batch_size (</span><span class="si">{</span><span class="n">batch_size</span><span class="si">}</span><span class="s2">)&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_search_diversity_rate</span> <span class="o">=</span> <span class="n">scfg</span><span class="o">.</span><span class="n">beam_search_diversity_rate</span>
<span class="k">elif</span> <span class="n">scfg</span><span class="o">.</span><span class="n">beam_search_diversity_rate</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_search_diversity_rate</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">(</span>
<span class="p">[</span><span class="n">batch_size</span><span class="p">],</span>
<span class="n">scfg</span><span class="o">.</span><span class="n">beam_search_diversity_rate</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_search_diversity_rate</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scfg</span><span class="o">.</span><span class="n">random_seed</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">random_seed</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">int64</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.random_seed.dtype (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">random_seed</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">) must be torch.int64&quot;</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">random_seed</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
<span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">batch_size</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.random_seed.shape[0] (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">random_seed</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">) must equal to batch_size (</span><span class="si">{</span><span class="n">batch_size</span><span class="si">}</span><span class="s2">)&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">random_seed</span> <span class="o">=</span> <span class="n">scfg</span><span class="o">.</span><span class="n">random_seed</span>
<span class="k">elif</span> <span class="n">scfg</span><span class="o">.</span><span class="n">random_seed</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">random_seed</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
<span class="n">scfg</span><span class="o">.</span><span class="n">random_seed</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int64</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">random_seed</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
<span class="bp">self</span><span class="o">.</span><span class="n">dynamic_decoder</span><span class="o">.</span><span class="n">setup</span><span class="p">(</span>
<span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">top_k</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">top_p</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">temperature</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">repetition_penalty</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">presence_penalty</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">frequency_penalty</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">min_length</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">host_length_penalty</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">host_early_stopping</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_search_diversity_rate</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">random_seed</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">top_p_decay</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">top_p_min</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">top_p_reset_ids</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">end_id</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">,</span> <span class="s2">&quot;end_id cannot be none&quot;</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">pad_id</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">,</span> <span class="s1">&#39;pad_id cannot be none&#39;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">end_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">((</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span><span class="p">,</span> <span class="p">),</span>
<span class="n">scfg</span><span class="o">.</span><span class="n">end_id</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="n">max_context_length</span> <span class="o">=</span> <span class="n">host_context_lengths</span><span class="o">.</span><span class="n">max</span><span class="p">()</span>
<span class="c1"># setup output ids buffer</span>
<span class="k">if</span> <span class="n">input_ids</span><span class="o">.</span><span class="n">dim</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
<span class="c1"># input_ids only have one dimension, which means remove_padding is enabled</span>
<span class="n">split_ids_list</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span>
<span class="n">torch</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">input_ids</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
<span class="n">host_context_lengths</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span><span class="o">.</span><span class="n">tolist</span><span class="p">(),</span>
<span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">))</span>
<span class="n">padded_input_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nested</span><span class="o">.</span><span class="n">to_padded_tensor</span><span class="p">(</span>
<span class="n">torch</span><span class="o">.</span><span class="n">nested</span><span class="o">.</span><span class="n">nested_tensor</span><span class="p">(</span><span class="n">split_ids_list</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">),</span>
<span class="n">scfg</span><span class="o">.</span><span class="n">pad_id</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">max_context_length</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">padded_input_ids</span> <span class="o">=</span> <span class="n">input_ids</span>
<span class="k">if</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
<span class="n">tiled_input_ids</span> <span class="o">=</span> <span class="n">_tile_beam_width</span><span class="p">(</span><span class="n">padded_input_ids</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span><span class="p">)</span>
<span class="n">tiled_input_ids</span> <span class="o">=</span> <span class="n">tiled_input_ids</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span>
<span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span><span class="p">,</span>
<span class="n">max_context_length</span><span class="p">)</span>
<span class="n">tiled_input_ids</span><span class="o">.</span><span class="n">permute</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="c1"># TODO: delete?</span>
<span class="bp">self</span><span class="o">.</span><span class="n">output_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span>
<span class="p">(</span><span class="n">tiled_input_ids</span><span class="p">,</span>
<span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span> <span class="o">-</span> <span class="n">max_context_length</span><span class="p">),</span>
<span class="n">scfg</span><span class="o">.</span><span class="n">end_id</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">padded_input_ids</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="n">padded_input_ids</span><span class="o">.</span><span class="n">device</span><span class="p">)),</span>
<span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">output_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span>
<span class="p">(</span><span class="n">padded_input_ids</span><span class="p">,</span>
<span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">(</span>
<span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span> <span class="o">-</span> <span class="n">max_context_length</span><span class="p">),</span>
<span class="n">scfg</span><span class="o">.</span><span class="n">end_id</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">padded_input_ids</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="n">padded_input_ids</span><span class="o">.</span><span class="n">device</span><span class="p">)),</span>
<span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="c1"># Note: we still allocate max_seq_length size of parent ids (not max_attention_window_size).</span>
<span class="bp">self</span><span class="o">.</span><span class="n">parent_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span>
<span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span>
<span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_tokens</span> <span class="o">+</span> <span class="mi">1</span><span class="p">],</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">generation_input_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span>
<span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_tokens</span> <span class="o">+</span> <span class="mi">1</span><span class="p">],</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">medusa_output_tokens</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span>
<span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_tokens</span><span class="p">],</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">medusa_temperature</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">medusa_output_logits</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span>
<span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_heads</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">],</span>
<span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_tensor_dtype</span><span class="p">(</span><span class="s1">&#39;logits&#39;</span><span class="p">),</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="k">elif</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span><span class="p">,</span> <span class="mi">1</span><span class="p">],</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">],</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="k">if</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span> <span class="o">&gt;</span> <span class="mi">1</span> <span class="ow">or</span> <span class="n">scfg</span><span class="o">.</span><span class="n">output_cum_log_probs</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">cum_log_probs</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span><span class="p">),</span>
<span class="o">-</span><span class="mf">1e20</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">cum_log_probs</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="mf">0.0</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">cum_log_probs</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">if</span> <span class="n">scfg</span><span class="o">.</span><span class="n">output_log_probs</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">log_probs</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span>
<span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">log_probs_tiled</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span>
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">log_probs</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">log_probs_tiled</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">finished</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">uint8</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="k">if</span> <span class="n">scfg</span><span class="o">.</span><span class="n">use_beam_hyps</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_output_ids_tgt</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">(</span>
<span class="n">size</span><span class="o">=</span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span> <span class="o">*</span> <span class="mi">2</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">],</span>
<span class="n">fill_value</span><span class="o">=</span><span class="n">scfg</span><span class="o">.</span><span class="n">end_id</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_sequence_lengths_tgt</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span>
<span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span> <span class="o">*</span> <span class="mi">2</span><span class="p">],</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_cum_log_probs</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span>
<span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span> <span class="o">*</span> <span class="mi">2</span><span class="p">],</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_normed_scores</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span>
<span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span> <span class="o">*</span> <span class="mi">2</span><span class="p">],</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_log_probs</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span>
<span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span> <span class="o">*</span> <span class="mi">2</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">],</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_min_normed_scores</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_num_beams</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_is_done</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">bool</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_output_ids_tgt</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_sequence_lengths_tgt</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_cum_log_probs</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_normed_scores</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_log_probs</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_min_normed_scores</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_num_beams</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_is_done</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">cross_qkv_reuse</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">def</span> <span class="nf">_tensor_dtype</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">name</span><span class="p">):</span>
<span class="c1"># return torch dtype given tensor name for convenience</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="n">trt_dtype_to_torch</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_dtype</span><span class="p">(</span><span class="n">name</span><span class="p">))</span>
<span class="k">return</span> <span class="n">dtype</span>
<span class="k">def</span> <span class="nf">_init_medusa</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">medusa_choices</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]):</span>
<span class="kn">from</span> <span class="nn">tensorrt_llm.runtime.medusa_utils</span> <span class="kn">import</span> <span class="p">(</span><span class="n">_medusa_setup</span><span class="p">,</span>
<span class="n">expand_choices_if_needed</span><span class="p">)</span>
<span class="n">medusa_choices</span> <span class="o">=</span> <span class="n">expand_choices_if_needed</span><span class="p">(</span><span class="n">medusa_choices</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_tokens</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">medusa_choices</span><span class="p">)</span>
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_tokens</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_tokens</span> <span class="o">&lt;=</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_medusa_tokens</span>
<span class="n">medusa_info</span> <span class="o">=</span> <span class="n">_medusa_setup</span><span class="p">(</span><span class="n">medusa_choices</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_heads</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">medusa_topks</span> <span class="o">=</span> <span class="n">medusa_info</span><span class="o">.</span><span class="n">medusa_topks</span>
<span class="bp">self</span><span class="o">.</span><span class="n">medusa_mask</span> <span class="o">=</span> <span class="n">medusa_info</span><span class="o">.</span><span class="n">medusa_mask</span><span class="p">[</span><span class="mi">1</span><span class="p">:,</span> <span class="mi">1</span><span class="p">:]</span><span class="o">.</span><span class="n">to</span><span class="p">(</span>
<span class="n">torch</span><span class="o">.</span><span class="n">bool</span>
<span class="p">)</span> <span class="c1"># convert to bool, original mask includes true token as well</span>
<span class="c1"># Expand medusa position offsets to number of batch size in order to be compatible with the new Medusa.</span>
<span class="n">target_shape</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">medusa_info</span><span class="o">.</span><span class="n">medusa_packed_mask</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="n">target_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span>
<span class="bp">self</span><span class="o">.</span><span class="n">medusa_packed_mask</span> <span class="o">=</span> <span class="n">medusa_info</span><span class="o">.</span><span class="n">medusa_packed_mask</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span>
<span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">expand</span><span class="p">(</span><span class="n">target_shape</span><span class="p">)</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">medusa_paths</span> <span class="o">=</span> <span class="n">medusa_info</span><span class="o">.</span><span class="n">medusa_paths</span>
<span class="bp">self</span><span class="o">.</span><span class="n">medusa_tree_ids</span> <span class="o">=</span> <span class="n">medusa_info</span><span class="o">.</span><span class="n">medusa_tree_ids</span>
<span class="c1"># Expand medusa position offsets to number of batch size in order to be compatible with the new Medusa.</span>
<span class="n">target_shape</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span>
<span class="n">medusa_info</span><span class="o">.</span><span class="n">medusa_position_offsets</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="n">target_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span>
<span class="bp">self</span><span class="o">.</span><span class="n">medusa_position_offsets</span> <span class="o">=</span> <span class="n">medusa_info</span><span class="o">.</span><span class="n">medusa_position_offsets</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span>
<span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">expand</span><span class="p">(</span><span class="n">target_shape</span><span class="p">)</span><span class="o">.</span><span class="n">int</span><span class="p">()</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
<span class="n">medusa_fp_mask</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">medusa_mask</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="n">medusa_fp_mask</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">logical_not</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">medusa_mask</span><span class="p">)]</span> <span class="o">=</span> <span class="nb">float</span><span class="p">(</span><span class="s1">&#39;-inf&#39;</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">medusa_mask</span> <span class="o">=</span> <span class="n">medusa_fp_mask</span>
<span class="k">return</span>
<div class="viewcode-block" id="GenerationSession.setup">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.setup">[docs]</a>
<span class="k">def</span> <span class="nf">setup</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">max_context_length</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">max_new_tokens</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">beam_width</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
<span class="n">max_attention_window_size</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">sink_token_length</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">encoder_max_input_length</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">lora_manager</span><span class="p">:</span> <span class="n">LoraManager</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">lora_uids</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">medusa_choices</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
<span class="c1"># Store these params related to buffer size to check against</span>
<span class="c1"># the input shape with the params given in decode()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">=</span> <span class="n">batch_size</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_context_length</span> <span class="o">=</span> <span class="n">max_context_length</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_new_tokens</span> <span class="o">=</span> <span class="n">max_new_tokens</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span> <span class="o">=</span> <span class="n">max_context_length</span> <span class="o">+</span> <span class="n">max_new_tokens</span>
<span class="k">if</span> <span class="n">medusa_choices</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span> <span class="o">+=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">max_medusa_tokens</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_width</span> <span class="o">=</span> <span class="n">beam_width</span>
<span class="bp">self</span><span class="o">.</span><span class="n">encoder_max_input_length</span> <span class="o">=</span> <span class="n">encoder_max_input_length</span>
<span class="k">if</span> <span class="n">max_attention_window_size</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span>
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span>
<span class="s2">&quot;The max_attention_window_size is not set, we will use max_seq_length by default.&quot;</span>
<span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">host_max_attention_window_sizes</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span>
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span><span class="p">,</span> <span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window_size</span>
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">max_attention_window_size</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
<span class="k">if</span> <span class="n">max_attention_window_size</span> <span class="o">&gt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">:</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
<span class="s2">&quot;The value of max_attention_window_size should ideally not exceed max_seq_length. &quot;</span>
<span class="s2">&quot;Therefore, it has been adjusted to match the value of max_seq_length.&quot;</span>
<span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window_size</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">max_attention_window_size</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">host_max_attention_window_sizes</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span>
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span><span class="p">,</span> <span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window_size</span>
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">max_attention_window_size</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window_size</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span>
<span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">max_attention_window_size</span><span class="p">)</span><span class="o">.</span><span class="n">item</span><span class="p">())</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window_size</span> <span class="o">&gt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">:</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
<span class="s2">&quot;The value of max_attention_window_size should ideally not exceed max_seq_length. &quot;</span>
<span class="s2">&quot;Therefore, it has been adjusted to match the value of max_seq_length.&quot;</span>
<span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window_size</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window_size</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">)</span>
<span class="k">if</span> <span class="n">max_attention_window_size</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">!=</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span><span class="p">:</span>
<span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span>
<span class="s2">&quot;max_attention_window_size tensor&#39;s size is not equal to num_layers! &quot;</span>
<span class="s2">&quot;Note that num_layers = num_total_layers // pipeline_parallelism_size.&quot;</span>
<span class="p">)</span>
<span class="k">assert</span> <span class="kc">False</span>
<span class="bp">self</span><span class="o">.</span><span class="n">host_max_attention_window_sizes</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">minimum</span><span class="p">(</span>
<span class="n">max_attention_window_size</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">torch</span><span class="o">.</span><span class="n">IntTensor</span><span class="p">([</span><span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">]</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span><span class="p">))</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">assert</span> <span class="kc">False</span><span class="p">,</span> <span class="s2">&quot;invalid max_attention_window_size!&quot;</span>
<span class="k">if</span> <span class="n">sink_token_length</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">sink_token_length</span> <span class="o">=</span> <span class="mi">0</span>
<span class="bp">self</span><span class="o">.</span><span class="n">host_sink_token_length</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">sink_token_length</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">sink_token_length</span> <span class="o">=</span> <span class="n">sink_token_length</span>
<span class="bp">self</span><span class="o">.</span><span class="n">host_sink_token_length</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span>
<span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">sink_token_length</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">assert</span> <span class="kc">False</span><span class="p">,</span> <span class="s2">&quot;invalid sink_token_length!&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">use_one_more_block</span> <span class="o">=</span> <span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span> <span class="ow">and</span> <span class="n">beam_width</span> <span class="o">&gt;</span> <span class="mi">1</span>
<span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span> <span class="o">&gt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window_size</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">lora_manager</span> <span class="o">=</span> <span class="n">lora_manager</span>
<span class="k">if</span> <span class="n">medusa_choices</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_init_medusa</span><span class="p">(</span><span class="n">medusa_choices</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span> <span class="o">=</span> <span class="p">{}</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;logits&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span>
<span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_tokens</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_context_logits</span> <span class="k">else</span>
<span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">max_context_length</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_tensor_dtype</span><span class="p">(</span><span class="s1">&#39;logits&#39;</span><span class="p">),</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="n">medusa_logits_shape</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_heads</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span>
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_tokens</span> <span class="o">+</span> <span class="mi">1</span><span class="p">),</span>
<span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
<span class="n">medusa_logits_shape</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_heads</span><span class="p">,</span> <span class="n">batch_size</span> <span class="o">*</span>
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_tokens</span> <span class="o">+</span> <span class="mi">1</span><span class="p">),</span>
<span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;medusa_logits&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span>
<span class="n">medusa_logits_shape</span> <span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_context_logits</span> <span class="k">else</span>
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_heads</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">max_context_length</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_tensor_dtype</span><span class="p">(</span><span class="s1">&#39;medusa_logits&#39;</span><span class="p">),</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;logits&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span>
<span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_context_logits</span> <span class="k">else</span>
<span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">max_context_length</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_tensor_dtype</span><span class="p">(</span><span class="s1">&#39;logits&#39;</span><span class="p">),</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span><span class="p">:</span>
<span class="c1"># use shape info to pass max length info in remove padding mode</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;encoder_max_input_length&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span>
<span class="p">(</span><span class="n">encoder_max_input_length</span><span class="p">,</span> <span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_tensor_dtype</span><span class="p">(</span><span class="s1">&#39;encoder_max_input_length&#39;</span><span class="p">),</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span><span class="p">:</span>
<span class="n">bubble_len</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">sink_token_length</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">tokens_per_block</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">bubble_len</span> <span class="o">+=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">tokens_per_block</span> <span class="o">-</span>
<span class="bp">self</span><span class="o">.</span><span class="n">sink_token_length</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">tokens_per_block</span><span class="p">)</span>
<span class="n">blocks</span> <span class="o">=</span> <span class="n">batch_size</span> <span class="o">*</span> <span class="n">beam_width</span> <span class="o">*</span> <span class="n">math</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span>
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window_size</span> <span class="o">+</span> <span class="n">bubble_len</span><span class="p">)</span> <span class="o">/</span>
<span class="bp">self</span><span class="o">.</span><span class="n">tokens_per_block</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_one_more_block</span><span class="p">:</span>
<span class="n">blocks</span> <span class="o">+=</span> <span class="n">batch_size</span> <span class="o">*</span> <span class="n">beam_width</span>
<span class="n">cache_shape</span> <span class="o">=</span> <span class="p">(</span>
<span class="n">blocks</span><span class="p">,</span>
<span class="mi">2</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">num_heads_kv</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">tokens_per_block</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">head_size</span><span class="p">,</span>
<span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">cache_shape</span> <span class="o">=</span> <span class="p">(</span>
<span class="n">batch_size</span><span class="p">,</span>
<span class="mi">2</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">num_heads_kv</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window_size</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">head_size</span><span class="p">,</span>
<span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span><span class="p">:</span>
<span class="n">cross_cache_shape</span> <span class="o">=</span> <span class="p">(</span>
<span class="n">batch_size</span><span class="p">,</span>
<span class="mi">2</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">num_heads_kv</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">encoder_max_input_length</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">head_size</span><span class="p">,</span>
<span class="p">)</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">):</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">quant_mode</span><span class="o">.</span><span class="n">has_kv_cache_quant</span><span class="p">():</span>
<span class="c1"># Since torch does not support fp8 now, using int8 here.</span>
<span class="n">kv_cache_type</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">int8</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">kv_cache_type</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dtype</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span> <span class="k">else</span> <span class="bp">self</span><span class="o">.</span><span class="n">_tensor_dtype</span><span class="p">(</span>
<span class="sa">f</span><span class="s1">&#39;present_key_value_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;present_key_value_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span>
<span class="n">cache_shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">kv_cache_type</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;cross_present_key_value_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span>
<span class="n">cross_cache_shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">kv_cache_type</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="c1"># without plugin, we need two set of kv cache buffers,</span>
<span class="c1"># one for inputs, and the other for outputs.</span>
<span class="c1"># They will take turns to act as input and output buffers.</span>
<span class="c1"># Not applicable to cross KV buffers as it&#39;s constant</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">):</span>
<span class="n">trt_dtype</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_dtype</span><span class="p">(</span>
<span class="sa">f</span><span class="s1">&#39;present_key_value_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="n">trt_dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">fp8</span><span class="p">:</span>
<span class="c1"># PyTorch doesn&#39;t support fp8 datatype, use int8 instead of it because int8 datatype size is same with fp8.</span>
<span class="c1"># TODO: Remove this section when PyTorch support fp8 datatype</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">int8</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_tensor_dtype</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;present_key_value_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;1_present_key_value_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span>
<span class="n">cache_shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_custom_all_reduce</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
<span class="n">set_peer_access</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">ipc_buffers</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">all_reduce_workspace</span> <span class="o">=</span> <span class="n">CustomAllReduceHelper</span><span class="o">.</span><span class="n">allocate_workspace</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="p">,</span>
<span class="n">CustomAllReduceHelper</span><span class="o">.</span><span class="n">max_workspace_size_auto</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span><span class="p">))</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_lora_plugin</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">lora_manager</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">assert</span> <span class="n">lora_uids</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="n">lora_weights_pointers_list</span> <span class="o">=</span> <span class="p">[</span>
<span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">2</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int64</span><span class="p">)</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span>
<span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span><span class="p">)</span>
<span class="p">]</span>
<span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span><span class="p">):</span>
<span class="n">layer_idx</span> <span class="o">=</span> <span class="n">idx</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span>
<span class="k">for</span> <span class="n">lora_module</span> <span class="ow">in</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">lora_target_modules</span> <span class="o">+</span>
<span class="bp">self</span><span class="o">.</span><span class="n">missing_qkv_modules</span><span class="p">):</span>
<span class="n">lora_ranks_</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">lora_ptrs_</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">batch_idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
<span class="n">lora_uid</span> <span class="o">=</span> <span class="n">lora_uids</span><span class="p">[</span><span class="n">batch_idx</span><span class="p">]</span>
<span class="k">if</span> <span class="n">lora_uid</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">lora_uid</span> <span class="o">!=</span> <span class="s2">&quot;-1&quot;</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">lora_manager</span><span class="o">.</span><span class="n">uid_to_low_ranks</span><span class="p">(</span>
<span class="n">lora_uid</span><span class="p">)[</span><span class="n">layer_idx</span><span class="p">][</span><span class="n">lora_module</span><span class="p">]</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">lora_ranks_</span><span class="o">.</span><span class="n">append</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">lora_manager</span><span class="o">.</span><span class="n">uid_to_low_ranks</span><span class="p">(</span><span class="n">lora_uid</span><span class="p">)</span>
<span class="p">[</span><span class="n">layer_idx</span><span class="p">][</span><span class="n">lora_module</span><span class="p">])</span>
<span class="n">lora_ptrs_</span><span class="o">.</span><span class="n">append</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">lora_manager</span><span class="o">.</span><span class="n">lora_weights_pointers_list</span><span class="p">[</span>
<span class="n">layer_idx</span><span class="p">][</span><span class="n">lora_uid</span><span class="p">][</span><span class="n">lora_module</span><span class="p">])</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">lora_ranks_</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="n">lora_ptrs_</span><span class="o">.</span><span class="n">append</span><span class="p">([</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">])</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="o">.</span><span class="n">update</span><span class="p">({</span>
<span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">lora_module</span><span class="si">}</span><span class="s1">_lora_ranks_</span><span class="si">{</span><span class="n">layer_idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">:</span>
<span class="n">torch</span><span class="o">.</span><span class="n">IntTensor</span><span class="p">(</span><span class="n">lora_ranks_</span><span class="p">)</span>
<span class="p">})</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="o">.</span><span class="n">update</span><span class="p">({</span>
<span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">lora_module</span><span class="si">}</span><span class="s1">_lora_weights_pointers_</span><span class="si">{</span><span class="n">layer_idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">:</span>
<span class="n">torch</span><span class="o">.</span><span class="n">LongTensor</span><span class="p">(</span><span class="n">lora_ptrs_</span><span class="p">)</span>
<span class="p">})</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;medusa_packed_mask&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">medusa_packed_mask</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span>
<span class="s1">&#39;medusa_position_offsets&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">medusa_position_offsets</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer_allocated</span> <span class="o">=</span> <span class="kc">True</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span><span class="p">:</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_tokens</span></div>
<span class="k">def</span> <span class="nf">_get_context_shape_buffer</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span>
<span class="n">input_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">context_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">host_context_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">position_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">last_token_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">attention_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">cross_attention_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">cache_indirection</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">kv_cache_block_pointers</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span>
<span class="n">host_kv_cache_block_pointers</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span>
<span class="n">hidden_states_input</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">prompt_embedding_table</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">tasks</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">prompt_vocab_size</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">encoder_output</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">encoder_input_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="n">RuntimeTensor</span><span class="p">]:</span>
<span class="n">tensors</span> <span class="o">=</span> <span class="p">{}</span>
<span class="k">def</span> <span class="nf">sym</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">name</span><span class="p">):</span>
<span class="k">return</span> <span class="n">RuntimeTensor</span><span class="o">.</span><span class="n">from_torch</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">add_tensor</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">name</span><span class="p">):</span>
<span class="k">return</span> <span class="n">tensors</span><span class="o">.</span><span class="n">update</span><span class="p">({</span><span class="n">name</span><span class="p">:</span> <span class="n">sym</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">name</span><span class="p">)})</span>
<span class="k">def</span> <span class="nf">add_tensor_with_shape</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">name</span><span class="p">,</span> <span class="n">shape</span><span class="p">):</span>
<span class="k">return</span> <span class="n">tensors</span><span class="o">.</span><span class="n">update</span><span class="p">(</span>
<span class="p">{</span><span class="n">name</span><span class="p">:</span> <span class="n">RuntimeTensor</span><span class="o">.</span><span class="n">from_torch</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">override_shape</span><span class="o">=</span><span class="n">shape</span><span class="p">)})</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">,</span> <span class="s1">&#39;context_lengths&#39;</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">cache_indirection</span><span class="p">,</span> <span class="s1">&#39;cache_indirection&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_position_embedding</span><span class="p">:</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">position_ids</span><span class="p">,</span> <span class="s1">&#39;position_ids&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span><span class="p">:</span>
<span class="c1"># in context phase, need to generate cross kv cache, set to True</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">bool</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">),</span>
<span class="s1">&#39;cross_kv_cache_gen&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">skip_cross_qkv</span><span class="p">:</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_qkv_reuse</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="c1"># see Attention&#39;s self.qkv output dim</span>
<span class="n">cross_qkv_out_dim</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_heads</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">head_size</span> <span class="o">+</span> <span class="p">(</span>
<span class="mi">2</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_heads_kv</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">head_size</span><span class="p">)</span>
<span class="n">cross_qkv_shape</span> <span class="o">=</span> <span class="n">encoder_output</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="p">(</span>
<span class="n">cross_qkv_out_dim</span><span class="p">,</span> <span class="p">)</span>
<span class="n">cross_qkv_reuse</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span><span class="n">cross_qkv_shape</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">encoder_output</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="n">encoder_output</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">cross_qkv_reuse</span> <span class="o">=</span> <span class="n">cross_qkv_reuse</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">cross_qkv_reuse</span><span class="p">,</span> <span class="s1">&#39;cross_qkv_reuse&#39;</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">encoder_output</span><span class="p">,</span> <span class="s1">&#39;encoder_output&#39;</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">encoder_input_lengths</span><span class="p">,</span> <span class="s1">&#39;encoder_input_lengths&#39;</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;encoder_max_input_length&#39;</span><span class="p">],</span>
<span class="s1">&#39;encoder_max_input_length&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">cross_attention_mask</span><span class="p">,</span> <span class="s1">&#39;cross_attention_mask&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">has_pp</span><span class="p">():</span>
<span class="n">hidden_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span>
<span class="k">if</span> <span class="n">input_ids</span><span class="o">.</span><span class="n">dim</span><span class="p">()</span> <span class="o">==</span> <span class="mi">2</span><span class="p">:</span>
<span class="n">hidden_states_input</span> <span class="o">=</span> <span class="n">hidden_states_input</span><span class="o">.</span><span class="n">resize_</span><span class="p">(</span>
<span class="n">input_ids</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">input_ids</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">hidden_size</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">hidden_states_input</span> <span class="o">=</span> <span class="n">hidden_states_input</span><span class="o">.</span><span class="n">resize_</span><span class="p">(</span>
<span class="n">input_ids</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">hidden_size</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;logits&#39;</span><span class="p">],</span> <span class="s1">&#39;logits&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span><span class="p">:</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;medusa_logits&#39;</span><span class="p">],</span> <span class="s1">&#39;medusa_logits&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_context_logits</span><span class="p">:</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">last_token_ids</span><span class="p">,</span> <span class="s1">&#39;last_token_ids&#39;</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">hidden_states_input</span><span class="p">,</span> <span class="s1">&#39;hidden_states_output&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_first_pp_rank</span><span class="p">():</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">input_ids</span><span class="p">,</span> <span class="s1">&#39;input_ids&#39;</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">hidden_states_input</span><span class="p">,</span> <span class="s1">&#39;hidden_states_input&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="n">prompt_embedding_table</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">prompt_embedding_table</span><span class="p">,</span> <span class="s1">&#39;prompt_embedding_table&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
<span class="n">tasks_generation</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">concat</span><span class="p">([</span>
<span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">([</span><span class="n">context_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]</span><span class="o">.</span><span class="n">item</span><span class="p">()],</span>
<span class="n">tasks</span><span class="p">[</span><span class="n">b</span><span class="p">]</span><span class="o">.</span><span class="n">item</span><span class="p">(),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
<span class="k">for</span> <span class="n">b</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">context_lengths</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">))</span>
<span class="p">])</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">tasks_generation</span> <span class="o">=</span> <span class="n">tasks</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">tasks_generation</span><span class="p">,</span> <span class="s1">&#39;tasks&#39;</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">prompt_vocab_size</span><span class="p">,</span> <span class="s1">&#39;prompt_vocab_size&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span><span class="p">:</span>
<span class="n">buffer</span> <span class="o">=</span> <span class="n">kv_cache_block_pointers</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span>
<span class="n">shape</span> <span class="o">=</span> <span class="n">kv_cache_block_pointers</span><span class="o">.</span><span class="n">shape</span>
<span class="n">shape</span> <span class="o">=</span> <span class="p">[</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span> <span class="o">*</span><span class="n">shape</span><span class="p">[</span><span class="mi">3</span><span class="p">:]]</span>
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="n">buffer</span><span class="p">,</span> <span class="sa">f</span><span class="s1">&#39;kv_cache_block_pointers&#39;</span><span class="p">,</span> <span class="n">shape</span><span class="p">)</span>
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="n">host_kv_cache_block_pointers</span><span class="p">,</span>
<span class="sa">f</span><span class="s1">&#39;host_kv_cache_block_pointers&#39;</span><span class="p">,</span> <span class="n">shape</span><span class="p">)</span>
<span class="n">batch_size</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span><span class="p">:</span>
<span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">):</span>
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
<span class="n">kv_cache_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_heads_kv</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">head_size</span><span class="p">)</span>
<span class="c1"># for empty tensor, TRT does not really use the tensor data, so any dtype is fine</span>
<span class="n">kv_cache_buffer</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="n">kv_cache_buffer</span><span class="p">,</span>
<span class="sa">f</span><span class="s1">&#39;past_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">,</span>
<span class="n">kv_cache_shape</span><span class="p">)</span>
<span class="n">present</span> <span class="o">=</span> <span class="sa">f</span><span class="s1">&#39;present_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="n">present</span><span class="p">],</span> <span class="n">present</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span><span class="p">:</span>
<span class="n">cross_kv_cache_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">num_heads_kv</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">head_size</span><span class="p">)</span>
<span class="c1"># for empty tensor, TRT does not really use the tensor data, so any dtype is fine</span>
<span class="n">cross_kv_cache_buffer</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="n">cross_kv_cache_buffer</span><span class="p">,</span>
<span class="sa">f</span><span class="s1">&#39;cross_past_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">,</span>
<span class="n">cross_kv_cache_shape</span><span class="p">)</span>
<span class="n">cross_present</span> <span class="o">=</span> <span class="sa">f</span><span class="s1">&#39;cross_present_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="n">cross_present</span><span class="p">],</span> <span class="n">cross_present</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">key_value_cache</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;present_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">]</span>
<span class="c1"># when plugin is used, past_ket_value tensor does not need to be empty tensor</span>
<span class="c1"># because plugin does not care, and does not use this shape.</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">key_value_cache</span><span class="p">,</span> <span class="sa">f</span><span class="s1">&#39;past_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">key_value_cache</span><span class="p">,</span> <span class="sa">f</span><span class="s1">&#39;present_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span><span class="p">:</span>
<span class="n">cross_cache_buffer</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span>
<span class="sa">f</span><span class="s1">&#39;cross_present_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">]</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">cross_cache_buffer</span><span class="p">,</span>
<span class="sa">f</span><span class="s1">&#39;cross_past_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">cross_cache_buffer</span><span class="p">,</span>
<span class="sa">f</span><span class="s1">&#39;cross_present_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
<span class="c1"># context request</span>
<span class="n">host_request_types</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="s1">&#39;cpu&#39;</span><span class="p">)</span><span class="o">.</span><span class="n">int</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span><span class="p">,</span>
<span class="s1">&#39;sequence_length&#39;</span><span class="p">,</span> <span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="p">))</span>
<span class="c1"># field 0: past_key_value_length, field 1: is_context (deprecated). changed to [0], otherwise affects batch padded input mode</span>
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="n">host_context_lengths</span><span class="p">,</span>
<span class="s1">&#39;host_past_key_value_lengths&#39;</span><span class="p">,</span> <span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="p">))</span>
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">host_sink_token_length</span><span class="p">,</span>
<span class="s1">&#39;host_sink_token_length&#39;</span><span class="p">,</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="p">))</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">host_request_types</span><span class="p">,</span> <span class="s1">&#39;host_request_types&#39;</span><span class="p">)</span>
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">host_max_attention_window_sizes</span><span class="p">,</span>
<span class="sa">f</span><span class="s1">&#39;host_max_attention_window_sizes&#39;</span><span class="p">,</span>
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span><span class="p">,</span> <span class="p">))</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">host_context_lengths</span><span class="p">,</span> <span class="s1">&#39;host_context_lengths&#39;</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">attention_mask</span><span class="p">,</span> <span class="s1">&#39;attention_mask&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_custom_all_reduce</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">all_reduce_workspace</span><span class="p">,</span> <span class="s1">&#39;all_reduce_workspace&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_lora_plugin</span><span class="p">:</span>
<span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span><span class="p">):</span>
<span class="k">for</span> <span class="n">lora_module</span> <span class="ow">in</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">lora_target_modules</span> <span class="o">+</span>
<span class="bp">self</span><span class="o">.</span><span class="n">missing_qkv_modules</span><span class="p">):</span>
<span class="n">layer_idx</span> <span class="o">=</span> <span class="n">idx</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span>
<span class="n">lora_ranks</span> <span class="o">=</span> <span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">lora_module</span><span class="si">}</span><span class="s1">_lora_ranks_</span><span class="si">{</span><span class="n">layer_idx</span><span class="si">}</span><span class="s1">&#39;</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="n">lora_ranks</span><span class="p">],</span> <span class="n">lora_ranks</span><span class="p">)</span>
<span class="n">lora_weights</span> <span class="o">=</span> <span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">lora_module</span><span class="si">}</span><span class="s1">_lora_weights_pointers_</span><span class="si">{</span><span class="n">layer_idx</span><span class="si">}</span><span class="s1">&#39;</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="n">lora_weights</span><span class="p">],</span> <span class="n">lora_weights</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">encoder_input_lengths</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="s1">&#39;cpu&#39;</span><span class="p">),</span>
<span class="s1">&#39;host_encoder_input_lengths&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span><span class="p">:</span>
<span class="c1"># Medusa mask and position offsets are fixed for the whole session.</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;medusa_packed_mask&#39;</span><span class="p">],</span> <span class="s1">&#39;medusa_packed_mask&#39;</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;medusa_position_offsets&#39;</span><span class="p">],</span>
<span class="s1">&#39;medusa_position_offsets&#39;</span><span class="p">)</span>
<span class="k">return</span> <span class="n">tensors</span>
<span class="k">def</span> <span class="nf">_get_next_step_shape_buffer</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span>
<span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">beam_width</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">max_context_length</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">step</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">context_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">host_context_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">position_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">last_token_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">attention_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">cross_attention_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">cache_indirection</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">kv_cache_block_pointers</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span>
<span class="n">host_kv_cache_block_pointers</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span>
<span class="n">hidden_states_input</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">prompt_embedding_table</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">tasks</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">prompt_vocab_size</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">encoder_output</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">encoder_input_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
<span class="n">tensors</span> <span class="o">=</span> <span class="p">{}</span> <span class="c1"># Dict[str, RuntimeTensor]</span>
<span class="k">def</span> <span class="nf">sym</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">name</span><span class="p">):</span>
<span class="k">return</span> <span class="n">RuntimeTensor</span><span class="o">.</span><span class="n">from_torch</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">add_tensor</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">name</span><span class="p">):</span>
<span class="k">return</span> <span class="n">tensors</span><span class="o">.</span><span class="n">update</span><span class="p">({</span><span class="n">name</span><span class="p">:</span> <span class="n">sym</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">name</span><span class="p">)})</span>
<span class="k">def</span> <span class="nf">add_tensor_with_shape</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">name</span><span class="p">,</span> <span class="n">shape</span><span class="p">):</span>
<span class="k">return</span> <span class="n">tensors</span><span class="o">.</span><span class="n">update</span><span class="p">(</span>
<span class="p">{</span><span class="n">name</span><span class="p">:</span> <span class="n">RuntimeTensor</span><span class="o">.</span><span class="n">from_torch</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">override_shape</span><span class="o">=</span><span class="n">shape</span><span class="p">)})</span>
<span class="n">context_lengths_local</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
<span class="n">host_context_lengths_local</span> <span class="o">=</span> <span class="n">host_context_lengths</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_context_fmha_for_generation</span><span class="p">:</span>
<span class="n">context_lengths_local</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones_like</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span><span class="o">.</span><span class="n">int</span><span class="p">()</span>
<span class="n">host_context_lengths_local</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones_like</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="s1">&#39;cpu&#39;</span><span class="p">)</span><span class="o">.</span><span class="n">int</span><span class="p">()</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">context_lengths_local</span><span class="p">,</span> <span class="s1">&#39;context_lengths&#39;</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">cache_indirection</span><span class="p">,</span> <span class="s1">&#39;cache_indirection&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">has_pp</span><span class="p">():</span>
<span class="n">hidden_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span>
<span class="n">shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">beam_width</span><span class="p">,</span>
<span class="n">hidden_size</span><span class="p">)</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span> <span class="k">else</span> <span class="p">(</span>
<span class="n">batch_size</span> <span class="o">*</span> <span class="n">beam_width</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">)</span>
<span class="n">hidden_states_input</span> <span class="o">=</span> <span class="n">hidden_states_input</span><span class="o">.</span><span class="n">resize_</span><span class="p">(</span><span class="o">*</span><span class="n">shape</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;logits&#39;</span><span class="p">],</span> <span class="s1">&#39;logits&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span><span class="p">:</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;medusa_logits&#39;</span><span class="p">],</span> <span class="s1">&#39;medusa_logits&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_context_logits</span><span class="p">:</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">last_token_ids</span><span class="p">,</span> <span class="s1">&#39;last_token_ids&#39;</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">hidden_states_input</span><span class="p">,</span> <span class="s1">&#39;hidden_states_output&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_first_pp_rank</span><span class="p">():</span>
<span class="n">input_ids_shape</span> <span class="o">=</span> <span class="p">(</span>
<span class="n">batch_size</span> <span class="o">*</span> <span class="n">beam_width</span> <span class="o">*</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_tokens</span> <span class="o">+</span> <span class="mi">1</span><span class="p">),</span>
<span class="p">)</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span> <span class="k">else</span> <span class="p">(</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">beam_width</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_tokens</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span><span class="p">:</span>
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">generation_input_ids</span><span class="p">,</span> <span class="s1">&#39;input_ids&#39;</span><span class="p">,</span>
<span class="n">input_ids_shape</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span><span class="p">,</span> <span class="s1">&#39;input_ids&#39;</span><span class="p">,</span>
<span class="n">input_ids_shape</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">hidden_states_input</span><span class="p">,</span> <span class="s1">&#39;hidden_states_input&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">host_context_lengths_local</span><span class="p">,</span> <span class="s1">&#39;host_context_lengths&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_position_embedding</span><span class="p">:</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">position_ids</span><span class="p">,</span> <span class="s1">&#39;position_ids&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span><span class="p">:</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
<span class="c1"># disable (or minimize) cross qkv computation at generation phase</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">skip_cross_qkv</span><span class="p">:</span>
<span class="c1"># disable</span>
<span class="n">encoder_output_shape</span> <span class="o">=</span> <span class="n">encoder_output</span><span class="o">.</span><span class="n">shape</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">cross_qkv_reuse</span><span class="p">,</span> <span class="s1">&#39;cross_qkv_reuse&#39;</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="c1"># minimize</span>
<span class="c1"># hacky way: such that qkv gemm becomes a gemv which is cheap and negligible</span>
<span class="n">encoder_output_shape</span> <span class="o">=</span> <span class="p">[</span>
<span class="mi">1</span><span class="p">,</span> <span class="n">encoder_output</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
<span class="p">]</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span> <span class="k">else</span> <span class="p">[</span>
<span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">encoder_output</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
<span class="p">]</span>
<span class="k">else</span><span class="p">:</span>
<span class="c1"># OOTB path doesn&#39;t have kv cache for now, so this encoder_output is</span>
<span class="c1"># a must-have input. We just use the encoder_output</span>
<span class="n">encoder_output_shape</span> <span class="o">=</span> <span class="n">encoder_output</span><span class="o">.</span><span class="n">shape</span>
<span class="c1"># in generation phase, cross kv cache is already filled during context phase, set to False</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">bool</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">),</span>
<span class="s1">&#39;cross_kv_cache_gen&#39;</span><span class="p">)</span>
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="n">encoder_output</span><span class="p">,</span> <span class="s1">&#39;encoder_output&#39;</span><span class="p">,</span>
<span class="n">encoder_output_shape</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">encoder_input_lengths</span><span class="p">,</span> <span class="s1">&#39;encoder_input_lengths&#39;</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;encoder_max_input_length&#39;</span><span class="p">],</span>
<span class="s1">&#39;encoder_max_input_length&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">cross_attention_mask</span><span class="p">,</span> <span class="s1">&#39;cross_attention_mask&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span><span class="p">:</span>
<span class="n">shape</span> <span class="o">=</span> <span class="n">kv_cache_block_pointers</span><span class="o">.</span><span class="n">shape</span>
<span class="n">shape</span> <span class="o">=</span> <span class="p">[</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span> <span class="o">*</span><span class="n">shape</span><span class="p">[</span><span class="mi">3</span><span class="p">:]]</span>
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="n">kv_cache_block_pointers</span><span class="p">,</span>
<span class="sa">f</span><span class="s1">&#39;kv_cache_block_pointers&#39;</span><span class="p">,</span> <span class="n">shape</span><span class="p">)</span>
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="n">host_kv_cache_block_pointers</span><span class="p">,</span>
<span class="sa">f</span><span class="s1">&#39;host_kv_cache_block_pointers&#39;</span><span class="p">,</span> <span class="n">shape</span><span class="p">)</span>
<span class="k">if</span> <span class="n">prompt_embedding_table</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">prompt_embedding_table</span><span class="p">,</span> <span class="s1">&#39;prompt_embedding_table&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
<span class="n">gen_tasks</span> <span class="o">=</span> <span class="n">tasks</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">gen_tasks</span> <span class="o">=</span> <span class="n">tasks</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">gen_tasks</span><span class="p">,</span> <span class="s1">&#39;tasks&#39;</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">prompt_vocab_size</span><span class="p">,</span> <span class="s1">&#39;prompt_vocab_size&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span><span class="p">:</span>
<span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">):</span>
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
<span class="n">next_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">beam_width</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_heads_kv</span><span class="p">,</span>
<span class="n">max_context_length</span> <span class="o">+</span> <span class="n">step</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">head_size</span><span class="p">)</span>
<span class="k">if</span> <span class="n">step</span> <span class="o">%</span> <span class="mi">2</span><span class="p">:</span>
<span class="n">add_tensor_with_shape</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;1_present_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">],</span>
<span class="sa">f</span><span class="s1">&#39;past_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">,</span> <span class="n">next_shape</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;present_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">],</span>
<span class="sa">f</span><span class="s1">&#39;present_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">add_tensor_with_shape</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;present_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">],</span>
<span class="sa">f</span><span class="s1">&#39;past_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">,</span> <span class="n">next_shape</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;1_present_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">],</span>
<span class="sa">f</span><span class="s1">&#39;present_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">key_value_cache</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;present_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">]</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">key_value_cache</span><span class="p">,</span> <span class="sa">f</span><span class="s1">&#39;past_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">key_value_cache</span><span class="p">,</span> <span class="sa">f</span><span class="s1">&#39;present_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span><span class="p">:</span>
<span class="n">cross_cache_buffer</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span>
<span class="sa">f</span><span class="s1">&#39;cross_present_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">]</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">cross_cache_buffer</span><span class="p">,</span>
<span class="sa">f</span><span class="s1">&#39;cross_past_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">cross_cache_buffer</span><span class="p">,</span>
<span class="sa">f</span><span class="s1">&#39;cross_present_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
<span class="c1"># generation requests</span>
<span class="n">host_request_types</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones_like</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="s1">&#39;cpu&#39;</span><span class="p">)</span><span class="o">.</span><span class="n">int</span><span class="p">()</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_context_fmha_for_generation</span><span class="p">:</span>
<span class="n">host_request_types</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="s1">&#39;cpu&#39;</span><span class="p">)</span><span class="o">.</span><span class="n">int</span><span class="p">()</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span><span class="p">:</span>
<span class="n">host_past_key_value_lengths</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span>
<span class="k">else</span><span class="p">:</span>
<span class="c1"># previous [past_kv_length, is_context] has been deprecated. only past_kv_length should be given here</span>
<span class="c1"># Note we should use max_context_length here to align to max -- but isn&#39;t this done in attn plugin&#39;s max_element() already?</span>
<span class="n">host_past_key_value_lengths</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span>
<span class="p">[</span><span class="n">max_context_length</span> <span class="o">+</span> <span class="n">step</span><span class="p">]</span> <span class="o">*</span> <span class="p">(</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">beam_width</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="s1">&#39;cpu&#39;</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">host_past_key_value_lengths</span><span class="p">,</span>
<span class="s1">&#39;host_past_key_value_lengths&#39;</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">host_request_types</span><span class="p">,</span> <span class="s1">&#39;host_request_types&#39;</span><span class="p">)</span>
<span class="c1"># Sequence lengths are not used in the context phase actually.</span>
<span class="n">sequence_length</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_context_fmha_for_generation</span><span class="p">:</span>
<span class="n">sequence_length</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
<span class="n">sequence_length</span> <span class="o">+=</span> <span class="mi">1</span>
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="n">sequence_length</span><span class="p">,</span> <span class="s1">&#39;sequence_length&#39;</span><span class="p">,</span>
<span class="p">(</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">beam_width</span><span class="p">,</span> <span class="p">))</span>
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">host_sink_token_length</span><span class="p">,</span>
<span class="s1">&#39;host_sink_token_length&#39;</span><span class="p">,</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="p">))</span>
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">host_max_attention_window_sizes</span><span class="p">,</span>
<span class="sa">f</span><span class="s1">&#39;host_max_attention_window_sizes&#39;</span><span class="p">,</span>
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span><span class="p">,</span> <span class="p">))</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">host_context_lengths_local</span><span class="p">,</span> <span class="s1">&#39;host_context_lengths&#39;</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">attention_mask</span><span class="p">,</span> <span class="s1">&#39;attention_mask&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_custom_all_reduce</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">all_reduce_workspace</span><span class="p">,</span> <span class="s1">&#39;all_reduce_workspace&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_lora_plugin</span><span class="p">:</span>
<span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span><span class="p">):</span>
<span class="n">layer_idx</span> <span class="o">=</span> <span class="n">idx</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span>
<span class="k">for</span> <span class="n">lora_module</span> <span class="ow">in</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">lora_target_modules</span> <span class="o">+</span>
<span class="bp">self</span><span class="o">.</span><span class="n">missing_qkv_modules</span><span class="p">):</span>
<span class="n">lora_ranks</span> <span class="o">=</span> <span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">lora_module</span><span class="si">}</span><span class="s1">_lora_ranks_</span><span class="si">{</span><span class="n">layer_idx</span><span class="si">}</span><span class="s1">&#39;</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="n">lora_ranks</span><span class="p">],</span> <span class="n">lora_ranks</span><span class="p">)</span>
<span class="n">lora_module</span> <span class="o">=</span> <span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">lora_module</span><span class="si">}</span><span class="s1">_lora_weights_pointers_</span><span class="si">{</span><span class="n">layer_idx</span><span class="si">}</span><span class="s1">&#39;</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="n">lora_module</span><span class="p">],</span> <span class="n">lora_module</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">encoder_input_lengths</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="s1">&#39;cpu&#39;</span><span class="p">),</span>
<span class="s1">&#39;host_encoder_input_lengths&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span><span class="p">:</span>
<span class="c1"># Medusa mask and position offsets are fixed for the whole session.</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;medusa_packed_mask&#39;</span><span class="p">],</span> <span class="s1">&#39;medusa_packed_mask&#39;</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;medusa_position_offsets&#39;</span><span class="p">],</span>
<span class="s1">&#39;medusa_position_offsets&#39;</span><span class="p">)</span>
<span class="k">return</span> <span class="n">tensors</span>
<span class="k">def</span> <span class="nf">_prepare_context_inputs</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span>
<span class="n">host_context_lengths</span><span class="p">,</span> <span class="n">use_gpt_attention_plugin</span><span class="p">,</span>
<span class="n">remove_input_padding</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">remove_input_padding</span><span class="p">:</span>
<span class="c1"># For Medusa, last_token_ids should contain the actual indices</span>
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">last_token_ids</span> <span class="o">-</span> <span class="mi">1</span> <span class="c1"># sub 1 from context_lengths for indices</span>
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">last_token_ids</span><span class="o">.</span><span class="n">reshape</span><span class="p">([</span><span class="n">batch_size</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">])</span>
<span class="k">if</span> <span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
<span class="n">max_context_length</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;max_context_length&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="n">remove_input_padding</span><span class="p">:</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">concat</span><span class="p">([</span>
<span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span>
<span class="n">host_context_lengths</span><span class="p">[</span><span class="n">i</span><span class="p">],</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">)</span>
<span class="p">])</span>
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">last_token_ids</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">int</span><span class="p">()</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">max_context_length</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span>
<span class="p">[</span><span class="mi">1</span><span class="p">,</span>
<span class="o">-</span><span class="mi">1</span><span class="p">])</span><span class="o">.</span><span class="n">expand</span><span class="p">([</span><span class="n">batch_size</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">])</span>
<span class="n">ret</span> <span class="o">=</span> <span class="p">{</span><span class="s1">&#39;last_token_ids&#39;</span><span class="p">:</span> <span class="n">last_token_ids</span><span class="p">}</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">input_ids</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;input_ids&#39;</span><span class="p">)</span>
<span class="n">pad_id</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;pad_id&#39;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
<span class="n">attention_mask</span> <span class="o">=</span> <span class="n">_prepare_attention_mask</span><span class="p">(</span><span class="n">input_ids</span><span class="p">,</span> <span class="n">pad_id</span><span class="p">)</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">attention_mask</span><span class="o">.</span><span class="n">long</span><span class="p">()</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span>
<span class="n">position_ids</span><span class="o">.</span><span class="n">masked_fill_</span><span class="p">(</span><span class="n">attention_mask</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">position_ids</span><span class="o">.</span><span class="n">int</span><span class="p">()</span>
<span class="n">ret</span> <span class="o">=</span> <span class="p">{</span>
<span class="s1">&#39;attention_mask&#39;</span><span class="p">:</span> <span class="n">attention_mask</span><span class="p">,</span>
<span class="s1">&#39;last_token_ids&#39;</span><span class="p">:</span> <span class="n">last_token_ids</span>
<span class="p">}</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_position_embedding</span><span class="p">:</span>
<span class="n">ret</span><span class="p">[</span><span class="s1">&#39;position_ids&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">position_ids</span>
<span class="k">return</span> <span class="n">ret</span>
<span class="k">def</span> <span class="nf">_prepare_generation_inputs</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span>
<span class="n">use_gpt_attention_plugin</span><span class="p">,</span>
<span class="n">remove_input_padding</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones_like</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">)</span>
<span class="k">if</span> <span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
<span class="n">step</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;step&#39;</span><span class="p">)</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">context_lengths</span> <span class="o">+</span> <span class="n">step</span>
<span class="k">if</span> <span class="n">remove_input_padding</span><span class="p">:</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span><span class="p">:</span>
<span class="c1"># For Medusa, last_token_ids should be [bs * seq] and should contain the actual indices (starts from 1)</span>
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_tokens</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="n">context_lengths</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">last_token_ids</span><span class="o">.</span><span class="n">expand</span><span class="p">([</span><span class="n">batch_size</span><span class="p">,</span>
<span class="o">-</span><span class="mi">1</span><span class="p">])</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">last_token_ids</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">int</span><span class="p">()</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span><span class="p">:</span>
<span class="c1"># For Medusa, last_token_ids should be [bs, seq] and should contain the actual indices (starts from 0)</span>
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_tokens</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="n">context_lengths</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">last_token_ids</span><span class="o">.</span><span class="n">expand</span><span class="p">([</span><span class="n">batch_size</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">])</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="n">position_ids</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">ret</span> <span class="o">=</span> <span class="p">{</span><span class="s1">&#39;last_token_ids&#39;</span><span class="p">:</span> <span class="n">last_token_ids</span><span class="p">}</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">attention_mask</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;attention_mask&#39;</span><span class="p">)</span>
<span class="n">num_beams</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;num_beams&#39;</span><span class="p">)</span>
<span class="n">attention_mask</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">attention_mask</span><span class="p">,</span>
<span class="n">attention_mask</span><span class="o">.</span><span class="n">new_ones</span><span class="p">(</span>
<span class="p">(</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">num_beams</span><span class="p">,</span> <span class="mi">1</span><span class="p">))),</span>
<span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">attention_mask</span><span class="o">.</span><span class="n">long</span><span class="p">()</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span>
<span class="n">position_ids</span><span class="o">.</span><span class="n">masked_fill_</span><span class="p">(</span><span class="n">attention_mask</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">position_ids</span><span class="p">[:,</span> <span class="o">-</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">position_ids</span><span class="o">.</span><span class="n">int</span><span class="p">()</span>
<span class="n">ret</span> <span class="o">=</span> <span class="p">{</span>
<span class="s1">&#39;last_token_ids&#39;</span><span class="p">:</span> <span class="n">last_token_ids</span><span class="p">,</span>
<span class="s1">&#39;attention_mask&#39;</span><span class="p">:</span> <span class="n">attention_mask</span><span class="p">,</span>
<span class="p">}</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_position_embedding</span><span class="p">:</span>
<span class="n">ret</span><span class="p">[</span><span class="s1">&#39;position_ids&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">position_ids</span>
<span class="k">return</span> <span class="n">ret</span>
<div class="viewcode-block" id="GenerationSession.pp_communicate_new_tokens">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.pp_communicate_new_tokens">[docs]</a>
<span class="k">def</span> <span class="nf">pp_communicate_new_tokens</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">should_stop</span><span class="p">,</span> <span class="n">cache_indir</span><span class="p">,</span>
<span class="n">sequence_length</span><span class="p">):</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
<span class="k">for</span> <span class="n">pg</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_group</span><span class="p">:</span>
<span class="k">if</span> <span class="n">pg</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">rank</span><span class="p">:</span>
<span class="k">continue</span>
<span class="n">should_stop</span> <span class="o">=</span> <span class="n">should_stop</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">nccl_comm</span><span class="o">.</span><span class="n">send</span><span class="p">(</span><span class="n">should_stop</span><span class="p">,</span> <span class="n">pg</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">nccl_comm</span><span class="o">.</span><span class="n">send</span><span class="p">(</span><span class="n">cache_indir</span><span class="p">,</span> <span class="n">pg</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">nccl_comm</span><span class="o">.</span><span class="n">send</span><span class="p">(</span><span class="n">sequence_length</span><span class="p">,</span> <span class="n">pg</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">nccl_comm</span><span class="o">.</span><span class="n">send</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_group</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">should_stop</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">bool</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">nccl_comm</span><span class="o">.</span><span class="n">recv</span><span class="p">(</span><span class="n">should_stop</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_group</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
<span class="bp">self</span><span class="o">.</span><span class="n">nccl_comm</span><span class="o">.</span><span class="n">recv</span><span class="p">(</span><span class="n">cache_indir</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_group</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
<span class="bp">self</span><span class="o">.</span><span class="n">nccl_comm</span><span class="o">.</span><span class="n">recv</span><span class="p">(</span><span class="n">sequence_length</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_group</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_first_pp_rank</span><span class="p">():</span>
<span class="bp">self</span><span class="o">.</span><span class="n">nccl_comm</span><span class="o">.</span><span class="n">recv</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_group</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
<span class="k">return</span> <span class="n">should_stop</span></div>
<div class="viewcode-block" id="GenerationSession.pp_communicate_final_output_ids">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.pp_communicate_final_output_ids">[docs]</a>
<span class="k">def</span> <span class="nf">pp_communicate_final_output_ids</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">final_output_ids</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span>
<span class="n">beam_width</span><span class="p">):</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
<span class="bp">self</span><span class="o">.</span><span class="n">nccl_comm</span><span class="o">.</span><span class="n">send</span><span class="p">(</span><span class="n">final_output_ids</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_group</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_first_pp_rank</span><span class="p">():</span>
<span class="n">final_output_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span>
<span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">nccl_comm</span><span class="o">.</span><span class="n">recv</span><span class="p">(</span><span class="n">final_output_ids</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_group</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
<span class="k">return</span> <span class="n">final_output_ids</span></div>
<div class="viewcode-block" id="GenerationSession.finalize_decoder">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.finalize_decoder">[docs]</a>
<span class="k">def</span> <span class="nf">finalize_decoder</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">context_lengths</span><span class="p">,</span>
<span class="n">batch_size</span><span class="p">,</span>
<span class="n">beam_width</span><span class="p">,</span>
<span class="n">scfg</span><span class="p">,</span>
<span class="n">in_progress</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
<span class="n">final_output_ids</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
<span class="c1"># output shape of self.gather_tree: [batch_size, beam_width, output_len]</span>
<span class="n">beam_hyps_args</span> <span class="o">=</span> <span class="p">[</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_output_ids_tgt</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_sequence_lengths_tgt</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_cum_log_probs</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_normed_scores</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_log_probs</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_min_normed_scores</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_num_beams</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_is_done</span>
<span class="p">]</span>
<span class="k">if</span> <span class="n">scfg</span><span class="o">.</span><span class="n">use_beam_hyps</span> <span class="ow">and</span> <span class="n">in_progress</span><span class="p">:</span>
<span class="c1"># self.gather_tree modifies these args.</span>
<span class="c1"># In streaming mode, this results in incorrect decoding in the following steps.</span>
<span class="n">beam_hyps_args</span> <span class="o">=</span> <span class="n">copy</span><span class="o">.</span><span class="n">deepcopy</span><span class="p">(</span><span class="n">beam_hyps_args</span><span class="p">)</span>
<span class="n">final_output_ids</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_tree</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">output_ids</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">parent_ids</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">end_ids</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">cum_log_probs</span><span class="p">,</span>
<span class="o">*</span><span class="n">beam_hyps_args</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">finished</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">length_penalty</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">early_stopping</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">use_beam_hyps</span><span class="p">)</span>
<span class="c1"># Communicate ranks in Pipeline Parallelism</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">has_pp</span><span class="p">():</span>
<span class="n">final_output_ids</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">pp_communicate_final_output_ids</span><span class="p">(</span>
<span class="n">final_output_ids</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">)</span>
<span class="k">return</span> <span class="n">final_output_ids</span></div>
<div class="viewcode-block" id="GenerationSession.find_best_medusa_path">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.find_best_medusa_path">[docs]</a>
<span class="k">def</span> <span class="nf">find_best_medusa_path</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">batch_size</span><span class="p">,</span>
<span class="n">input_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">next_logits</span><span class="p">,</span>
<span class="n">temp</span><span class="o">=</span><span class="mi">0</span><span class="p">):</span>
<span class="k">assert</span> <span class="n">input_ids</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_tokens</span> <span class="o">+</span> <span class="mi">1</span>
<span class="n">best_path</span> <span class="o">=</span> <span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">batch_size</span>
<span class="n">best_path_len</span> <span class="o">=</span> <span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="n">batch_size</span>
<span class="n">next_tokens</span> <span class="o">=</span> <span class="p">[</span><span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">batch_size</span>
<span class="n">zero_pad</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">input_ids</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="n">input_ids</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="n">input_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">input_ids</span><span class="p">,</span> <span class="n">zero_pad</span><span class="p">),</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="k">if</span> <span class="n">temp</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">new_tokens_raw</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span>
<span class="n">next_logits</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span>
<span class="p">)</span> <span class="c1"># TODO: can be done by treating [bs, nT, vocab] as [bs*nT, vocab] and using decoderOp?</span>
<span class="n">new_tokens</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">new_tokens_raw</span><span class="p">,</span> <span class="n">zero_pad</span><span class="p">),</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">input_paths</span> <span class="o">=</span> <span class="p">[</span>
<span class="n">input_ids</span><span class="p">[</span><span class="n">b</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">medusa_paths</span><span class="p">]</span> <span class="k">for</span> <span class="n">b</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">)</span>
<span class="p">]</span>
<span class="n">new_paths</span> <span class="o">=</span> <span class="p">[</span>
<span class="n">new_tokens</span><span class="p">[</span><span class="n">b</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">medusa_paths</span><span class="p">]</span> <span class="k">for</span> <span class="n">b</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">)</span>
<span class="p">]</span>
<span class="k">for</span> <span class="n">b</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
<span class="n">equality</span> <span class="o">=</span> <span class="n">input_paths</span><span class="p">[</span><span class="n">b</span><span class="p">][:,</span> <span class="mi">1</span><span class="p">:]</span> <span class="o">==</span> <span class="n">new_paths</span><span class="p">[</span><span class="n">b</span><span class="p">][:,</span> <span class="p">:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
<span class="n">paths_correct_len</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cumprod</span><span class="p">(</span><span class="n">equality</span><span class="o">.</span><span class="n">int</span><span class="p">(),</span>
<span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="n">best_path_len</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="o">=</span> <span class="n">paths_correct_len</span><span class="o">.</span><span class="n">max</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span> <span class="o">+</span> <span class="mi">1</span>
<span class="k">if</span> <span class="n">best_path_len</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
<span class="n">best_path</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">paths_correct_len</span><span class="p">)</span>
<span class="n">next_tokens</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="o">=</span> <span class="n">new_paths</span><span class="p">[</span><span class="n">b</span><span class="p">][</span>
<span class="n">best_path</span><span class="p">[</span><span class="n">b</span><span class="p">]][:</span><span class="n">best_path_len</span><span class="p">[</span><span class="n">b</span><span class="p">]]</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
<span class="k">return</span> <span class="n">best_path</span><span class="p">,</span> <span class="n">best_path_len</span><span class="p">,</span> <span class="n">next_tokens</span></div>
<div class="viewcode-block" id="GenerationSession.filter_medusa_logits">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.filter_medusa_logits">[docs]</a>
<span class="k">def</span> <span class="nf">filter_medusa_logits</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">best_path</span><span class="p">,</span> <span class="n">best_path_lengths</span><span class="p">,</span>
<span class="n">medusa_logits</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> medusa_logits is of shape [nMH, bs, nMT+1, vocab]</span>
<span class="sd"> Returns [nMH, bs, vocab]</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">filtered_logits</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span>
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_heads</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">medusa_logits</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="n">medusa_logits</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="n">medusa_logits</span> <span class="o">=</span> <span class="n">medusa_logits</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_heads</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_tokens</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="k">for</span> <span class="n">b</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
<span class="n">idx</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">medusa_paths</span><span class="p">[</span><span class="n">best_path</span><span class="p">[</span><span class="n">b</span><span class="p">],</span> <span class="n">best_path_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="o">-</span> <span class="mi">1</span><span class="p">]</span>
<span class="n">filtered_logits</span><span class="p">[:,</span> <span class="n">b</span><span class="p">,</span> <span class="o">...</span><span class="p">]</span> <span class="o">=</span> <span class="n">medusa_logits</span><span class="p">[:,</span> <span class="n">b</span><span class="p">,</span> <span class="n">idx</span><span class="p">,</span> <span class="o">...</span><span class="p">]</span>
<span class="k">return</span> <span class="n">filtered_logits</span></div>
<div class="viewcode-block" id="GenerationSession.get_next_medusa_tokens">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.get_next_medusa_tokens">[docs]</a>
<span class="k">def</span> <span class="nf">get_next_medusa_tokens</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">next_medusa_logits</span><span class="p">):</span>
<span class="n">next_medusa_tokens</span> <span class="o">=</span> <span class="p">[</span>
<span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="n">next_medusa_logits</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="p">]</span> <span class="c1"># dummy token for now, TODO: update tree_ids and remove this</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_heads</span><span class="p">):</span>
<span class="n">medusa_token</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">topk</span><span class="p">(</span><span class="n">next_medusa_logits</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="p">:,</span> <span class="p">:],</span>
<span class="bp">self</span><span class="o">.</span><span class="n">medusa_topks</span><span class="p">[</span><span class="n">i</span><span class="p">],</span>
<span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">indices</span>
<span class="n">next_medusa_tokens</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">medusa_token</span><span class="p">)</span>
<span class="n">next_medusa_tokens</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="n">next_medusa_tokens</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="k">return</span> <span class="n">next_medusa_tokens</span></div>
<div class="viewcode-block" id="GenerationSession.update_kv_cache_draft_token_location">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.update_kv_cache_draft_token_location">[docs]</a>
<span class="k">def</span> <span class="nf">update_kv_cache_draft_token_location</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">best_path</span><span class="p">,</span>
<span class="n">best_path_len</span><span class="p">):</span>
<span class="n">best_path_len_tensor</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">best_path_len</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
<span class="n">accepted_draft_token_counts</span> <span class="o">=</span> <span class="n">best_path_len_tensor</span> <span class="o">-</span> <span class="mi">1</span>
<span class="n">accepted_draft_token_offsets</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">batch_size</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
<span class="n">accepted_draft_token_offsets</span><span class="p">[</span><span class="mi">1</span><span class="p">:]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span>
<span class="n">accepted_draft_token_counts</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="n">accepted_draft_token_offsets_cpu</span> <span class="o">=</span> <span class="n">accepted_draft_token_offsets</span><span class="o">.</span><span class="n">to</span><span class="p">(</span>
<span class="s1">&#39;cpu&#39;</span><span class="p">)</span>
<span class="n">packed_accepted_draft_tokens_indices</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span>
<span class="n">accepted_draft_token_offsets_cpu</span><span class="p">[</span><span class="n">batch_size</span><span class="p">],</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
<span class="k">for</span> <span class="n">seq_idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
<span class="n">seq_start</span> <span class="o">=</span> <span class="n">accepted_draft_token_offsets_cpu</span><span class="p">[</span><span class="n">seq_idx</span><span class="p">]</span>
<span class="n">seq_end</span> <span class="o">=</span> <span class="n">accepted_draft_token_offsets_cpu</span><span class="p">[</span><span class="n">seq_idx</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]</span>
<span class="n">seq_accepted_draft_count</span> <span class="o">=</span> <span class="n">seq_end</span> <span class="o">-</span> <span class="n">seq_start</span>
<span class="n">best_path_idx</span> <span class="o">=</span> <span class="n">best_path</span><span class="p">[</span><span class="n">seq_idx</span><span class="p">]</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span>
<span class="n">best_path</span><span class="p">[</span><span class="n">seq_idx</span><span class="p">],</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> <span class="k">else</span> <span class="n">best_path</span><span class="p">[</span><span class="n">seq_idx</span><span class="p">]</span>
<span class="n">seq_accepted_token_indices</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">medusa_paths</span><span class="p">[</span>
<span class="n">best_path_idx</span><span class="p">,</span> <span class="mi">1</span><span class="p">:</span><span class="mi">1</span> <span class="o">+</span> <span class="n">seq_accepted_draft_count</span><span class="p">]</span>
<span class="n">packed_accepted_draft_tokens_indices</span><span class="p">[</span>
<span class="n">seq_start</span><span class="p">:</span><span class="n">seq_end</span><span class="p">]</span> <span class="o">=</span> <span class="n">seq_accepted_token_indices</span> <span class="o">-</span> <span class="mi">1</span>
<span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_updater</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">accepted_draft_token_offsets</span><span class="p">,</span>
<span class="n">packed_accepted_draft_tokens_indices</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_tokens</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span> <span class="o">+=</span> <span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span></div>
<div class="viewcode-block" id="GenerationSession.update_output_ids_by_offset">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.update_output_ids_by_offset">[docs]</a>
<span class="k">def</span> <span class="nf">update_output_ids_by_offset</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">new_generated_ids</span><span class="p">,</span> <span class="n">offsets</span><span class="p">):</span>
<span class="c1"># output_ids [batch_size, padded_input_length]</span>
<span class="c1"># new_generated_ids [batch_size, padded_accepted_length]</span>
<span class="c1"># offsets [batch_size]</span>
<span class="c1"># FIXME: using fused kernel to update the padded output ids.</span>
<span class="n">batch_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">output_ids</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="k">for</span> <span class="n">b</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">output_ids</span><span class="p">[</span><span class="n">b</span><span class="p">,</span> <span class="n">offsets</span><span class="p">[</span><span class="n">b</span><span class="p">]:(</span>
<span class="n">offsets</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]</span>
<span class="p">)]</span> <span class="o">=</span> <span class="n">new_generated_ids</span><span class="p">[</span><span class="n">b</span><span class="p">][:</span><span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]]</span></div>
<div class="viewcode-block" id="GenerationSession.next_medusa_input_ids">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.next_medusa_input_ids">[docs]</a>
<span class="k">def</span> <span class="nf">next_medusa_input_ids</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="c1"># self.new_tokens [batch_size, padded_accepted_length]</span>
<span class="c1"># self.accept_lengths [batch_size]</span>
<span class="c1"># self.medusa_new_tokens [batch_size, num_medusa_tokens]</span>
<span class="c1"># FIXME: using fused kernel to generate the new medusa input ids.</span>
<span class="n">batch_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="k">for</span> <span class="n">b</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">generation_input_ids</span><span class="p">[</span><span class="n">b</span><span class="p">,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span><span class="p">[</span>
<span class="n">b</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="o">-</span> <span class="mi">1</span><span class="p">]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">generation_input_ids</span><span class="p">[</span><span class="n">b</span><span class="p">,</span> <span class="mi">1</span><span class="p">:]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">medusa_output_tokens</span><span class="p">[</span><span class="n">b</span><span class="p">,</span> <span class="p">:]</span></div>
<span class="c1"># OPTIMIZE: need to optimize this early-stop workflow.</span>
<div class="viewcode-block" id="GenerationSession.early_stop_criteria">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.early_stop_criteria">[docs]</a>
<span class="k">def</span> <span class="nf">early_stop_criteria</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">step</span><span class="p">,</span> <span class="n">should_stop</span><span class="p">):</span>
<span class="k">for</span> <span class="n">b</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">medusa_should_step</span><span class="p">[</span><span class="n">b</span><span class="p">]:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">continue</span>
<span class="c1"># output sequence length criteria.</span>
<span class="n">prev_total_output_length</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">total_accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]</span>
<span class="c1"># end id criteria.</span>
<span class="n">should_stop_with_end_id</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">any</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span><span class="p">[</span><span class="n">b</span><span class="p">,</span> <span class="p">:</span><span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]]</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">end_ids</span><span class="p">[</span><span class="n">b</span><span class="p">])</span>
<span class="n">end_id_pos</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span><span class="p">[</span><span class="n">b</span><span class="p">,</span> <span class="p">:</span><span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]]</span> <span class="o">==</span>
<span class="bp">self</span><span class="o">.</span><span class="n">end_ids</span><span class="p">[</span><span class="n">b</span><span class="p">])</span><span class="o">.</span><span class="n">nonzero</span><span class="p">(</span><span class="n">as_tuple</span><span class="o">=</span><span class="kc">True</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">medusa_should_step</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">medusa_should_step</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="ow">or</span> <span class="p">(</span>
<span class="n">prev_total_output_length</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="o">&gt;=</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_new_tokens</span><span class="p">)</span> <span class="ow">or</span> <span class="n">should_stop_with_end_id</span>
<span class="c1"># update accept lengths for the current step.</span>
<span class="k">if</span> <span class="p">(</span><span class="n">prev_total_output_length</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="o">&gt;=</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_new_tokens</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_new_tokens</span> <span class="o">-</span> <span class="n">prev_total_output_length</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">])</span>
<span class="k">if</span> <span class="n">should_stop_with_end_id</span><span class="p">:</span>
<span class="c1"># get the position of first end_id.</span>
<span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">end_id_pos</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">])</span>
<span class="bp">self</span><span class="o">.</span><span class="n">total_accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="o">+=</span> <span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]</span>
<span class="n">should_stop</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">should_stop</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="ow">or</span> <span class="p">(</span><span class="n">step</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_new_tokens</span> <span class="o">-</span>
<span class="mi">1</span><span class="p">)</span> <span class="ow">or</span> <span class="n">torch</span><span class="o">.</span><span class="n">all</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">medusa_should_step</span><span class="p">)</span>
<span class="k">return</span> <span class="n">should_stop</span></div>
<div class="viewcode-block" id="GenerationSession.process_logits_for_medusa_mode">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.process_logits_for_medusa_mode">[docs]</a>
<span class="k">def</span> <span class="nf">process_logits_for_medusa_mode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">step</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">input_ids</span><span class="p">,</span>
<span class="n">logits</span><span class="p">,</span> <span class="n">context_has_medusa_tokens</span><span class="p">,</span>
<span class="n">next_step_buffer</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">):</span>
<span class="n">medusa_logits</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;medusa_logits&#39;</span><span class="p">]</span>
<span class="n">best_path</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">best_path_lengths</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">should_stop</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="kc">False</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="nb">bool</span><span class="p">)</span>
<span class="k">if</span> <span class="n">step</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="c1"># logits buffer is of shape [bs, medusa_tokens+1, vocab]</span>
<span class="c1"># but during context phase, we get only [bs, 1, vocab] but contiguous</span>
<span class="n">logits</span> <span class="o">=</span> <span class="n">logits</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)[:</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">logits</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]]</span><span class="o">.</span><span class="n">view</span><span class="p">(</span>
<span class="n">batch_size</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">next_main_token_logits</span> <span class="o">=</span> <span class="n">logits</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">decoder_logits_dtype</span><span class="p">)</span>
<span class="n">next_main_token</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">next_main_token_logits</span><span class="p">,</span>
<span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span>
<span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span> <span class="o">=</span> <span class="n">next_main_token</span>
<span class="c1"># NOTE: stop criteria.</span>
<span class="bp">self</span><span class="o">.</span><span class="n">medusa_should_step</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">eq</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">),</span>
<span class="bp">self</span><span class="o">.</span><span class="n">end_ids</span><span class="p">)</span>
<span class="k">if</span> <span class="n">torch</span><span class="o">.</span><span class="n">equal</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="n">end_ids</span><span class="p">):</span>
<span class="c1"># stop if context phase output EOS</span>
<span class="n">should_stop</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="kc">True</span>
<span class="c1"># NOTE: only one token&#39;s medusa logit will be written in.</span>
<span class="n">medusa_logits</span> <span class="o">=</span> <span class="n">medusa_logits</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_tokens</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span>
<span class="o">-</span><span class="mi">1</span><span class="p">)[</span><span class="mi">0</span><span class="p">,</span> <span class="o">...</span><span class="p">]</span>
<span class="n">next_medusa_logits</span> <span class="o">=</span> <span class="n">medusa_logits</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_heads</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span>
<span class="o">-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">decoder_logits_dtype</span><span class="p">)</span>
<span class="n">next_medusa_tokens</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_next_medusa_tokens</span><span class="p">(</span>
<span class="n">batch_size</span><span class="p">,</span> <span class="n">next_medusa_logits</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">medusa_output_tokens</span> <span class="o">=</span> <span class="n">next_medusa_tokens</span><span class="p">[:,</span> <span class="bp">self</span><span class="o">.</span><span class="n">medusa_tree_ids</span><span class="p">[</span>
<span class="o">-</span><span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_tokens</span><span class="p">:]]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">total_accept_lengths</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">next_token_logits</span> <span class="o">=</span> <span class="n">logits</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">decoder_logits_dtype</span><span class="p">)</span>
<span class="n">best_path</span><span class="p">,</span> <span class="n">best_path_lengths</span><span class="p">,</span> <span class="n">next_main_tokens</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">find_best_medusa_path</span><span class="p">(</span>
<span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">generation_input_ids</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">),</span>
<span class="n">next_token_logits</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_tokens</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span>
<span class="o">-</span><span class="mi">1</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">best_path_lengths</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nested</span><span class="o">.</span><span class="n">to_padded_tensor</span><span class="p">(</span>
<span class="n">torch</span><span class="o">.</span><span class="n">nested</span><span class="o">.</span><span class="n">nested_tensor</span><span class="p">(</span><span class="n">next_main_tokens</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="bp">self</span><span class="o">.</span><span class="n">end_ids</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> <span class="c1">#FIXME end id padding.</span>
<span class="n">next_medusa_logits</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">filter_medusa_logits</span><span class="p">(</span>
<span class="n">batch_size</span><span class="p">,</span> <span class="n">best_path</span><span class="p">,</span> <span class="n">best_path_lengths</span><span class="p">,</span> <span class="n">medusa_logits</span><span class="p">)</span>
<span class="n">next_medusa_tokens</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_next_medusa_tokens</span><span class="p">(</span>
<span class="n">batch_size</span><span class="p">,</span> <span class="n">next_medusa_logits</span><span class="p">)</span>
<span class="n">should_stop</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">early_stop_criteria</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">step</span><span class="p">,</span>
<span class="n">should_stop</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">medusa_output_tokens</span> <span class="o">=</span> <span class="n">next_medusa_tokens</span><span class="p">[:,</span> <span class="bp">self</span><span class="o">.</span><span class="n">medusa_tree_ids</span><span class="p">[</span>
<span class="o">-</span><span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_tokens</span><span class="p">:]]</span>
<span class="c1"># NOTE: self.accept_lengths are the lengths of accepted tokens in the current step</span>
<span class="c1"># NOTE: self.sequence_length_buffer = num_past_kv_cache (accepted) + num_medusa_tokens + 1</span>
<span class="k">if</span> <span class="n">step</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">update_output_ids_by_offset</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="c1"># Noteself.sequence_length_buffer = num_past_kv_cache (accepted) + num_medusa_tokens</span>
<span class="bp">self</span><span class="o">.</span><span class="n">update_output_ids_by_offset</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_tokens</span><span class="p">)</span>
<span class="k">if</span> <span class="n">step</span> <span class="o">!=</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_new_tokens</span> <span class="o">-</span> <span class="mi">1</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">should_stop</span><span class="o">.</span><span class="n">item</span><span class="p">():</span>
<span class="bp">self</span><span class="o">.</span><span class="n">next_medusa_input_ids</span><span class="p">()</span>
<span class="k">if</span> <span class="n">step</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">:</span>
<span class="k">assert</span> <span class="n">best_path</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">best_path_lengths</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">update_kv_cache_draft_token_location</span><span class="p">(</span>
<span class="n">batch_size</span><span class="p">,</span> <span class="n">best_path</span><span class="p">,</span> <span class="n">best_path_lengths</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span> <span class="o">+=</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_tokens</span> <span class="o">+</span> <span class="mi">1</span>
<span class="c1"># NOTE: set the accepted tokens for the last step.</span>
<span class="k">if</span> <span class="n">should_stop</span><span class="o">.</span><span class="n">item</span><span class="p">():</span>
<span class="c1"># remove num_medusa_tokens for next generation.</span>
<span class="c1"># Runtime: denotes kv cache length start positions.</span>
<span class="c1"># Output: denotes the length of sequence length (input ids + output ids)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_tokens</span>
<span class="n">next_step_buffer</span><span class="p">[</span><span class="s1">&#39;host_past_key_value_lengths&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">to_torch</span><span class="p">()</span><span class="o">.</span><span class="n">copy_</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span><span class="p">)</span>
<span class="k">return</span> <span class="n">should_stop</span></div>
<div class="viewcode-block" id="GenerationSession.handle_per_step">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.handle_per_step">[docs]</a>
<span class="k">def</span> <span class="nf">handle_per_step</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span> <span class="n">cache_indirections</span><span class="p">:</span> <span class="nb">list</span><span class="p">,</span> <span class="n">step</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">max_context_length</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">input_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">hidden_states</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">scfg</span><span class="p">:</span> <span class="n">SamplingConfig</span><span class="p">,</span>
<span class="n">kv_cache_block_pointers</span><span class="p">:</span> <span class="nb">list</span><span class="p">,</span> <span class="n">host_kv_cache_block_pointers</span><span class="p">:</span> <span class="nb">list</span><span class="p">,</span>
<span class="n">prompt_embedding_table</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">tasks</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">context_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">host_context_lengths</span><span class="p">,</span>
<span class="n">attention_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">cross_attention_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">prompt_vocab_size</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">ite</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">sequence_limit_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">sequence_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">next_step_tensors</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">RuntimeTensor</span><span class="p">],</span> <span class="n">stop_words_data</span><span class="p">,</span>
<span class="n">bad_words_data</span><span class="p">,</span> <span class="n">no_repeat_ngram_size</span><span class="p">,</span> <span class="n">encoder_output</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">encoder_input_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">stopping_criteria</span><span class="p">:</span> <span class="n">StoppingCriteria</span><span class="p">,</span>
<span class="n">logits_processor</span><span class="p">:</span> <span class="n">LogitsProcessor</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="k">if</span> <span class="n">step</span> <span class="o">%</span> <span class="mi">2</span><span class="p">:</span>
<span class="n">context</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">context_0</span>
<span class="n">this_src_cache_indirection</span> <span class="o">=</span> <span class="n">cache_indirections</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
<span class="n">this_tgt_cache_indirection</span> <span class="o">=</span> <span class="n">cache_indirections</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">next_src_cache_indirection</span> <span class="o">=</span> <span class="n">cache_indirections</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">context</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">context_1</span>
<span class="n">this_src_cache_indirection</span> <span class="o">=</span> <span class="n">cache_indirections</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">this_tgt_cache_indirection</span> <span class="o">=</span> <span class="n">cache_indirections</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
<span class="n">next_src_cache_indirection</span> <span class="o">=</span> <span class="n">cache_indirections</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
<span class="k">if</span> <span class="n">step</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">model_inputs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_prepare_context_inputs</span><span class="p">(</span>
<span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span>
<span class="n">context_lengths</span><span class="o">=</span><span class="n">context_lengths</span><span class="p">,</span>
<span class="n">host_context_lengths</span><span class="o">=</span><span class="n">host_context_lengths</span><span class="p">,</span>
<span class="n">use_gpt_attention_plugin</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">,</span>
<span class="n">remove_input_padding</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">,</span>
<span class="n">max_context_length</span><span class="o">=</span><span class="n">max_context_length</span><span class="p">,</span>
<span class="n">input_ids</span><span class="o">=</span><span class="n">input_ids</span><span class="p">,</span>
<span class="n">pad_id</span><span class="o">=</span><span class="n">scfg</span><span class="o">.</span><span class="n">pad_id</span><span class="p">,</span>
<span class="n">eos_id</span><span class="o">=</span><span class="n">scfg</span><span class="o">.</span><span class="n">end_id</span><span class="p">)</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">model_inputs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;position_ids&#39;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">model_inputs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;last_token_ids&#39;</span><span class="p">)</span>
<span class="n">attention_mask</span> <span class="o">=</span> <span class="n">model_inputs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;attention_mask&#39;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span><span class="p">:</span>
<span class="n">host_kv_cache_block_pointers</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_manager</span><span class="o">.</span><span class="n">get_block_pointers</span><span class="p">(</span>
<span class="mi">1</span><span class="p">)</span>
<span class="n">kv_cache_block_pointers</span> <span class="o">=</span> <span class="n">host_kv_cache_block_pointers</span><span class="o">.</span><span class="n">to</span><span class="p">(</span>
<span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
<span class="n">ctx_tensors</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_context_shape_buffer</span><span class="p">(</span>
<span class="n">input_ids</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span> <span class="n">host_context_lengths</span><span class="p">,</span> <span class="n">position_ids</span><span class="p">,</span>
<span class="n">last_token_ids</span><span class="p">,</span> <span class="n">attention_mask</span><span class="p">,</span> <span class="n">cross_attention_mask</span><span class="p">,</span>
<span class="n">this_src_cache_indirection</span><span class="p">,</span> <span class="n">kv_cache_block_pointers</span><span class="p">,</span>
<span class="n">host_kv_cache_block_pointers</span><span class="p">,</span> <span class="n">hidden_states</span><span class="p">,</span>
<span class="n">prompt_embedding_table</span><span class="p">,</span> <span class="n">tasks</span><span class="p">,</span> <span class="n">prompt_vocab_size</span><span class="p">,</span>
<span class="n">encoder_output</span><span class="p">,</span> <span class="n">encoder_input_lengths</span><span class="p">)</span>
<span class="n">context</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">ctx_context</span>
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">_set_tensors</span><span class="p">(</span><span class="n">context</span><span class="p">,</span> <span class="n">ctx_tensors</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">debug_mode</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">debug_buffer</span> <span class="o">=</span> <span class="p">{</span>
<span class="n">name</span><span class="p">:</span> <span class="n">tensor</span><span class="o">.</span><span class="n">to_torch</span><span class="p">()</span>
<span class="k">for</span> <span class="n">name</span><span class="p">,</span> <span class="n">tensor</span> <span class="ow">in</span> <span class="n">ctx_tensors</span><span class="o">.</span><span class="n">items</span><span class="p">()</span>
<span class="p">}</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cuda_graph_mode</span><span class="p">:</span>
<span class="c1"># context mode, clean cuda graph instances</span>
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">cuda_graph_instances</span> <span class="o">=</span> <span class="p">[</span><span class="kc">None</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">2</span><span class="p">)]</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">debug_mode</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">_check_tensors</span><span class="p">(</span><span class="n">context</span><span class="p">)</span>
<span class="c1"># dynamic_decoder currently use torch&#39;s current stream, so must let TRT enqueue use same stream here</span>
<span class="n">stream</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">current_stream</span><span class="p">()</span><span class="o">.</span><span class="n">cuda_stream</span>
<span class="n">instance_idx</span> <span class="o">=</span> <span class="n">step</span> <span class="o">%</span> <span class="mi">2</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cuda_graph_mode</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">cuda_graph_instances</span><span class="p">[</span>
<span class="n">instance_idx</span><span class="p">]</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="c1"># launch cuda graph</span>
<span class="n">CUASSERT</span><span class="p">(</span>
<span class="n">cudart</span><span class="o">.</span><span class="n">cudaGraphLaunch</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">cuda_graph_instances</span><span class="p">[</span><span class="n">instance_idx</span><span class="p">],</span> <span class="n">stream</span><span class="p">))</span>
<span class="n">ok</span> <span class="o">=</span> <span class="kc">True</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">ok</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">_run</span><span class="p">(</span><span class="n">context</span><span class="p">,</span> <span class="n">stream</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">ok</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Executing TRT engine failed step=</span><span class="si">{</span><span class="n">step</span><span class="si">}</span><span class="s2">!&quot;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">debug_mode</span><span class="p">:</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">synchronize</span><span class="p">()</span>
<span class="n">context_logits</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
<span class="k">if</span> <span class="n">step</span> <span class="o">==</span> <span class="mi">0</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_context_logits</span><span class="p">:</span>
<span class="k">assert</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span>
<span class="n">context_logits</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;logits&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
<span class="c1"># gather last token of context</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
<span class="c1"># reshape self.buffer[&#39;logits&#39;] from [bs, max_context_length, vocab]</span>
<span class="c1"># to [1, bs * max_context_length, vocab]</span>
<span class="c1"># Note that the data are put in the buffer without padding although</span>
<span class="c1"># the allocated buffer has padding.</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;logits&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;logits&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span>
<span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">])</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;logits&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">index_select</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;logits&#39;</span><span class="p">],</span> <span class="mi">1</span><span class="p">,</span>
<span class="n">last_token_ids</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">last_token_ids</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">last_token_ids</span><span class="o">.</span><span class="n">expand</span><span class="p">(</span>
<span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;logits&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">gather</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;logits&#39;</span><span class="p">],</span>
<span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
<span class="n">index</span><span class="o">=</span><span class="n">last_token_ids</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int64</span><span class="p">))</span><span class="o">.</span><span class="n">view</span><span class="p">(</span>
<span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">)</span>
<span class="k">if</span> <span class="n">step</span> <span class="o">==</span> <span class="mi">0</span> <span class="ow">and</span> <span class="n">beam_width</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
<span class="k">assert</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span>
<span class="c1"># these tiled tensors are returned by handle_per_step(), so they can relay to the next generation calls</span>
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
<span class="n">attention_mask</span> <span class="o">=</span> <span class="n">_tile_beam_width</span><span class="p">(</span><span class="n">attention_mask</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">)</span>
<span class="n">context_lengths</span> <span class="o">=</span> <span class="n">_tile_beam_width</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">)</span>
<span class="n">host_context_lengths</span> <span class="o">=</span> <span class="n">_tile_beam_width</span><span class="p">(</span><span class="n">host_context_lengths</span><span class="p">,</span>
<span class="n">beam_width</span><span class="p">)</span>
<span class="k">if</span> <span class="n">encoder_input_lengths</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">encoder_input_lengths</span> <span class="o">=</span> <span class="n">_tile_beam_width</span><span class="p">(</span>
<span class="n">encoder_input_lengths</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">)</span>
<span class="k">if</span> <span class="n">tasks</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">tasks</span> <span class="o">=</span> <span class="n">_tile_beam_width</span><span class="p">(</span><span class="n">tasks</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">)</span>
<span class="c1"># Move tiling before logit computing of context</span>
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span><span class="p">:</span>
<span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
<span class="c1"># Note: this tiles both self attn cache and cross attn cache!</span>
<span class="c1"># both names contain &quot;present_key_value&quot;</span>
<span class="k">if</span> <span class="s2">&quot;present_key_value&quot;</span> <span class="ow">in</span> <span class="n">key</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">_tile_beam_width</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="n">key</span><span class="p">],</span> <span class="n">beam_width</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;logits&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">_tile_beam_width</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;logits&#39;</span><span class="p">],</span> <span class="n">beam_width</span><span class="p">)</span>
<span class="n">generation_logits</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_generation_logits</span><span class="p">:</span>
<span class="n">generation_logits</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;logits&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
<span class="c1"># Initialize sequence_lengths (no paddings) for the generation phase.</span>
<span class="k">if</span> <span class="n">step</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
<span class="c1"># NOTE: handle next step.</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">step</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_new_tokens</span> <span class="o">-</span> <span class="mi">1</span><span class="p">:</span>
<span class="c1"># Set shape and address for the next step</span>
<span class="n">model_inputs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_prepare_generation_inputs</span><span class="p">(</span>
<span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span>
<span class="n">context_lengths</span><span class="o">=</span><span class="n">context_lengths</span><span class="p">,</span>
<span class="n">use_gpt_attention_plugin</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">,</span>
<span class="n">remove_input_padding</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">,</span>
<span class="n">step</span><span class="o">=</span><span class="n">step</span><span class="p">,</span>
<span class="n">num_beams</span><span class="o">=</span><span class="n">beam_width</span><span class="p">,</span>
<span class="n">attention_mask</span><span class="o">=</span><span class="n">attention_mask</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">model_inputs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;position_ids&#39;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">model_inputs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;last_token_ids&#39;</span><span class="p">)</span>
<span class="n">attention_mask</span> <span class="o">=</span> <span class="n">model_inputs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;attention_mask&#39;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
<span class="c1"># Prepare for the next step, and always allocate 1 token slot.</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span><span class="p">:</span>
<span class="c1"># Iterate to the next step in KV cache manager.</span>
<span class="c1"># Increase number of tokens for all unfinished sequences.</span>
<span class="c1"># And allocate new blocks if needed.</span>
<span class="c1"># We set this to False for all sequences, since we use only length criterion to stop now</span>
<span class="c1"># OPTIMIZE: find a better of adding multiple tokens for paged kv cache.</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_tokens</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
<span class="c1"># Allocate kv cache token slots for next step.</span>
<span class="c1"># Make sure there are always &gt; (num_medusa_tokens + 1) free token slots.</span>
<span class="c1"># Allocate (num_medusa_tokens + 1) * 2 for safety as we don&#39;t know the current step or next step&#39;s accepted lengths.</span>
<span class="n">add_token_count</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_tokens</span> <span class="o">+</span>
<span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="mi">2</span> <span class="k">if</span> <span class="n">step</span> <span class="o">==</span> <span class="mi">0</span> <span class="k">else</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="p">)</span><span class="o">.</span><span class="n">item</span><span class="p">()</span>
<span class="k">assert</span> <span class="n">add_token_count</span> <span class="o">&gt;</span> <span class="mi">0</span>
<span class="k">for</span> <span class="n">new_tokens</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">add_token_count</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_manager</span><span class="o">.</span><span class="n">step</span><span class="p">([</span><span class="kc">False</span><span class="p">]</span> <span class="o">*</span> <span class="n">batch_size</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_manager</span><span class="o">.</span><span class="n">step</span><span class="p">([</span><span class="kc">False</span><span class="p">]</span> <span class="o">*</span> <span class="n">batch_size</span><span class="p">)</span>
<span class="n">host_kv_cache_block_pointers</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_manager</span><span class="o">.</span><span class="n">get_block_pointers</span><span class="p">(</span>
<span class="n">beam_width</span><span class="p">)</span>
<span class="n">kv_cache_block_pointers</span> <span class="o">=</span> <span class="n">host_kv_cache_block_pointers</span><span class="o">.</span><span class="n">to</span><span class="p">(</span>
<span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
<span class="n">next_context</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">context_1</span> <span class="k">if</span> <span class="n">step</span> <span class="o">%</span> <span class="mi">2</span> <span class="k">else</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">context_0</span>
<span class="n">next_step_tensors</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_next_step_shape_buffer</span><span class="p">(</span>
<span class="n">batch_size</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">,</span> <span class="n">max_context_length</span><span class="p">,</span> <span class="n">step</span><span class="p">,</span>
<span class="n">context_lengths</span><span class="p">,</span> <span class="n">host_context_lengths</span><span class="p">,</span> <span class="n">position_ids</span><span class="p">,</span>
<span class="n">last_token_ids</span><span class="p">,</span> <span class="n">attention_mask</span><span class="p">,</span> <span class="n">cross_attention_mask</span><span class="p">,</span>
<span class="n">next_src_cache_indirection</span><span class="p">,</span> <span class="n">kv_cache_block_pointers</span><span class="p">,</span>
<span class="n">host_kv_cache_block_pointers</span><span class="p">,</span> <span class="n">hidden_states</span><span class="p">,</span>
<span class="n">prompt_embedding_table</span><span class="p">,</span> <span class="n">tasks</span><span class="p">,</span> <span class="n">prompt_vocab_size</span><span class="p">,</span>
<span class="n">encoder_output</span><span class="p">,</span> <span class="n">encoder_input_lengths</span><span class="p">)</span>
<span class="c1"># there are some tensors created inside the _get_next_step_shape_buffer, not owned by any object</span>
<span class="c1"># needs to pro-long the life time of the tensors inside the next_step_tensors array</span>
<span class="c1"># otherwise, it maybe released before the next step actually enqueued</span>
<span class="c1"># one way to prolong it is to return the list, and destroy it in next step by assigning new values</span>
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">_set_tensors</span><span class="p">(</span><span class="n">next_context</span><span class="p">,</span> <span class="n">next_step_tensors</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cuda_graph_mode</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_capture_cuda_graph_and_instantiate</span><span class="p">(</span>
<span class="n">next_context</span><span class="p">,</span> <span class="n">stream</span><span class="p">,</span> <span class="n">step</span><span class="p">)</span>
<span class="n">should_stop</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">logits</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
<span class="n">logits</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;logits&#39;</span><span class="p">]</span>
<span class="k">if</span> <span class="n">logits</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span><span class="p">:</span>
<span class="n">should_stop</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">process_logits_for_medusa_mode</span><span class="p">(</span>
<span class="n">step</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">input_ids</span><span class="p">,</span> <span class="n">logits</span><span class="p">,</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">next_step_tensors</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">if</span> <span class="n">logits_processor</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">final_output_ids</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">finalize_decoder</span><span class="p">(</span>
<span class="n">context_lengths</span><span class="p">,</span>
<span class="n">batch_size</span><span class="p">,</span>
<span class="n">beam_width</span><span class="p">,</span>
<span class="n">scfg</span><span class="p">,</span>
<span class="n">in_progress</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="c1"># keep the shape as same as huggingface stopping_criteria</span>
<span class="n">final_output_ids_</span> <span class="o">=</span> <span class="n">final_output_ids</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span>
<span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">final_output_ids</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">))</span>
<span class="n">logits</span> <span class="o">=</span> <span class="n">logits_processor</span><span class="p">(</span><span class="n">step</span><span class="p">,</span> <span class="n">final_output_ids_</span><span class="p">,</span>
<span class="n">logits</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;logits&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">logits</span>
<span class="c1"># [batch_size x beam_width, vocab_size_padded] -&gt; [batch_size, beam_width, vocab_size_padded]</span>
<span class="n">next_token_logits</span> <span class="o">=</span> <span class="n">logits</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span>
<span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">,</span>
<span class="o">-</span><span class="mi">1</span><span class="p">))</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">decoder_logits_dtype</span><span class="p">)</span>
<span class="n">decode_step</span> <span class="o">=</span> <span class="n">step</span> <span class="o">+</span> <span class="n">max_context_length</span>
<span class="n">stop_words_list_ptrs</span><span class="p">,</span> <span class="n">stop_words_lens</span><span class="p">,</span> <span class="n">max_stop_words_len</span> <span class="o">=</span> <span class="n">stop_words_data</span>
<span class="n">bad_words_list_ptrs</span><span class="p">,</span> <span class="n">bad_words_lens</span><span class="p">,</span> <span class="n">max_bad_words_len</span> <span class="o">=</span> <span class="n">bad_words_data</span>
<span class="n">should_stop</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dynamic_decoder</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span>
<span class="n">next_token_logits</span><span class="p">,</span> <span class="n">decode_step</span><span class="p">,</span> <span class="n">max_context_length</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">sink_token_length</span><span class="p">,</span>
<span class="n">ite</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">end_ids</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">embedding_bias_opt</span><span class="p">,</span>
<span class="n">context_lengths</span><span class="p">,</span> <span class="n">sequence_limit_lengths</span><span class="p">,</span>
<span class="n">stop_words_list_ptrs</span><span class="p">,</span> <span class="n">stop_words_lens</span><span class="p">,</span>
<span class="n">max_stop_words_len</span><span class="p">,</span> <span class="n">bad_words_list_ptrs</span><span class="p">,</span> <span class="n">bad_words_lens</span><span class="p">,</span>
<span class="n">max_bad_words_len</span><span class="p">,</span> <span class="n">no_repeat_ngram_size</span><span class="p">,</span>
<span class="n">this_src_cache_indirection</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">output_ids</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">finished</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">finished</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">cum_log_probs</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">log_probs</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">log_probs_tiled</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">parent_ids</span><span class="p">,</span>
<span class="n">this_tgt_cache_indirection</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_output_ids_tgt</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_sequence_lengths_tgt</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_cum_log_probs</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_normed_scores</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_log_probs</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_min_normed_scores</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_num_beams</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_is_done</span><span class="p">,</span>
<span class="n">scfg</span><span class="o">.</span><span class="n">use_beam_hyps</span><span class="p">)</span>
<span class="k">if</span> <span class="n">stopping_criteria</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">should_stop</span><span class="o">.</span><span class="n">item</span><span class="p">():</span>
<span class="n">final_output_ids</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">finalize_decoder</span><span class="p">(</span>
<span class="n">context_lengths</span><span class="p">,</span>
<span class="n">batch_size</span><span class="p">,</span>
<span class="n">beam_width</span><span class="p">,</span>
<span class="n">scfg</span><span class="p">,</span>
<span class="n">in_progress</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="c1"># keep the shape as same as huggingface stopping_criteria</span>
<span class="n">final_output_ids_</span> <span class="o">=</span> <span class="n">final_output_ids</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span>
<span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">final_output_ids</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">))</span>
<span class="n">should_stop</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">stopping_criteria</span><span class="p">(</span>
<span class="n">step</span><span class="p">,</span> <span class="n">final_output_ids_</span><span class="p">,</span> <span class="n">logits</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">has_pp</span><span class="p">():</span>
<span class="n">should_stop</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">pp_communicate_new_tokens</span><span class="p">(</span>
<span class="n">should_stop</span><span class="p">,</span> <span class="n">this_tgt_cache_indirection</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span><span class="p">:</span>
<span class="k">if</span> <span class="p">(</span><span class="n">step</span> <span class="o">&gt;=</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_new_tokens</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="ow">or</span> <span class="p">(</span><span class="n">should_stop</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="ow">and</span> <span class="n">should_stop</span><span class="o">.</span><span class="n">item</span><span class="p">()):</span>
<span class="c1"># Free all blocks in all sequences.</span>
<span class="c1"># With in-flight batching and while loop we&#39;ll free some sequences, when they are done</span>
<span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_manager</span><span class="o">.</span><span class="n">step</span><span class="p">([</span><span class="kc">True</span><span class="p">]</span> <span class="o">*</span> <span class="n">batch_size</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">debug_mode</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">dump_debug_buffers</span><span class="p">(</span><span class="n">step</span><span class="p">)</span>
<span class="k">if</span> <span class="n">next_step_tensors</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">debug_buffer</span> <span class="o">=</span> <span class="p">{</span>
<span class="n">name</span><span class="p">:</span> <span class="n">tensor</span><span class="o">.</span><span class="n">to_torch</span><span class="p">()</span>
<span class="k">for</span> <span class="n">name</span><span class="p">,</span> <span class="n">tensor</span> <span class="ow">in</span> <span class="n">next_step_tensors</span><span class="o">.</span><span class="n">items</span><span class="p">()</span>
<span class="p">}</span>
<span class="k">return</span> <span class="n">should_stop</span><span class="p">,</span> <span class="n">next_step_tensors</span><span class="p">,</span> <span class="n">tasks</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span> <span class="n">host_context_lengths</span><span class="p">,</span> <span class="n">attention_mask</span><span class="p">,</span> <span class="n">context_logits</span><span class="p">,</span> <span class="n">generation_logits</span><span class="p">,</span> <span class="n">encoder_input_lengths</span></div>
<div class="viewcode-block" id="GenerationSession.dump_debug_buffers">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.dump_debug_buffers">[docs]</a>
<span class="k">def</span> <span class="nf">dump_debug_buffers</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">step</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">debug_tensors_to_save</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="c1"># restricted written tensors according to filter</span>
<span class="n">debug_tensor_names</span> <span class="o">=</span> <span class="n">copy</span><span class="o">.</span><span class="n">deepcopy</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">debug_buffer</span><span class="o">.</span><span class="n">keys</span><span class="p">()))</span>
<span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="n">debug_tensor_names</span><span class="p">:</span>
<span class="k">if</span> <span class="nb">all</span><span class="p">([</span><span class="n">kk</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">k</span> <span class="k">for</span> <span class="n">kk</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">debug_tensors_to_save</span><span class="p">]):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">debug_buffer</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="n">k</span><span class="p">)</span>
<span class="n">debug_dir</span> <span class="o">=</span> <span class="n">Path</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;tllm_debug/PP_</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_rank</span><span class="si">}</span><span class="s2">/TP_</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_rank</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="n">debug_dir</span><span class="o">.</span><span class="n">mkdir</span><span class="p">(</span><span class="n">parents</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">exist_ok</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="k">for</span> <span class="n">name</span><span class="p">,</span> <span class="n">t</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">debug_buffer</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="c1"># convert tensor name to valid file name</span>
<span class="n">fname</span> <span class="o">=</span> <span class="n">name</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s2">&quot;/&quot;</span><span class="p">,</span> <span class="s2">&quot;.&quot;</span><span class="p">)</span>
<span class="n">t</span> <span class="o">=</span> <span class="n">torch_to_numpy</span><span class="p">(</span><span class="n">t</span><span class="p">)</span>
<span class="n">np</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">debug_dir</span> <span class="o">/</span> <span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">fname</span><span class="si">}</span><span class="s2">-step</span><span class="si">{</span><span class="n">step</span><span class="si">}</span><span class="s2">.npy&quot;</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span>
<span class="n">txt_format</span> <span class="o">=</span> <span class="s2">&quot;</span><span class="si">%d</span><span class="s2">&quot;</span> <span class="k">if</span> <span class="n">t</span><span class="o">.</span><span class="n">dtype</span> <span class="ow">in</span> <span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">]</span> <span class="k">else</span> <span class="s1">&#39;</span><span class="si">%.18e</span><span class="s1">&#39;</span>
<span class="n">np</span><span class="o">.</span><span class="n">savetxt</span><span class="p">(</span>
<span class="n">debug_dir</span> <span class="o">/</span> <span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">fname</span><span class="si">}</span><span class="s2">-step</span><span class="si">{</span><span class="n">step</span><span class="si">}</span><span class="s2">.txt&quot;</span><span class="p">,</span>
<span class="n">t</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">t</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]),</span> <span class="c1"># savetxt accepts 2 dims only</span>
<span class="n">fmt</span><span class="o">=</span><span class="n">txt_format</span><span class="p">)</span></div>
<div class="viewcode-block" id="GenerationSession.decode_regular">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.decode_regular">[docs]</a>
<span class="k">def</span> <span class="nf">decode_regular</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">scfg</span><span class="p">:</span> <span class="n">SamplingConfig</span><span class="p">,</span>
<span class="n">sequence_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">context_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">host_context_lengths</span><span class="p">,</span>
<span class="n">max_context_length</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">beam_width</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">cache_indirections</span><span class="p">:</span> <span class="nb">list</span><span class="p">,</span>
<span class="n">input_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">hidden_states</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">prompt_embedding_table</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">tasks</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">prompt_vocab_size</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">ite</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">sequence_limit_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">stop_words_data</span><span class="p">,</span>
<span class="n">bad_words_data</span><span class="p">,</span>
<span class="n">no_repeat_ngram_size</span><span class="p">,</span>
<span class="n">output_sequence_lengths</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">return_dict</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">encoder_output</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">encoder_input_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">stopping_criteria</span><span class="p">:</span> <span class="n">StoppingCriteria</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">logits_processor</span><span class="p">:</span> <span class="n">LogitsProcessor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">cross_attention_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="n">kv_cache_block_pointers</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">host_kv_cache_block_pointers</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">attention_mask</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">outputs_context_logits</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">outputs_generation_logits</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">def</span> <span class="nf">get_outputs_dict</span><span class="p">(</span><span class="n">output_ids</span><span class="p">):</span>
<span class="n">outputs</span> <span class="o">=</span> <span class="p">{}</span>
<span class="n">outputs</span><span class="p">[</span><span class="s1">&#39;output_ids&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">output_ids</span>
<span class="k">if</span> <span class="n">scfg</span><span class="o">.</span><span class="n">output_log_probs</span><span class="p">:</span>
<span class="n">outputs</span><span class="p">[</span><span class="s1">&#39;log_probs&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">log_probs</span>
<span class="k">if</span> <span class="n">scfg</span><span class="o">.</span><span class="n">output_cum_log_probs</span><span class="p">:</span>
<span class="n">outputs</span><span class="p">[</span><span class="s1">&#39;cum_log_probs&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">cum_log_probs</span>
<span class="k">if</span> <span class="n">output_sequence_lengths</span><span class="p">:</span>
<span class="n">outputs</span><span class="p">[</span>
<span class="s1">&#39;sequence_lengths&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span>
<span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">])</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_context_logits</span><span class="p">:</span>
<span class="n">outputs</span><span class="p">[</span><span class="s1">&#39;context_logits&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">outputs_context_logits</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_generation_logits</span><span class="p">:</span>
<span class="n">outputs</span><span class="p">[</span><span class="s1">&#39;generation_logits&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">outputs_generation_logits</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span><span class="p">:</span>
<span class="n">outputs</span><span class="p">[</span><span class="s1">&#39;medusa_output_tokens&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">medusa_output_tokens</span>
<span class="n">outputs</span><span class="p">[</span><span class="s1">&#39;accept_lengths&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">medusa_temperature</span> <span class="o">!=</span> <span class="mf">0.0</span><span class="p">:</span>
<span class="n">outputs</span><span class="p">[</span><span class="s1">&#39;medusa_output_logits&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">medusa_output_logits</span>
<span class="k">return</span> <span class="n">outputs</span>
<span class="n">benchmark_profiler</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;benchmark_profiler&#39;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
<span class="n">generation_phase_step_count</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">def</span> <span class="nf">profile_fn</span><span class="p">(</span><span class="n">benchmark_profiler_obj</span><span class="p">,</span> <span class="n">step_count</span><span class="p">):</span>
<span class="k">if</span> <span class="n">benchmark_profiler_obj</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">benchmark_profiler_obj</span><span class="o">.</span><span class="n">record_cuda_event</span><span class="p">(</span><span class="s1">&#39;last_token&#39;</span><span class="p">)</span>
<span class="n">benchmark_profiler_obj</span><span class="o">.</span><span class="n">record_elapsed_time</span><span class="p">(</span>
<span class="s1">&#39;first_token&#39;</span><span class="p">,</span> <span class="s1">&#39;last_token&#39;</span><span class="p">,</span> <span class="s1">&#39;generation_time&#39;</span><span class="p">)</span>
<span class="n">benchmark_profiler_obj</span><span class="o">.</span><span class="n">add_aux_info</span><span class="p">(</span><span class="s1">&#39;generation_step_count&#39;</span><span class="p">,</span>
<span class="n">step_count</span><span class="p">)</span>
<span class="n">next_step_tensors</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">for</span> <span class="n">step</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_new_tokens</span><span class="p">):</span>
<span class="n">should_stop</span><span class="p">,</span> <span class="n">next_step_tensors</span><span class="p">,</span> <span class="n">tasks</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span> <span class="n">host_context_lengths</span><span class="p">,</span> <span class="n">attention_mask</span><span class="p">,</span> <span class="n">context_logits</span><span class="p">,</span> <span class="n">generation_logits</span><span class="p">,</span> <span class="n">encoder_input_lengths</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">handle_per_step</span><span class="p">(</span>
<span class="n">cache_indirections</span><span class="p">,</span> <span class="n">step</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">max_context_length</span><span class="p">,</span>
<span class="n">beam_width</span><span class="p">,</span> <span class="n">input_ids</span><span class="p">,</span> <span class="n">hidden_states</span><span class="p">,</span> <span class="n">scfg</span><span class="p">,</span>
<span class="n">kv_cache_block_pointers</span><span class="p">,</span> <span class="n">host_kv_cache_block_pointers</span><span class="p">,</span>
<span class="n">prompt_embedding_table</span><span class="p">,</span> <span class="n">tasks</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span>
<span class="n">host_context_lengths</span><span class="p">,</span> <span class="n">attention_mask</span><span class="p">,</span> <span class="n">cross_attention_mask</span><span class="p">,</span>
<span class="n">prompt_vocab_size</span><span class="p">,</span> <span class="n">ite</span><span class="p">,</span> <span class="n">sequence_limit_lengths</span><span class="p">,</span>
<span class="n">sequence_lengths</span><span class="p">,</span> <span class="n">next_step_tensors</span><span class="p">,</span> <span class="n">stop_words_data</span><span class="p">,</span>
<span class="n">bad_words_data</span><span class="p">,</span> <span class="n">no_repeat_ngram_size</span><span class="p">,</span> <span class="n">encoder_output</span><span class="p">,</span>
<span class="n">encoder_input_lengths</span><span class="p">,</span> <span class="n">stopping_criteria</span><span class="p">,</span> <span class="n">logits_processor</span><span class="p">,</span>
<span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
<span class="k">if</span> <span class="n">step</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="k">if</span> <span class="n">benchmark_profiler</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">benchmark_profiler</span><span class="o">.</span><span class="n">record_cuda_event</span><span class="p">(</span><span class="s1">&#39;first_token&#39;</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">generation_phase_step_count</span> <span class="o">=</span> <span class="n">generation_phase_step_count</span> <span class="o">+</span> <span class="mi">1</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
<span class="k">if</span> <span class="n">step</span> <span class="o">==</span> <span class="mi">0</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_context_logits</span><span class="p">:</span>
<span class="n">outputs_context_logits</span> <span class="o">=</span> <span class="n">context_logits</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_generation_logits</span><span class="p">:</span>
<span class="n">outputs_generation_logits</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">generation_logits</span><span class="p">)</span>
<span class="k">if</span> <span class="n">should_stop</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">should_stop</span><span class="o">.</span><span class="n">item</span><span class="p">():</span>
<span class="n">profile_fn</span><span class="p">(</span><span class="n">benchmark_profiler</span><span class="p">,</span> <span class="n">generation_phase_step_count</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span><span class="p">:</span>
<span class="c1"># just hack away for now</span>
<span class="n">final_output_ids</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">output_ids</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">final_output_ids</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">finalize_decoder</span><span class="p">(</span>
<span class="n">context_lengths</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">,</span> <span class="n">scfg</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_first_pp_rank</span><span class="p">():</span>
<span class="k">if</span> <span class="n">return_dict</span><span class="p">:</span>
<span class="k">return</span> <span class="n">get_outputs_dict</span><span class="p">(</span><span class="n">final_output_ids</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">return</span> <span class="n">final_output_ids</span>
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
<span class="n">outputs</span> <span class="o">=</span> <span class="p">{}</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_context_logits</span><span class="p">:</span>
<span class="n">outputs</span><span class="p">[</span><span class="s1">&#39;context_logits&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">outputs_context_logits</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_generation_logits</span><span class="p">:</span>
<span class="n">outputs</span><span class="p">[</span><span class="s1">&#39;generation_logits&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">outputs_generation_logits</span>
<span class="k">return</span> <span class="n">outputs</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">return</span> <span class="kc">None</span>
<span class="k">assert</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span><span class="p">,</span> <span class="s2">&quot;the custom decoder doesn&#39;t support medusa.&quot;</span>
<span class="n">profile_fn</span><span class="p">(</span><span class="n">benchmark_profiler</span><span class="p">,</span> <span class="n">generation_phase_step_count</span><span class="p">)</span>
<span class="n">final_output_ids</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">finalize_decoder</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span>
<span class="n">beam_width</span><span class="p">,</span> <span class="n">scfg</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_first_pp_rank</span><span class="p">():</span>
<span class="k">if</span> <span class="n">return_dict</span><span class="p">:</span>
<span class="k">return</span> <span class="n">get_outputs_dict</span><span class="p">(</span><span class="n">final_output_ids</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">return</span> <span class="n">final_output_ids</span>
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
<span class="n">outputs</span> <span class="o">=</span> <span class="p">{}</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_context_logits</span><span class="p">:</span>
<span class="n">outputs</span><span class="p">[</span><span class="s1">&#39;context_logits&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">outputs_context_logits</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_generation_logits</span><span class="p">:</span>
<span class="n">outputs</span><span class="p">[</span><span class="s1">&#39;generation_logits&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">outputs_generation_logits</span>
<span class="k">return</span> <span class="n">outputs</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">return</span> <span class="kc">None</span></div>
<div class="viewcode-block" id="GenerationSession.decode_stream">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.decode_stream">[docs]</a>
<span class="k">def</span> <span class="nf">decode_stream</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">scfg</span><span class="p">:</span> <span class="n">SamplingConfig</span><span class="p">,</span>
<span class="n">sequence_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">context_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">host_context_lengths</span><span class="p">,</span>
<span class="n">max_context_length</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">beam_width</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">cache_indirections</span><span class="p">:</span> <span class="nb">list</span><span class="p">,</span>
<span class="n">input_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">hidden_states</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">prompt_embedding_table</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">tasks</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">prompt_vocab_size</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">ite</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">sequence_limit_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">stop_words_data</span><span class="p">,</span>
<span class="n">bad_words_data</span><span class="p">,</span>
<span class="n">no_repeat_ngram_size</span><span class="p">,</span>
<span class="n">output_sequence_lengths</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">return_dict</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">encoder_output</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">encoder_input_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">stopping_criteria</span><span class="p">:</span> <span class="n">StoppingCriteria</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">logits_processor</span><span class="p">:</span> <span class="n">LogitsProcessor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">cross_attention_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="n">kv_cache_block_pointers</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">host_kv_cache_block_pointers</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">attention_mask</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">outputs_context_logits</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">def</span> <span class="nf">get_outputs_dict</span><span class="p">(</span><span class="n">output_ids</span><span class="p">):</span>
<span class="n">outputs</span> <span class="o">=</span> <span class="p">{}</span>
<span class="n">outputs</span><span class="p">[</span><span class="s1">&#39;output_ids&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">output_ids</span>
<span class="k">if</span> <span class="n">output_sequence_lengths</span><span class="p">:</span>
<span class="n">outputs</span><span class="p">[</span>
<span class="s1">&#39;sequence_lengths&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span>
<span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">])</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_context_logits</span><span class="p">:</span>
<span class="n">outputs</span><span class="p">[</span><span class="s1">&#39;context_logits&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">outputs_context_logits</span>
<span class="k">return</span> <span class="n">outputs</span>
<span class="n">next_step_tensors</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">for</span> <span class="n">step</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_new_tokens</span><span class="p">):</span>
<span class="n">should_stop</span><span class="p">,</span> <span class="n">next_step_tensors</span><span class="p">,</span> <span class="n">tasks</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span> <span class="n">host_context_lengths</span><span class="p">,</span> <span class="n">attention_mask</span><span class="p">,</span> <span class="n">context_logits</span><span class="p">,</span> <span class="n">generation_logits</span><span class="p">,</span> <span class="n">encoder_input_lengths</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">handle_per_step</span><span class="p">(</span>
<span class="n">cache_indirections</span><span class="p">,</span> <span class="n">step</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">max_context_length</span><span class="p">,</span>
<span class="n">beam_width</span><span class="p">,</span> <span class="n">input_ids</span><span class="p">,</span> <span class="n">hidden_states</span><span class="p">,</span> <span class="n">scfg</span><span class="p">,</span>
<span class="n">kv_cache_block_pointers</span><span class="p">,</span> <span class="n">host_kv_cache_block_pointers</span><span class="p">,</span>
<span class="n">prompt_embedding_table</span><span class="p">,</span> <span class="n">tasks</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span>
<span class="n">host_context_lengths</span><span class="p">,</span> <span class="n">attention_mask</span><span class="p">,</span> <span class="n">cross_attention_mask</span><span class="p">,</span>
<span class="n">prompt_vocab_size</span><span class="p">,</span> <span class="n">ite</span><span class="p">,</span> <span class="n">sequence_limit_lengths</span><span class="p">,</span>
<span class="n">sequence_lengths</span><span class="p">,</span> <span class="n">next_step_tensors</span><span class="p">,</span> <span class="n">stop_words_data</span><span class="p">,</span>
<span class="n">bad_words_data</span><span class="p">,</span> <span class="n">no_repeat_ngram_size</span><span class="p">,</span> <span class="n">encoder_output</span><span class="p">,</span>
<span class="n">encoder_input_lengths</span><span class="p">,</span> <span class="n">stopping_criteria</span><span class="p">,</span> <span class="n">logits_processor</span><span class="p">)</span>
<span class="k">if</span> <span class="n">step</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">outputs_context_logits</span> <span class="o">=</span> <span class="n">context_logits</span>
<span class="k">if</span> <span class="n">should_stop</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">final_output_ids</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">finalize_decoder</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">,</span>
<span class="n">batch_size</span><span class="p">,</span>
<span class="n">beam_width</span><span class="p">,</span>
<span class="n">scfg</span><span class="p">,</span>
<span class="n">in_progress</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_first_pp_rank</span><span class="p">():</span>
<span class="k">if</span> <span class="n">return_dict</span><span class="p">:</span>
<span class="k">yield</span> <span class="n">get_outputs_dict</span><span class="p">(</span><span class="n">final_output_ids</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">yield</span> <span class="n">final_output_ids</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">yield</span> <span class="kc">None</span>
<span class="k">if</span> <span class="n">should_stop</span><span class="o">.</span><span class="n">item</span><span class="p">():</span>
<span class="k">return</span>
<span class="n">final_output_ids</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">finalize_decoder</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span>
<span class="n">beam_width</span><span class="p">,</span> <span class="n">scfg</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_first_pp_rank</span><span class="p">():</span>
<span class="k">if</span> <span class="n">return_dict</span><span class="p">:</span>
<span class="k">yield</span> <span class="n">get_outputs_dict</span><span class="p">(</span><span class="n">final_output_ids</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">yield</span> <span class="n">final_output_ids</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">yield</span> <span class="kc">None</span></div>
<div class="viewcode-block" id="GenerationSession.decode_batch">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.decode_batch">[docs]</a>
<span class="k">def</span> <span class="nf">decode_batch</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">input_ids</span><span class="p">:</span> <span class="n">Sequence</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span>
<span class="n">sampling_config</span><span class="p">:</span> <span class="n">SamplingConfig</span><span class="p">,</span>
<span class="n">streaming</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="n">input_ids</span><span class="p">,</span> <span class="n">context_lengths</span> <span class="o">=</span> <span class="n">_prepare_input_ids</span><span class="p">(</span><span class="n">input_ids</span><span class="p">)</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">decode</span><span class="p">(</span><span class="n">input_ids</span><span class="p">,</span>
<span class="n">context_lengths</span><span class="p">,</span>
<span class="n">sampling_config</span><span class="p">,</span>
<span class="n">streaming</span><span class="o">=</span><span class="n">streaming</span><span class="p">,</span>
<span class="o">**</span><span class="n">kwargs</span><span class="p">)</span></div>
<span class="c1"># As dynamic_decoder uses torch&#39;s current stream, we must ensure it runs on the same stream that</span>
<span class="c1"># dynamic_decoder was set up with</span>
<div class="viewcode-block" id="GenerationSession.decode">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.decode">[docs]</a>
<span class="nd">@cuda_stream_guard</span>
<span class="k">def</span> <span class="nf">decode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">input_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">context_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">sampling_config</span><span class="p">:</span> <span class="n">SamplingConfig</span><span class="p">,</span>
<span class="n">prompt_embedding_table</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">tasks</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">prompt_vocab_size</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">stop_words_list</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">bad_words_list</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">no_repeat_ngram_size</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">streaming</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">output_sequence_lengths</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">return_dict</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">encoder_output</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">encoder_input_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">stopping_criteria</span><span class="p">:</span> <span class="n">StoppingCriteria</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">logits_processor</span><span class="p">:</span> <span class="n">LogitsProcessor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">cross_attention_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="n">scfg</span> <span class="o">=</span> <span class="n">sampling_config</span>
<span class="n">batch_size</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="n">beam_width</span> <span class="o">=</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span>
<span class="n">max_context_length</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">)</span><span class="o">.</span><span class="n">item</span><span class="p">()</span>
<span class="n">host_context_lengths</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span>
<span class="k">assert</span> <span class="n">batch_size</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span> \
<span class="s2">&quot;Given batch size is different from the one used in setup(),&quot;</span> \
<span class="s2">&quot;rerun the setup function with the new batch size to avoid buffer overflow.&quot;</span>
<span class="k">assert</span> <span class="n">max_context_length</span> <span class="o">&lt;=</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_context_length</span><span class="p">,</span> \
<span class="s2">&quot;Given input length is large then the one used in setup(),&quot;</span> \
<span class="s2">&quot;rerun the setup function with the new max_context_length to avoid buffer overflow.&quot;</span>
<span class="k">assert</span> <span class="n">beam_width</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">beam_width</span><span class="p">,</span> \
<span class="s2">&quot;Given beam width is different from the one used in setup(),&quot;</span> \
<span class="s2">&quot;rerun the setup function with the new beam width to avoid buffer overflow.&quot;</span>
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">sink_token_length</span> <span class="o">&lt;=</span> <span class="n">torch</span><span class="o">.</span><span class="n">min</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">)</span><span class="o">.</span><span class="n">item</span><span class="p">(),</span> \
<span class="s2">&quot;Given sink token length is larger than shortest context length,&quot;</span> \
<span class="s2">&quot;rerun the setup function with a smaller sink token length.&quot;</span>
<span class="n">ite</span> <span class="o">=</span> <span class="mi">0</span> <span class="c1"># index of local batches, will always be 0 if pp_size = 1</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span> <span class="ow">and</span> <span class="n">input_ids</span><span class="o">.</span><span class="n">dim</span><span class="p">()</span> <span class="o">==</span> <span class="mi">2</span><span class="p">:</span>
<span class="k">assert</span> <span class="n">input_ids</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
<span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span> <span class="s2">&quot;Packed 2D input must have shape [1, &lt;sum of input lengths&gt;]&quot;</span>
<span class="n">input_ids</span> <span class="o">=</span> <span class="n">input_ids</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">__setup_decoder</span><span class="p">(</span><span class="n">input_ids</span><span class="p">,</span> <span class="n">scfg</span><span class="p">,</span> <span class="n">host_context_lengths</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer_allocated</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span><span class="s1">&#39;Buffer not allocated, please call setup first!&#39;</span><span class="p">)</span>
<span class="n">sequence_limit_lengths</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="c1"># Sequence_lengths for the dynamic decoder still has the input paddings.</span>
<span class="n">sequence_lengths</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">((</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">beam_width</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
<span class="n">max_context_length</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="n">cache_indirections</span> <span class="o">=</span> <span class="p">[</span>
<span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">((</span>
<span class="n">batch_size</span><span class="p">,</span>
<span class="n">beam_width</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window_size</span><span class="p">,</span>
<span class="p">),</span>
<span class="mi">0</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">),</span>
<span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">((</span>
<span class="n">batch_size</span><span class="p">,</span>
<span class="n">beam_width</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window_size</span><span class="p">,</span>
<span class="p">),</span>
<span class="mi">0</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="p">]</span> <span class="c1"># ping-pong buffers</span>
<span class="n">hidden_states</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">has_pp</span><span class="p">():</span>
<span class="n">max_num_tokens</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">beam_width</span><span class="p">,</span>
<span class="n">batch_size</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">)</span>
<span class="n">hidden_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span>
<span class="n">hidden_states</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="n">max_num_tokens</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">))</span>
<span class="c1"># Init KV cache block manager</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span><span class="p">:</span>
<span class="n">bubble_len</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">sink_token_length</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">tokens_per_block</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">bubble_len</span> <span class="o">+=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">tokens_per_block</span> <span class="o">-</span>
<span class="bp">self</span><span class="o">.</span><span class="n">sink_token_length</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">tokens_per_block</span><span class="p">)</span>
<span class="n">max_blocks_per_seq</span> <span class="o">=</span> <span class="n">math</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span>
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window_size</span> <span class="o">+</span> <span class="n">bubble_len</span><span class="p">)</span> <span class="o">/</span>
<span class="bp">self</span><span class="o">.</span><span class="n">tokens_per_block</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_one_more_block</span><span class="p">:</span>
<span class="n">max_blocks_per_seq</span> <span class="o">+=</span> <span class="mi">1</span>
<span class="n">blocks</span> <span class="o">=</span> <span class="n">batch_size</span> <span class="o">*</span> <span class="n">beam_width</span> <span class="o">*</span> <span class="n">max_blocks_per_seq</span>
<span class="n">memory_pools</span> <span class="o">=</span> <span class="p">[</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;present_key_value_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">]</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">)</span>
<span class="p">]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_manager</span> <span class="o">=</span> <span class="n">KVCacheManager</span><span class="p">(</span>
<span class="n">memory_pools</span><span class="p">,</span> <span class="n">blocks</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">tokens_per_block</span><span class="p">,</span> <span class="n">max_blocks_per_seq</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">sink_token_length</span><span class="p">,</span>
<span class="n">beam_width</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_one_more_block</span><span class="p">)</span>
<span class="c1"># Add sequences to the manager</span>
<span class="k">for</span> <span class="n">bi</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
<span class="n">generation_sequence</span> <span class="o">=</span> <span class="n">GenerationSequence</span><span class="p">(</span><span class="n">seq_idx</span><span class="o">=</span><span class="n">bi</span><span class="p">,</span>
<span class="n">batch_idx</span><span class="o">=</span><span class="n">bi</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_manager</span><span class="o">.</span><span class="n">add_sequence</span><span class="p">(</span><span class="n">generation_sequence</span><span class="p">,</span>
<span class="n">max_context_length</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span><span class="p">:</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">quant_mode</span><span class="o">.</span><span class="n">has_kv_cache_quant</span><span class="p">():</span>
<span class="c1"># Since torch does not support fp8 now, using int8 here.</span>
<span class="n">kv_cache_type</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">int8</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">kv_cache_type</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dtype</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span> <span class="k">else</span> <span class="bp">self</span><span class="o">.</span><span class="n">_tensor_dtype</span><span class="p">(</span>
<span class="sa">f</span><span class="s1">&#39;present_key_value_</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">history_max_seq_length</span> <span class="o">=</span> <span class="p">[</span><span class="n">max_context_length</span><span class="p">]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_updater</span> <span class="o">=</span> <span class="n">KVCacheUpdater</span><span class="p">()</span>
<span class="k">assert</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span>
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_updater</span><span class="o">.</span><span class="n">init_paged_kv_cache</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">num_heads_kv</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">head_size</span><span class="p">,</span> <span class="n">kv_cache_type</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_manager</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">past_key_value_list</span> <span class="o">=</span> <span class="p">[</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;present_key_value_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">]</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">)</span>
<span class="p">]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_updater</span><span class="o">.</span><span class="n">init_linear_kv_cache</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">num_heads_kv</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">head_size</span><span class="p">,</span> <span class="n">kv_cache_type</span><span class="p">,</span>
<span class="n">past_key_value_list</span><span class="p">)</span>
<span class="n">stop_words_lens</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">stop_words_list_ptrs</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">max_stop_words_len</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">if</span> <span class="n">stop_words_list</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">max_stop_words_len</span> <span class="o">=</span> <span class="n">stop_words_list</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span>
<span class="n">stop_words_lens</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="p">),</span>
<span class="n">max_stop_words_len</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
<span class="n">stop_words_list_ptrs</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">batch_size</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int64</span><span class="p">)</span>
<span class="k">for</span> <span class="n">bi</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
<span class="n">stop_words_list_ptrs</span><span class="p">[</span><span class="n">bi</span><span class="p">]</span> <span class="o">=</span> <span class="n">stop_words_list</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">(</span>
<span class="p">)</span> <span class="o">+</span> <span class="n">bi</span> <span class="o">*</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">max_stop_words_len</span>
<span class="n">stop_words_list_ptrs</span> <span class="o">=</span> <span class="n">stop_words_list_ptrs</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
<span class="n">stop_words_data</span> <span class="o">=</span> <span class="p">(</span><span class="n">stop_words_list_ptrs</span><span class="p">,</span> <span class="n">stop_words_lens</span><span class="p">,</span>
<span class="n">max_stop_words_len</span><span class="p">)</span>
<span class="n">bad_words_lens</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">bad_words_list_ptrs</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">max_bad_words_len</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">if</span> <span class="n">bad_words_list</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">max_bad_words_len</span> <span class="o">=</span> <span class="n">bad_words_list</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span>
<span class="n">bad_words_lens</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="p">),</span>
<span class="n">max_bad_words_len</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
<span class="n">bad_words_list_ptrs</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">batch_size</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int64</span><span class="p">)</span>
<span class="k">for</span> <span class="n">bi</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
<span class="n">bad_words_list_ptrs</span><span class="p">[</span><span class="n">bi</span><span class="p">]</span> <span class="o">=</span> <span class="n">bad_words_list</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">(</span>
<span class="p">)</span> <span class="o">+</span> <span class="n">bi</span> <span class="o">*</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">max_bad_words_len</span>
<span class="n">bad_words_list_ptrs</span> <span class="o">=</span> <span class="n">bad_words_list_ptrs</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
<span class="n">bad_words_data</span> <span class="o">=</span> <span class="p">(</span><span class="n">bad_words_list_ptrs</span><span class="p">,</span> <span class="n">bad_words_lens</span><span class="p">,</span>
<span class="n">max_bad_words_len</span><span class="p">)</span>
<span class="c1"># start context phase</span>
<span class="k">if</span> <span class="n">streaming</span><span class="p">:</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">decode_stream</span><span class="p">(</span>
<span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="p">,</span> <span class="n">sequence_lengths</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span>
<span class="n">host_context_lengths</span><span class="p">,</span> <span class="n">max_context_length</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">,</span>
<span class="n">cache_indirections</span><span class="p">,</span> <span class="n">input_ids</span><span class="p">,</span> <span class="n">hidden_states</span><span class="p">,</span>
<span class="n">prompt_embedding_table</span><span class="p">,</span> <span class="n">tasks</span><span class="p">,</span> <span class="n">prompt_vocab_size</span><span class="p">,</span> <span class="n">ite</span><span class="p">,</span>
<span class="n">sequence_limit_lengths</span><span class="p">,</span> <span class="n">stop_words_data</span><span class="p">,</span> <span class="n">bad_words_data</span><span class="p">,</span>
<span class="n">no_repeat_ngram_size</span><span class="p">,</span> <span class="n">output_sequence_lengths</span><span class="p">,</span> <span class="n">return_dict</span><span class="p">,</span>
<span class="n">encoder_output</span><span class="p">,</span> <span class="n">encoder_input_lengths</span><span class="p">,</span> <span class="n">stopping_criteria</span><span class="p">,</span>
<span class="n">logits_processor</span><span class="p">,</span> <span class="n">cross_attention_mask</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">decode_regular</span><span class="p">(</span>
<span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="p">,</span> <span class="n">sequence_lengths</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span>
<span class="n">host_context_lengths</span><span class="p">,</span> <span class="n">max_context_length</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">,</span>
<span class="n">cache_indirections</span><span class="p">,</span> <span class="n">input_ids</span><span class="p">,</span> <span class="n">hidden_states</span><span class="p">,</span>
<span class="n">prompt_embedding_table</span><span class="p">,</span> <span class="n">tasks</span><span class="p">,</span> <span class="n">prompt_vocab_size</span><span class="p">,</span> <span class="n">ite</span><span class="p">,</span>
<span class="n">sequence_limit_lengths</span><span class="p">,</span> <span class="n">stop_words_data</span><span class="p">,</span> <span class="n">bad_words_data</span><span class="p">,</span>
<span class="n">no_repeat_ngram_size</span><span class="p">,</span> <span class="n">output_sequence_lengths</span><span class="p">,</span> <span class="n">return_dict</span><span class="p">,</span>
<span class="n">encoder_output</span><span class="p">,</span> <span class="n">encoder_input_lengths</span><span class="p">,</span> <span class="n">stopping_criteria</span><span class="p">,</span>
<span class="n">logits_processor</span><span class="p">,</span> <span class="n">cross_attention_mask</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span></div>
</div>
<div class="viewcode-block" id="ChatGLMGenerationSession">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.ChatGLMGenerationSession">[docs]</a>
<span class="k">class</span> <span class="nc">ChatGLMGenerationSession</span><span class="p">(</span><span class="n">GenerationSession</span><span class="p">):</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span>
<span class="n">model_config</span><span class="p">:</span> <span class="n">ModelConfig</span><span class="p">,</span>
<span class="n">engine_buffer</span><span class="p">,</span>
<span class="n">mapping</span><span class="p">:</span> <span class="n">Mapping</span><span class="p">,</span>
<span class="n">debug_mode</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">debug_tensors_to_save</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">cuda_graph_mode</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">stream</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">Stream</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span>
<span class="n">model_config</span><span class="p">,</span>
<span class="n">engine_buffer</span><span class="p">,</span>
<span class="n">mapping</span><span class="p">,</span>
<span class="n">debug_mode</span><span class="p">,</span>
<span class="n">debug_tensors_to_save</span><span class="p">,</span>
<span class="n">cuda_graph_mode</span><span class="p">,</span>
<span class="n">stream</span><span class="p">,</span>
<span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">mask_index_tensor</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">def</span> <span class="nf">_prepare_context_inputs</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span>
<span class="n">use_gpt_attention_plugin</span><span class="p">,</span> <span class="n">remove_input_padding</span><span class="p">,</span>
<span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="n">max_context_length</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;max_context_length&#39;</span><span class="p">)</span>
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
<span class="k">if</span> <span class="n">remove_input_padding</span><span class="p">:</span>
<span class="n">input_lengths_acc</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span>
<span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">IntTensor</span><span class="p">([</span><span class="mi">0</span><span class="p">])</span><span class="o">.</span><span class="n">cuda</span><span class="p">(),</span> <span class="n">context_lengths</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">),</span>
<span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="mi">2</span><span class="p">,</span> <span class="n">input_lengths_acc</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]],</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
<span class="n">position_ids</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="n">input_lengths_acc</span><span class="p">[</span><span class="n">i</span><span class="p">]:</span><span class="n">input_lengths_acc</span><span class="p">[</span>
<span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span>
<span class="n">context_lengths</span><span class="p">[</span><span class="n">i</span><span class="p">],</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
<span class="n">position_ids</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="n">input_lengths_acc</span><span class="p">[</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]</span> <span class="o">-</span>
<span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">-</span> <span class="mi">2</span>
<span class="n">position_ids</span><span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="n">input_lengths_acc</span><span class="p">[</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]</span> <span class="o">-</span> <span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">position_ids</span><span class="o">.</span><span class="n">int</span><span class="p">()</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">last_token_ids</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">int</span><span class="p">()</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">max_context_length</span><span class="p">],</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
<span class="n">position_ids</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">,</span> <span class="p">:]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">max_context_length</span><span class="p">)</span>
<span class="c1"># specialization for GLM series models</span>
<span class="k">if</span> <span class="n">kwargs</span><span class="p">[</span><span class="s2">&quot;pad_id&quot;</span><span class="p">]</span> <span class="ow">in</span> <span class="p">[</span><span class="mi">50256</span><span class="p">,</span> <span class="mi">50259</span><span class="p">]:</span>
<span class="k">if</span> <span class="n">kwargs</span><span class="p">[</span><span class="s2">&quot;pad_id&quot;</span><span class="p">]</span> <span class="o">==</span> <span class="mi">50256</span><span class="p">:</span> <span class="c1"># glm_2b / glm_10b</span>
<span class="n">mask_ids</span> <span class="o">=</span> <span class="p">[</span><span class="mi">50260</span><span class="p">,</span> <span class="mi">50264</span><span class="p">,</span> <span class="mi">50263</span><span class="p">]</span>
<span class="k">else</span><span class="p">:</span> <span class="c1"># glm_10b_chinese / glm_large_chinese</span>
<span class="n">mask_ids</span> <span class="o">=</span> <span class="p">[</span><span class="mi">50003</span><span class="p">,</span> <span class="mi">50008</span><span class="p">,</span> <span class="mi">50009</span><span class="p">]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">mask_index_tensor</span> <span class="o">=</span> \
<span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
<span class="n">length</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
<span class="n">input_ids</span> <span class="o">=</span> <span class="n">kwargs</span><span class="p">[</span><span class="s2">&quot;input_ids&quot;</span><span class="p">][</span><span class="n">i</span><span class="p">]</span>
<span class="n">mask_index</span> <span class="o">=</span> <span class="p">[</span>
<span class="n">torch</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">input_ids</span> <span class="o">==</span> <span class="nb">id</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">int</span><span class="p">()</span> <span class="k">for</span> <span class="nb">id</span> <span class="ow">in</span> <span class="n">mask_ids</span>
<span class="p">]</span>
<span class="n">tail_index</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">([</span><span class="n">max_context_length</span><span class="p">])</span><span class="o">.</span><span class="n">int</span><span class="p">()</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
<span class="n">mask_index</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">tail_index</span><span class="p">)</span>
<span class="n">mask_index</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="n">mask_index</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">min</span><span class="p">()</span>
<span class="n">position_ids</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">length</span> <span class="o">-</span> <span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">mask_index</span><span class="p">)</span>
<span class="n">position_ids</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">length</span> <span class="o">-</span> <span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span>
<span class="bp">self</span><span class="o">.</span><span class="n">mask_index_tensor</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">mask_index</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
<span class="n">length</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
<span class="n">position_ids</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">length</span> <span class="o">-</span> <span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="n">length</span> <span class="o">-</span> <span class="mi">2</span>
<span class="n">position_ids</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">length</span> <span class="o">-</span> <span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">position_ids</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
<span class="n">inputs</span> <span class="o">=</span> <span class="p">{</span>
<span class="s1">&#39;position_ids&#39;</span><span class="p">:</span> <span class="n">position_ids</span><span class="p">,</span>
<span class="s1">&#39;last_token_ids&#39;</span><span class="p">:</span> <span class="n">last_token_ids</span>
<span class="p">}</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
<span class="n">attention_mask</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
<span class="n">inputs</span><span class="p">[</span><span class="s1">&#39;attention_mask&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">attention_mask</span>
<span class="k">return</span> <span class="n">inputs</span>
<span class="k">def</span> <span class="nf">_prepare_generation_inputs</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span>
<span class="n">use_gpt_attention_plugin</span><span class="p">,</span>
<span class="n">remove_input_padding</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="n">step</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;step&#39;</span><span class="p">)</span>
<span class="n">num_beams</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;num_beams&#39;</span><span class="p">)</span>
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones_like</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">)</span>
<span class="k">if</span> <span class="n">remove_input_padding</span><span class="p">:</span>
<span class="k">def</span> <span class="nf">_tile_beam_width_chatglm</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">num_beams</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
<span class="n">new_shape</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="n">new_shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="n">new_shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="n">num_beams</span>
<span class="n">tile_size</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">new_shape</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
<span class="n">tile_size</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">insert</span><span class="p">(</span><span class="n">tile_size</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">num_beams</span><span class="p">)</span>
<span class="n">new_tensor</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
<span class="n">new_tensor</span> <span class="o">=</span> <span class="n">new_tensor</span><span class="o">.</span><span class="n">tile</span><span class="p">(</span><span class="n">tile_size</span><span class="o">.</span><span class="n">tolist</span><span class="p">())</span>
<span class="n">new_tensor</span> <span class="o">=</span> <span class="n">new_tensor</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">new_shape</span><span class="o">.</span><span class="n">tolist</span><span class="p">())</span>
<span class="k">return</span> <span class="n">new_tensor</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="mi">2</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
<span class="n">position_ids</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="n">num_beams</span><span class="p">]</span> <span class="o">-</span> <span class="mi">2</span>
<span class="n">position_ids</span><span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">step</span> <span class="o">+</span> <span class="mi">2</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">_tile_beam_width_chatglm</span><span class="p">(</span><span class="n">position_ids</span><span class="p">,</span> <span class="n">num_beams</span><span class="p">)</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">position_ids</span><span class="o">.</span><span class="n">int</span><span class="p">()</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">last_token_ids</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">int</span><span class="p">()</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">data</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mask_index_tensor</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> <span class="c1"># specialization for GLM series models</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
<span class="n">data</span><span class="o">.</span><span class="n">append</span><span class="p">([[</span><span class="bp">self</span><span class="o">.</span><span class="n">mask_index_tensor</span><span class="p">[</span><span class="n">i</span><span class="p">]],</span> <span class="p">[</span><span class="n">step</span> <span class="o">+</span> <span class="mi">2</span><span class="p">]])</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
<span class="n">data</span><span class="o">.</span><span class="n">append</span><span class="p">([[</span><span class="n">context_lengths</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="n">num_beams</span><span class="p">]</span> <span class="o">-</span> <span class="mi">2</span><span class="p">],</span>
<span class="p">[</span><span class="n">step</span> <span class="o">+</span> <span class="mi">2</span><span class="p">]])</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">_tile_beam_width</span><span class="p">(</span><span class="n">position_ids</span><span class="p">,</span> <span class="n">num_beams</span><span class="p">)</span>
<span class="n">inputs</span> <span class="o">=</span> <span class="p">{</span>
<span class="s1">&#39;position_ids&#39;</span><span class="p">:</span> <span class="n">position_ids</span><span class="p">,</span>
<span class="s1">&#39;last_token_ids&#39;</span><span class="p">:</span> <span class="n">last_token_ids</span>
<span class="p">}</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
<span class="n">attention_mask</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
<span class="n">inputs</span><span class="p">[</span><span class="s1">&#39;attention_mask&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">attention_mask</span>
<span class="k">return</span> <span class="n">inputs</span></div>
<div class="viewcode-block" id="QWenForCausalLMGenerationSession">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.QWenForCausalLMGenerationSession">[docs]</a>
<span class="k">class</span> <span class="nc">QWenForCausalLMGenerationSession</span><span class="p">(</span><span class="n">GenerationSession</span><span class="p">):</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span>
<span class="n">model_config</span><span class="p">:</span> <span class="n">ModelConfig</span><span class="p">,</span>
<span class="n">engine_buffer</span><span class="p">,</span>
<span class="n">mapping</span><span class="p">:</span> <span class="n">Mapping</span><span class="p">,</span>
<span class="n">debug_mode</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">debug_tensors_to_save</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">cuda_graph_mode</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">stream</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">Stream</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">global_max_input_length</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">2048</span><span class="p">,</span>
<span class="n">global_max_output_length</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">4096</span><span class="p">,</span>
<span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">model_config</span><span class="p">,</span>
<span class="n">engine_buffer</span><span class="p">,</span>
<span class="n">mapping</span><span class="p">,</span>
<span class="n">debug_mode</span><span class="p">,</span>
<span class="n">debug_tensors_to_save</span><span class="o">=</span><span class="n">debug_tensors_to_save</span><span class="p">,</span>
<span class="n">cuda_graph_mode</span><span class="o">=</span><span class="n">cuda_graph_mode</span><span class="p">,</span>
<span class="n">stream</span><span class="o">=</span><span class="n">stream</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">global_max_input_length</span> <span class="o">=</span> <span class="n">global_max_input_length</span>
<span class="bp">self</span><span class="o">.</span><span class="n">global_max_output_length</span> <span class="o">=</span> <span class="n">global_max_output_length</span>
<div class="viewcode-block" id="QWenForCausalLMGenerationSession.generate">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.QWenForCausalLMGenerationSession.generate">[docs]</a>
<span class="k">def</span> <span class="nf">generate</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span>
<span class="n">input_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">input_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">sampling_config</span><span class="p">:</span> <span class="n">SamplingConfig</span><span class="p">,</span>
<span class="n">max_new_tokens</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">runtime_rank</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
<span class="p">):</span>
<span class="n">max_input_length</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">input_lengths</span><span class="p">)</span><span class="o">.</span><span class="n">item</span><span class="p">()</span>
<span class="n">max_new_tokens</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">max_new_tokens</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">global_max_output_length</span> <span class="o">-</span> <span class="n">max_input_length</span><span class="p">)</span>
<span class="c1"># setup batch_size, max_input_length, max_output_len</span>
<span class="bp">self</span><span class="o">.</span><span class="n">setup</span><span class="p">(</span><span class="n">batch_size</span><span class="o">=</span><span class="n">input_lengths</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
<span class="n">max_context_length</span><span class="o">=</span><span class="n">max_input_length</span><span class="p">,</span>
<span class="n">max_new_tokens</span><span class="o">=</span><span class="n">max_new_tokens</span><span class="p">)</span>
<span class="n">output_ids</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">decode</span><span class="p">(</span><span class="n">input_ids</span><span class="p">,</span> <span class="n">input_lengths</span><span class="p">,</span> <span class="n">sampling_config</span><span class="p">)</span>
<span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">synchronize</span><span class="p">()</span>
<span class="k">if</span> <span class="n">runtime_rank</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">outputs</span> <span class="o">=</span> <span class="n">output_ids</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">,</span> <span class="p">:]</span>
<span class="k">return</span> <span class="n">outputs</span></div>
</div>
<div class="viewcode-block" id="MambaLMHeadModelGenerationSession">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.MambaLMHeadModelGenerationSession">[docs]</a>
<span class="k">class</span> <span class="nc">MambaLMHeadModelGenerationSession</span><span class="p">(</span><span class="n">GenerationSession</span><span class="p">):</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span>
<span class="n">model_config</span><span class="p">:</span> <span class="n">ModelConfig</span><span class="p">,</span>
<span class="n">engine_buffer</span><span class="p">,</span>
<span class="n">mapping</span><span class="p">:</span> <span class="n">Mapping</span><span class="p">,</span>
<span class="n">debug_mode</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">debug_tensors_to_save</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">cuda_graph_mode</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">stream</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">Stream</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="p">):</span>
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">model_config</span><span class="p">,</span> <span class="n">ModelConfig</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span> <span class="o">=</span> <span class="n">model_config</span>
<span class="bp">self</span><span class="o">.</span><span class="n">mapping</span> <span class="o">=</span> <span class="n">mapping</span>
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span> <span class="o">=</span> <span class="n">_Runtime</span><span class="p">(</span><span class="n">engine_buffer</span><span class="p">,</span> <span class="n">mapping</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">device</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span>
<span class="sa">f</span><span class="s1">&#39;cuda:</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">runtime_rank</span><span class="w"> </span><span class="o">%</span><span class="w"> </span><span class="n">mapping</span><span class="o">.</span><span class="n">gpus_per_node</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">set_device</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="c1"># dynamic_decoder currently use torch&#39;s current stream, so must let TRT enqueue use same stream here</span>
<span class="bp">self</span><span class="o">.</span><span class="n">stream</span> <span class="o">=</span> <span class="n">stream</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">stream</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">stream</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">Stream</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">set_stream</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">stream</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">debug_mode</span> <span class="o">=</span> <span class="n">debug_mode</span>
<span class="bp">self</span><span class="o">.</span><span class="n">debug_tensors_to_save</span> <span class="o">=</span> <span class="n">debug_tensors_to_save</span>
<span class="bp">self</span><span class="o">.</span><span class="n">cuda_graph_mode</span> <span class="o">=</span> <span class="n">cuda_graph_mode</span>
<span class="c1"># Optional inputs for dynamic decoder</span>
<span class="bp">self</span><span class="o">.</span><span class="n">top_p_decay</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">top_p_min</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">top_p_reset_ids</span> <span class="o">=</span> <span class="kc">None</span>
<span class="c1"># TODO: in tensorrt_llm/cpp/tensorrt_llm/thop/dynamicDecodeOp.cpp it&#39;s T, can be float or half?</span>
<span class="bp">self</span><span class="o">.</span><span class="n">embedding_bias_opt</span> <span class="o">=</span> <span class="kc">None</span>
<span class="c1"># use one more block in paged kv cache.</span>
<span class="bp">self</span><span class="o">.</span><span class="n">use_one_more_block</span> <span class="o">=</span> <span class="kc">False</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer_allocated</span> <span class="o">=</span> <span class="kc">False</span>
<span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span> <span class="o">=</span> <span class="n">pad_vocab_size</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">vocab_size</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">decoder_logits_dtype</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_tensor_dtype</span><span class="p">(</span><span class="s1">&#39;logits&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder_logits_dtype</span> <span class="ow">not</span> <span class="ow">in</span> <span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">]:</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
<span class="s2">&quot;Logits dtype not supported by decoder. Falling back to float32. You may want to change the logits dtype to float16 in your model definition.&quot;</span>
<span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">decoder_logits_dtype</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span>
<span class="bp">self</span><span class="o">.</span><span class="n">dynamic_decoder</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">classes</span><span class="o">.</span><span class="n">trtllm</span><span class="o">.</span><span class="n">DynamicDecodeOp</span><span class="p">(</span>
<span class="n">model_config</span><span class="o">.</span><span class="n">max_batch_size</span><span class="p">,</span> <span class="n">model_config</span><span class="o">.</span><span class="n">max_beam_width</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">vocab_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder_logits_dtype</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">gather_tree</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">tensorrt_llm</span><span class="o">.</span><span class="n">gather_tree</span>
<span class="n">expected_tensor_names</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">&#39;input_ids&#39;</span><span class="p">]</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">&#39;logits&#39;</span><span class="p">]</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">&#39;host_request_types&#39;</span><span class="p">]</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">&#39;host_context_lengths&#39;</span><span class="p">]</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">&#39;last_token_ids&#39;</span><span class="p">]</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_state</span><span class="p">:</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
<span class="sa">f</span><span class="s1">&#39;conv_state_ptr_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">)</span>
<span class="p">]</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
<span class="sa">f</span><span class="s1">&#39;ssm_state_ptr_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">)</span>
<span class="p">]</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">&#39;slot_mapping&#39;</span><span class="p">]</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
<span class="sa">f</span><span class="s1">&#39;past_conv_state_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">)</span>
<span class="p">]</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
<span class="sa">f</span><span class="s1">&#39;present_conv_state_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">)</span>
<span class="p">]</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
<span class="sa">f</span><span class="s1">&#39;past_ssm_state_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">)</span>
<span class="p">]</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
<span class="sa">f</span><span class="s1">&#39;present_ssm_state_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">)</span>
<span class="p">]</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span> <span class="o">&gt;</span> <span class="mi">1</span> <span class="ow">and</span> <span class="n">model_config</span><span class="o">.</span><span class="n">use_custom_all_reduce</span><span class="p">:</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">&#39;all_reduce_workspace&#39;</span><span class="p">]</span>
<span class="n">found_tensor_names</span> <span class="o">=</span> <span class="p">[</span>
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_name</span><span class="p">(</span><span class="n">i</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">num_io_tensors</span><span class="p">)</span>
<span class="p">]</span>
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">debug_mode</span> <span class="ow">and</span> <span class="nb">set</span><span class="p">(</span><span class="n">expected_tensor_names</span><span class="p">)</span> <span class="o">!=</span> <span class="nb">set</span><span class="p">(</span>
<span class="n">found_tensor_names</span><span class="p">):</span>
<span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span>
<span class="sa">f</span><span class="s1">&#39;self.remove_input_padding=</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="si">}</span><span class="s1">, self.paged_state=</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">paged_state</span><span class="si">}</span><span class="s1">&#39;</span>
<span class="p">)</span>
<span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;The following expected tensors are not found: </span><span class="si">{</span><span class="nb">set</span><span class="p">(</span><span class="n">expected_tensor_names</span><span class="p">)</span><span class="o">.</span><span class="n">difference</span><span class="p">(</span><span class="nb">set</span><span class="p">(</span><span class="n">found_tensor_names</span><span class="p">))</span><span class="si">}</span><span class="s2">&quot;</span>
<span class="p">)</span>
<span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;Those tensors in engine are not expected: </span><span class="si">{</span><span class="nb">set</span><span class="p">(</span><span class="n">found_tensor_names</span><span class="p">)</span><span class="o">.</span><span class="n">difference</span><span class="p">(</span><span class="nb">set</span><span class="p">(</span><span class="n">expected_tensor_names</span><span class="p">))</span><span class="si">}</span><span class="s2">&quot;</span>
<span class="p">)</span>
<span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Expected tensor names: </span><span class="si">{</span><span class="n">expected_tensor_names</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Found tensor names: </span><span class="si">{</span><span class="n">found_tensor_names</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span>
<span class="s2">&quot;Tensor names in engine are not the same as expected.&quot;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">debug_mode</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">debug_tensors</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span>
<span class="nb">set</span><span class="p">(</span><span class="n">found_tensor_names</span><span class="p">)</span> <span class="o">-</span> <span class="nb">set</span><span class="p">(</span><span class="n">expected_tensor_names</span><span class="p">))</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">mamba_d_state</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">mamba_d_state</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">mamba_d_conv</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">mamba_d_conv</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">mamba_expand</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">mamba_expand</span>
<div class="viewcode-block" id="MambaLMHeadModelGenerationSession.setup">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.MambaLMHeadModelGenerationSession.setup">[docs]</a>
<span class="k">def</span> <span class="nf">setup</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">max_context_length</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">max_new_tokens</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">beam_width</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
<span class="n">max_attention_window_size</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">sink_token_length</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">encoder_max_input_length</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">lora_manager</span><span class="p">:</span> <span class="n">LoraManager</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">lora_uids</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">medusa_choices</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
<span class="c1"># Store these params related to buffer size to check against</span>
<span class="c1"># the input shape with the params given in decode()</span>
<span class="k">assert</span> <span class="n">beam_width</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span> <span class="s2">&quot;Only support beam width = 1 now.&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">=</span> <span class="n">batch_size</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_context_length</span> <span class="o">=</span> <span class="n">max_context_length</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_new_tokens</span> <span class="o">=</span> <span class="n">max_new_tokens</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span> <span class="o">=</span> <span class="n">max_context_length</span> <span class="o">+</span> <span class="n">max_new_tokens</span>
<span class="bp">self</span><span class="o">.</span><span class="n">mamba_d_inner</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mamba_expand</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_width</span> <span class="o">=</span> <span class="n">beam_width</span>
<span class="bp">self</span><span class="o">.</span><span class="n">sink_token_length</span> <span class="o">=</span> <span class="mi">0</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span> <span class="o">=</span> <span class="p">{}</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;logits&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span>
<span class="p">(</span><span class="n">batch_size</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">)</span> <span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_context_logits</span> <span class="k">else</span>
<span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">max_context_length</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_tensor_dtype</span><span class="p">(</span><span class="s1">&#39;logits&#39;</span><span class="p">),</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_mamba_conv1d_plugin</span><span class="p">:</span>
<span class="n">conv_state_shape</span> <span class="o">=</span> <span class="p">(</span>
<span class="n">batch_size</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">mamba_d_conv</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">mamba_d_inner</span><span class="p">,</span>
<span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">conv_state_shape</span> <span class="o">=</span> <span class="p">(</span>
<span class="n">batch_size</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">mamba_d_inner</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">mamba_d_conv</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">ssm_state_shape</span> <span class="o">=</span> <span class="p">(</span>
<span class="n">batch_size</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">mamba_d_state</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">mamba_d_inner</span><span class="p">,</span>
<span class="p">)</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">):</span>
<span class="c1"># we need two set of kv cache buffers, one for inputs, and the other for outputs.</span>
<span class="c1"># They will take turns to act as input and output buffers.</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dtype</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;present_conv_state_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span>
<span class="n">conv_state_shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;1_present_conv_state_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span>
<span class="n">conv_state_shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;present_ssm_state_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span>
<span class="n">ssm_state_shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_state</span><span class="p">:</span>
<span class="n">conv_state_ptr</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span>
<span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;present_conv_state_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">()],</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int64</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="s1">&#39;cpu&#39;</span><span class="p">)</span>
<span class="n">ssm_state_ptr</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span>
<span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;present_ssm_state_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">()],</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int64</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="s1">&#39;cpu&#39;</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;conv_state_ptr_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">conv_state_ptr</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;ssm_state_ptr_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">ssm_state_ptr</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer_allocated</span> <span class="o">=</span> <span class="kc">True</span></div>
<span class="k">def</span> <span class="nf">_get_context_shape_buffer</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span>
<span class="n">input_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">context_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">host_context_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">position_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">last_token_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">attention_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">cross_attention_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">cache_indirection</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">kv_cache_block_pointers</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span>
<span class="n">host_kv_cache_block_pointers</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span>
<span class="n">hidden_states_input</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">prompt_embedding_table</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">tasks</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">prompt_vocab_size</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">encoder_output</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">encoder_input_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="n">RuntimeTensor</span><span class="p">]:</span>
<span class="n">tensors</span> <span class="o">=</span> <span class="p">{}</span>
<span class="k">def</span> <span class="nf">sym</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">name</span><span class="p">):</span>
<span class="k">return</span> <span class="n">RuntimeTensor</span><span class="o">.</span><span class="n">from_torch</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">add_tensor</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">name</span><span class="p">):</span>
<span class="k">return</span> <span class="n">tensors</span><span class="o">.</span><span class="n">update</span><span class="p">({</span><span class="n">name</span><span class="p">:</span> <span class="n">sym</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">name</span><span class="p">)})</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">input_ids</span><span class="p">,</span> <span class="s1">&#39;input_ids&#39;</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;logits&#39;</span><span class="p">],</span> <span class="s1">&#39;logits&#39;</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">last_token_ids</span><span class="p">,</span> <span class="s1">&#39;last_token_ids&#39;</span><span class="p">)</span>
<span class="n">batch_size</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">):</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_state</span><span class="p">:</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;conv_state_ptr_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">],</span>
<span class="sa">f</span><span class="s1">&#39;conv_state_ptr_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;ssm_state_ptr_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">],</span>
<span class="sa">f</span><span class="s1">&#39;ssm_state_ptr_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="c1"># conv state</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_tensor_dtype</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;present_conv_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_mamba_conv1d_plugin</span><span class="p">:</span>
<span class="n">conv_state_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mamba_d_conv</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">mamba_d_inner</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">conv_state_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mamba_d_inner</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">mamba_d_conv</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">conv_state</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">conv_state_shape</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">conv_state</span><span class="p">,</span> <span class="sa">f</span><span class="s1">&#39;past_conv_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="n">present</span> <span class="o">=</span> <span class="sa">f</span><span class="s1">&#39;present_conv_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="n">present</span><span class="p">],</span> <span class="n">present</span><span class="p">)</span>
<span class="c1"># ssm state</span>
<span class="n">ssm_state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;present_ssm_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">]</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">ssm_state</span><span class="p">,</span> <span class="sa">f</span><span class="s1">&#39;past_ssm_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">ssm_state</span><span class="p">,</span> <span class="sa">f</span><span class="s1">&#39;present_ssm_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="c1"># context request</span>
<span class="n">host_request_types</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="s1">&#39;cpu&#39;</span><span class="p">)</span><span class="o">.</span><span class="n">int</span><span class="p">()</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">host_request_types</span><span class="p">,</span> <span class="s1">&#39;host_request_types&#39;</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">host_context_lengths</span><span class="p">,</span> <span class="s1">&#39;host_context_lengths&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_state</span><span class="p">:</span>
<span class="n">slot_mapping</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span>
<span class="n">batch_size</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">slot_mapping</span><span class="p">,</span> <span class="s1">&#39;slot_mapping&#39;</span><span class="p">)</span>
<span class="c1"># all reduce</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_custom_all_reduce</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">all_reduce_workspace</span><span class="p">,</span> <span class="s1">&#39;all_reduce_workspace&#39;</span><span class="p">)</span>
<span class="k">return</span> <span class="n">tensors</span>
<span class="k">def</span> <span class="nf">_get_next_step_shape_buffer</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span>
<span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">beam_width</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">max_context_length</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">step</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">context_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">host_context_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">position_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">last_token_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">attention_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">cross_attention_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">cache_indirection</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">kv_cache_block_pointers</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span>
<span class="n">host_kv_cache_block_pointers</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span>
<span class="n">hidden_states_input</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">prompt_embedding_table</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">tasks</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">prompt_vocab_size</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">encoder_output</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">encoder_input_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
<span class="n">tensors</span> <span class="o">=</span> <span class="p">{}</span> <span class="c1"># Dict[str, RuntimeTensor]</span>
<span class="k">def</span> <span class="nf">sym</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">name</span><span class="p">):</span>
<span class="k">return</span> <span class="n">RuntimeTensor</span><span class="o">.</span><span class="n">from_torch</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">add_tensor</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">name</span><span class="p">):</span>
<span class="k">return</span> <span class="n">tensors</span><span class="o">.</span><span class="n">update</span><span class="p">({</span><span class="n">name</span><span class="p">:</span> <span class="n">sym</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">name</span><span class="p">)})</span>
<span class="k">def</span> <span class="nf">add_tensor_with_shape</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">name</span><span class="p">,</span> <span class="n">shape</span><span class="p">):</span>
<span class="k">return</span> <span class="n">tensors</span><span class="o">.</span><span class="n">update</span><span class="p">(</span>
<span class="p">{</span><span class="n">name</span><span class="p">:</span> <span class="n">RuntimeTensor</span><span class="o">.</span><span class="n">from_torch</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">override_shape</span><span class="o">=</span><span class="n">shape</span><span class="p">)})</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
<span class="n">input_ids_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">beam_width</span><span class="p">,</span> <span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">input_ids_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">beam_width</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span><span class="p">,</span> <span class="s1">&#39;input_ids&#39;</span><span class="p">,</span> <span class="n">input_ids_shape</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;logits&#39;</span><span class="p">],</span> <span class="s1">&#39;logits&#39;</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">last_token_ids</span><span class="p">,</span> <span class="s1">&#39;last_token_ids&#39;</span><span class="p">)</span>
<span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">):</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_state</span><span class="p">:</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;conv_state_ptr_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">],</span>
<span class="sa">f</span><span class="s1">&#39;conv_state_ptr_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;ssm_state_ptr_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">],</span>
<span class="sa">f</span><span class="s1">&#39;ssm_state_ptr_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="c1"># conv state</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_mamba_conv1d_plugin</span><span class="p">:</span>
<span class="n">conv_state_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mamba_d_conv</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">mamba_d_inner</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">conv_state_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mamba_d_inner</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">mamba_d_conv</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span>
<span class="k">if</span> <span class="n">step</span> <span class="o">%</span> <span class="mi">2</span><span class="p">:</span>
<span class="n">add_tensor_with_shape</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;1_present_conv_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">],</span>
<span class="sa">f</span><span class="s1">&#39;past_conv_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">,</span> <span class="n">conv_state_shape</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;present_conv_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">],</span>
<span class="sa">f</span><span class="s1">&#39;present_conv_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">add_tensor_with_shape</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;present_conv_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">],</span>
<span class="sa">f</span><span class="s1">&#39;past_conv_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">,</span> <span class="n">conv_state_shape</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;1_present_conv_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">],</span>
<span class="sa">f</span><span class="s1">&#39;present_conv_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="c1"># ssm state</span>
<span class="n">ssm_state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;present_ssm_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">]</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">ssm_state</span><span class="p">,</span> <span class="sa">f</span><span class="s1">&#39;past_ssm_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">ssm_state</span><span class="p">,</span> <span class="sa">f</span><span class="s1">&#39;present_ssm_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="c1"># generation requests</span>
<span class="n">host_request_types</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones_like</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="s1">&#39;cpu&#39;</span><span class="p">)</span><span class="o">.</span><span class="n">int</span><span class="p">()</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">host_request_types</span><span class="p">,</span> <span class="s1">&#39;host_request_types&#39;</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">host_context_lengths</span><span class="p">,</span> <span class="s1">&#39;host_context_lengths&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_state</span><span class="p">:</span>
<span class="n">slot_mapping</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span>
<span class="n">batch_size</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">slot_mapping</span><span class="p">,</span> <span class="s1">&#39;slot_mapping&#39;</span><span class="p">)</span>
<span class="c1"># all reduce</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_custom_all_reduce</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">all_reduce_workspace</span><span class="p">,</span> <span class="s1">&#39;all_reduce_workspace&#39;</span><span class="p">)</span>
<span class="k">return</span> <span class="n">tensors</span>
<span class="k">def</span> <span class="nf">_prepare_context_inputs</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span>
<span class="n">host_context_lengths</span><span class="p">,</span> <span class="n">use_gpt_attention_plugin</span><span class="p">,</span>
<span class="n">remove_input_padding</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
<span class="k">if</span> <span class="n">remove_input_padding</span><span class="p">:</span>
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">last_token_ids</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">int</span><span class="p">()</span>
<span class="n">ret</span> <span class="o">=</span> <span class="p">{</span><span class="s1">&#39;last_token_ids&#39;</span><span class="p">:</span> <span class="n">last_token_ids</span><span class="p">}</span>
<span class="k">return</span> <span class="n">ret</span>
<span class="k">def</span> <span class="nf">_prepare_generation_inputs</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span>
<span class="n">use_gpt_attention_plugin</span><span class="p">,</span>
<span class="n">remove_input_padding</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones_like</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">)</span>
<span class="k">if</span> <span class="n">remove_input_padding</span><span class="p">:</span>
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">last_token_ids</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">int</span><span class="p">()</span>
<span class="n">ret</span> <span class="o">=</span> <span class="p">{</span><span class="s1">&#39;last_token_ids&#39;</span><span class="p">:</span> <span class="n">last_token_ids</span><span class="p">}</span>
<span class="k">return</span> <span class="n">ret</span>
<span class="k">def</span> <span class="nf">_capture_cuda_graph_and_instantiate</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">,</span> <span class="n">stream</span><span class="p">,</span> <span class="n">step</span><span class="p">):</span>
<span class="n">instance_idx</span> <span class="o">=</span> <span class="p">(</span><span class="n">step</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">%</span> <span class="mi">2</span>
<span class="c1"># Mamba model I/O shape isn&#39;t changed in generation phase</span>
<span class="c1"># Create two cuda graph once.If cuda graph has already existed, skip it.</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">cuda_graph_instances</span><span class="p">[</span><span class="n">instance_idx</span><span class="p">]</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">return</span>
<span class="c1"># WAR for TRT 9.x</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">preview_trt_version</span><span class="p">()</span> <span class="ow">and</span> <span class="n">step</span> <span class="o">&lt;</span> <span class="mi">3</span><span class="p">:</span>
<span class="k">return</span>
<span class="k">return</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">_capture_cuda_graph_and_instantiate</span><span class="p">(</span>
<span class="n">context</span><span class="p">,</span> <span class="n">stream</span><span class="p">,</span> <span class="n">step</span><span class="p">)</span></div>
</pre></div>
</div>
</div>
<footer>
<hr/>
<div role="contentinfo">
<p>&#169; Copyright 2023, NVidia.</p>
</div>
Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
<a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
provided by <a href="https://readthedocs.org">Read the Docs</a>.
</footer>
</div>
</div>
</section>
</div>
<script>
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
</script>
</body>
</html>