mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
3858 lines
645 KiB
HTML
3858 lines
645 KiB
HTML
<!DOCTYPE html>
|
||
<html class="writer-html5" lang="en" data-content_root="../../../">
|
||
<head>
|
||
<meta charset="utf-8" />
|
||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||
<title>tensorrt_llm.runtime.generation — tensorrt_llm documentation</title>
|
||
<link rel="stylesheet" type="text/css" href="../../../_static/pygments.css?v=80d5e7a1" />
|
||
<link rel="stylesheet" type="text/css" href="../../../_static/css/theme.css?v=19f00094" />
|
||
|
||
|
||
<!--[if lt IE 9]>
|
||
<script src="../../../_static/js/html5shiv.min.js"></script>
|
||
<![endif]-->
|
||
|
||
<script src="../../../_static/jquery.js?v=5d32c60e"></script>
|
||
<script src="../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||
<script src="../../../_static/documentation_options.js?v=5929fcd5"></script>
|
||
<script src="../../../_static/doctools.js?v=888ff710"></script>
|
||
<script src="../../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||
<script src="../../../_static/js/theme.js"></script>
|
||
<link rel="index" title="Index" href="../../../genindex.html" />
|
||
<link rel="search" title="Search" href="../../../search.html" />
|
||
</head>
|
||
|
||
<body class="wy-body-for-nav">
|
||
<div class="wy-grid-for-nav">
|
||
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
|
||
<div class="wy-side-scroll">
|
||
<div class="wy-side-nav-search" >
|
||
|
||
|
||
|
||
<a href="../../../index.html" class="icon icon-home">
|
||
tensorrt_llm
|
||
</a>
|
||
<div role="search">
|
||
<form id="rtd-search-form" class="wy-form" action="../../../search.html" method="get">
|
||
<input type="text" name="q" placeholder="Search docs" aria-label="Search docs" />
|
||
<input type="hidden" name="check_keywords" value="yes" />
|
||
<input type="hidden" name="area" value="default" />
|
||
</form>
|
||
</div>
|
||
</div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
|
||
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../overview.html">Overview</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../quick-start-guide.html">Quick Start Guide</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../release-notes.html">Release Notes</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">Installation</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../installation/linux.html">Installing on Linux</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../installation/build-from-source-linux.html">Building from Source Code on Linux</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../installation/windows.html">Installing on Windows</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../installation/build-from-source-windows.html">Building from Source Code on Windows</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">Architecture</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../architecture/overview.html">TensorRT-LLM Architecture</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../architecture/core-concepts.html">Model Definition</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../architecture/core-concepts.html#compilation">Compilation</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../architecture/core-concepts.html#runtime">Runtime</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../architecture/core-concepts.html#multi-gpu-and-multi-node-support">Multi-GPU and Multi-Node Support</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../architecture/checkpoint.html">TensorRT-LLM Checkpoint</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../architecture/workflow.html">TensorRT-LLM Build Workflow</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../architecture/add-model.html">Adding a Model</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">Advanced</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/gpt-attention.html">Multi-Head, Multi-Query, and Group-Query Attention</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/gpt-runtime.html">C++ GPT Runtime</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/graph-rewriting.html">Graph Rewriting Module</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/batch-manager.html">The Batch Manager in TensorRT-LLM</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/inference-request.html">Inference Request</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/lora.html">Run gpt-2b + LoRA using GptManager / cpp runtime</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../performance/perf-overview.html">Overview</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../performance/perf-best-practices.html">Best Practices for Tuning the Performance of TensorRT-LLM</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../performance/perf-analysis.html">Performance Analysis</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">Reference</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../reference/troubleshooting.html">Troubleshooting</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../reference/support-matrix.html">Support Matrix</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../reference/precision.html">Numerical Precision</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../reference/memory.html">Memory Usage of TensorRT-LLM</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">C++ API</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../_cpp_gen/runtime.html">Runtime</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../python-api/tensorrt_llm.layers.html">Layers</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../python-api/tensorrt_llm.functional.html">Functionals</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../python-api/tensorrt_llm.models.html">Models</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../python-api/tensorrt_llm.plugin.html">Plugin</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../python-api/tensorrt_llm.quantization.html">Quantization</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../python-api/tensorrt_llm.runtime.html">Runtime</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">Blogs</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/H100vsA100.html">H100 has 4.6x A100 Performance in TensorRT-LLM, achieving 10,000 tok/s at 100ms to first token</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/H200launch.html">H200 achieves nearly 12,000 tokens/sec on Llama2-13B with TensorRT-LLM</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/Falcon180B-H200.html">Falcon-180B on a single H200 GPU with INT4 AWQ, and 6.7x faster Llama-70B over A100</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/quantization-in-TRT-LLM.html">Speed up inference with SOTA quantization techniques in TRT-LLM</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/XQA-kernel.html">New XQA-kernel provides 2.4x more Llama-70B throughput within the same latency budget</a></li>
|
||
</ul>
|
||
|
||
</div>
|
||
</div>
|
||
</nav>
|
||
|
||
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
|
||
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
|
||
<a href="../../../index.html">tensorrt_llm</a>
|
||
</nav>
|
||
|
||
<div class="wy-nav-content">
|
||
<div class="rst-content">
|
||
<div role="navigation" aria-label="Page navigation">
|
||
<ul class="wy-breadcrumbs">
|
||
<li><a href="../../../index.html" class="icon icon-home" aria-label="Home"></a></li>
|
||
<li class="breadcrumb-item"><a href="../../index.html">Module code</a></li>
|
||
<li class="breadcrumb-item active">tensorrt_llm.runtime.generation</li>
|
||
<li class="wy-breadcrumbs-aside">
|
||
</li>
|
||
</ul>
|
||
<hr/>
|
||
</div>
|
||
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
|
||
<div itemprop="articleBody">
|
||
|
||
<h1>Source code for tensorrt_llm.runtime.generation</h1><div class="highlight"><pre>
|
||
<span></span><span class="c1"># SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.</span>
|
||
<span class="c1"># SPDX-License-Identifier: Apache-2.0</span>
|
||
<span class="c1">#</span>
|
||
<span class="c1"># Licensed under the Apache License, Version 2.0 (the "License");</span>
|
||
<span class="c1"># you may not use this file except in compliance with the License.</span>
|
||
<span class="c1"># You may obtain a copy of the License at</span>
|
||
<span class="c1">#</span>
|
||
<span class="c1"># http://www.apache.org/licenses/LICENSE-2.0</span>
|
||
<span class="c1">#</span>
|
||
<span class="c1"># Unless required by applicable law or agreed to in writing, software</span>
|
||
<span class="c1"># distributed under the License is distributed on an "AS IS" BASIS,</span>
|
||
<span class="c1"># WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.</span>
|
||
<span class="c1"># See the License for the specific language governing permissions and</span>
|
||
<span class="c1"># limitations under the License.</span>
|
||
|
||
<span class="kn">import</span> <span class="nn">copy</span>
|
||
<span class="kn">import</span> <span class="nn">csv</span>
|
||
<span class="kn">import</span> <span class="nn">math</span>
|
||
<span class="kn">from</span> <span class="nn">dataclasses</span> <span class="kn">import</span> <span class="n">dataclass</span><span class="p">,</span> <span class="n">field</span>
|
||
<span class="kn">from</span> <span class="nn">functools</span> <span class="kn">import</span> <span class="n">reduce</span><span class="p">,</span> <span class="n">wraps</span>
|
||
<span class="kn">from</span> <span class="nn">pathlib</span> <span class="kn">import</span> <span class="n">Path</span>
|
||
<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">Iterable</span><span class="p">,</span> <span class="n">List</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Sequence</span><span class="p">,</span> <span class="n">Union</span>
|
||
|
||
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
|
||
<span class="kn">import</span> <span class="nn">tensorrt</span> <span class="k">as</span> <span class="nn">trt</span>
|
||
|
||
<span class="c1"># isort: off</span>
|
||
<span class="kn">import</span> <span class="nn">torch</span>
|
||
<span class="kn">import</span> <span class="nn">tensorrt</span> <span class="k">as</span> <span class="nn">trt</span>
|
||
<span class="c1"># isort: on</span>
|
||
<span class="kn">from</span> <span class="nn">cuda</span> <span class="kn">import</span> <span class="n">cudart</span>
|
||
|
||
<span class="kn">from</span> <span class="nn">tensorrt_llm.plugin.plugin</span> <span class="kn">import</span> <span class="n">CustomAllReduceHelper</span>
|
||
|
||
<span class="kn">from</span> <span class="nn">.._ipc_utils</span> <span class="kn">import</span> <span class="n">set_peer_access</span>
|
||
<span class="kn">from</span> <span class="nn">.._utils</span> <span class="kn">import</span> <span class="p">(</span><span class="n">pad_vocab_size</span><span class="p">,</span> <span class="n">preview_trt_version</span><span class="p">,</span> <span class="n">str_dtype_to_torch</span><span class="p">,</span>
|
||
<span class="n">torch_to_numpy</span><span class="p">,</span> <span class="n">trt_dtype_to_torch</span><span class="p">)</span>
|
||
<span class="kn">from</span> <span class="nn">..logger</span> <span class="kn">import</span> <span class="n">logger</span>
|
||
<span class="kn">from</span> <span class="nn">..lora_manager</span> <span class="kn">import</span> <span class="n">LoraManager</span>
|
||
<span class="kn">from</span> <span class="nn">..mapping</span> <span class="kn">import</span> <span class="n">Mapping</span>
|
||
<span class="kn">from</span> <span class="nn">..quantization</span> <span class="kn">import</span> <span class="n">QuantMode</span>
|
||
<span class="kn">from</span> <span class="nn">.kv_cache_manager</span> <span class="kn">import</span> <span class="n">GenerationSequence</span><span class="p">,</span> <span class="n">KVCacheManager</span><span class="p">,</span> <span class="n">KVCacheUpdater</span>
|
||
<span class="kn">from</span> <span class="nn">.session</span> <span class="kn">import</span> <span class="n">_scoped_stream</span>
|
||
|
||
|
||
<div class="viewcode-block" id="to_word_list_format">
|
||
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.to_word_list_format">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">to_word_list_format</span><span class="p">(</span><span class="n">word_dict</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]],</span>
|
||
<span class="n">tokenizer</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">add_special_tokens</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> format of word_dict</span>
|
||
<span class="sd"> len(word_dict) should be same to batch_size</span>
|
||
<span class="sd"> word_dict[i] means the words for batch i</span>
|
||
<span class="sd"> len(word_dict[i]) must be 1, which means it only contains 1 string</span>
|
||
<span class="sd"> This string can contains several sentences and split by ",".</span>
|
||
<span class="sd"> For example, if word_dict[2] = " I am happy, I am sad", then this function will return</span>
|
||
<span class="sd"> the ids for two short sentences " I am happy" and " I am sad".</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">assert</span> <span class="n">tokenizer</span> <span class="o">!=</span> <span class="kc">None</span><span class="p">,</span> <span class="s2">"need to set tokenizer"</span>
|
||
|
||
<span class="n">flat_ids</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="n">offsets</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="k">for</span> <span class="n">word_dict_item</span> <span class="ow">in</span> <span class="n">word_dict</span><span class="p">:</span>
|
||
<span class="n">item_flat_ids</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="n">item_offsets</span> <span class="o">=</span> <span class="p">[]</span>
|
||
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">word_dict_item</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="nb">bytes</span><span class="p">):</span>
|
||
<span class="n">word_dict_item</span> <span class="o">=</span> <span class="p">[</span><span class="n">word_dict_item</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">decode</span><span class="p">()]</span>
|
||
|
||
<span class="n">words</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">csv</span><span class="o">.</span><span class="n">reader</span><span class="p">(</span><span class="n">word_dict_item</span><span class="p">))[</span><span class="mi">0</span><span class="p">]</span>
|
||
<span class="k">for</span> <span class="n">word</span> <span class="ow">in</span> <span class="n">words</span><span class="p">:</span>
|
||
<span class="n">ids</span> <span class="o">=</span> <span class="n">tokenizer</span><span class="o">.</span><span class="n">encode</span><span class="p">(</span><span class="n">word</span><span class="p">,</span> <span class="n">add_special_tokens</span><span class="o">=</span><span class="n">add_special_tokens</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">ids</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="k">continue</span>
|
||
|
||
<span class="n">item_flat_ids</span> <span class="o">+=</span> <span class="n">ids</span>
|
||
<span class="n">item_offsets</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">ids</span><span class="p">))</span>
|
||
|
||
<span class="n">flat_ids</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">item_flat_ids</span><span class="p">))</span>
|
||
<span class="n">offsets</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">item_offsets</span><span class="p">)))</span>
|
||
|
||
<span class="n">pad_to</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="nb">max</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">ids</span><span class="p">)</span> <span class="k">for</span> <span class="n">ids</span> <span class="ow">in</span> <span class="n">flat_ids</span><span class="p">))</span>
|
||
|
||
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="p">(</span><span class="n">ids</span><span class="p">,</span> <span class="n">offs</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="nb">zip</span><span class="p">(</span><span class="n">flat_ids</span><span class="p">,</span> <span class="n">offsets</span><span class="p">)):</span>
|
||
<span class="n">flat_ids</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">pad</span><span class="p">(</span><span class="n">ids</span><span class="p">,</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">pad_to</span> <span class="o">-</span> <span class="nb">len</span><span class="p">(</span><span class="n">ids</span><span class="p">)),</span> <span class="n">constant_values</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
|
||
<span class="n">offsets</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">pad</span><span class="p">(</span><span class="n">offs</span><span class="p">,</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">pad_to</span> <span class="o">-</span> <span class="nb">len</span><span class="p">(</span><span class="n">offs</span><span class="p">)),</span> <span class="n">constant_values</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
|
||
|
||
<span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">flat_ids</span><span class="p">,</span> <span class="n">offsets</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"int32"</span><span class="p">)</span><span class="o">.</span><span class="n">transpose</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">))</span></div>
|
||
|
||
|
||
|
||
<span class="k">def</span> <span class="nf">_prepare_input_ids</span><span class="p">(</span><span class="n">tensors</span><span class="p">:</span> <span class="n">Sequence</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]):</span>
|
||
<span class="n">tensors</span> <span class="o">=</span> <span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">flatten</span><span class="p">(</span><span class="n">t</span><span class="p">)</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="n">tensors</span><span class="p">]</span>
|
||
<span class="n">data</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">concat</span><span class="p">(</span><span class="n">tensors</span><span class="p">)</span>
|
||
<span class="n">row_lengths</span> <span class="o">=</span> <span class="p">[</span><span class="n">t</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="n">tensors</span><span class="p">]</span>
|
||
<span class="n">row_lengths</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">row_lengths</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="n">data</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">row_lengths</span><span class="p">)</span>
|
||
|
||
|
||
<span class="k">def</span> <span class="nf">CUASSERT</span><span class="p">(</span><span class="n">cuda_ret</span><span class="p">):</span>
|
||
<span class="n">err</span> <span class="o">=</span> <span class="n">cuda_ret</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
||
<span class="k">if</span> <span class="n">err</span> <span class="o">!=</span> <span class="n">cudart</span><span class="o">.</span><span class="n">cudaError_t</span><span class="o">.</span><span class="n">cudaSuccess</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s2">"CUDA ERROR: </span><span class="si">{</span><span class="n">err</span><span class="si">}</span><span class="s2">, error code reference: https://nvidia.github.io/cuda-python/module/cudart.html#cuda.cudart.cudaError_t"</span>
|
||
<span class="p">)</span>
|
||
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">cuda_ret</span><span class="p">)</span> <span class="o">></span> <span class="mi">1</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">cuda_ret</span><span class="p">[</span><span class="mi">1</span><span class="p">:]</span>
|
||
<span class="k">return</span> <span class="kc">None</span>
|
||
|
||
|
||
<span class="k">def</span> <span class="nf">_update_cuda_graph_instance</span><span class="p">(</span><span class="n">instance</span><span class="p">,</span> <span class="n">graph</span><span class="p">):</span>
|
||
<span class="n">err</span> <span class="o">=</span> <span class="n">cudart</span><span class="o">.</span><span class="n">cudaGraphExecUpdate</span><span class="p">(</span><span class="n">instance</span><span class="p">,</span> <span class="n">graph</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">err</span> <span class="o">!=</span> <span class="n">cudart</span><span class="o">.</span><span class="n">cudaError_t</span><span class="o">.</span><span class="n">cudaSuccess</span><span class="p">:</span>
|
||
<span class="c1"># When updating cuda graph failed, destroy and instantiate one.</span>
|
||
<span class="n">CUASSERT</span><span class="p">(</span><span class="n">cudart</span><span class="o">.</span><span class="n">cudaGraphExecDestroy</span><span class="p">(</span><span class="n">instance</span><span class="p">))</span>
|
||
<span class="n">instance</span> <span class="o">=</span> <span class="n">CUASSERT</span><span class="p">(</span><span class="n">cudart</span><span class="o">.</span><span class="n">cudaGraphInstantiate</span><span class="p">(</span><span class="n">graph</span><span class="p">,</span> <span class="mi">0</span><span class="p">))[</span><span class="mi">0</span><span class="p">]</span>
|
||
<span class="k">return</span> <span class="n">instance</span>
|
||
|
||
|
||
<span class="k">def</span> <span class="nf">_prepare_attention_mask</span><span class="p">(</span><span class="n">input_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">pad_id</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
|
||
<span class="n">is_pad_id_in_inputs</span> <span class="o">=</span> <span class="p">(</span><span class="n">pad_id</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">)</span> <span class="ow">and</span> <span class="p">(</span><span class="n">pad_id</span> <span class="ow">in</span> <span class="n">input_ids</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">input_ids</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">is_pad_id_in_inputs</span><span class="p">:</span>
|
||
<span class="n">mask</span> <span class="o">=</span> <span class="n">input_ids</span><span class="o">.</span><span class="n">ne</span><span class="p">(</span><span class="n">pad_id</span><span class="p">)</span><span class="o">.</span><span class="n">int</span><span class="p">()</span>
|
||
<span class="c1"># for enc-dec models, pad_id could be the start token and should be always counted</span>
|
||
<span class="c1"># as valid token rather than padded token, so we force its mask to be 1.</span>
|
||
<span class="c1"># This doesn't impact the existing behavior</span>
|
||
<span class="n">mask</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span>
|
||
<span class="k">return</span> <span class="n">mask</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">input_ids</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="n">input_ids</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||
|
||
|
||
<span class="k">def</span> <span class="nf">_tile_beam_width</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">num_beams</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
|
||
<span class="n">new_shape</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
|
||
<span class="n">new_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">new_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">num_beams</span>
|
||
|
||
<span class="n">tile_size</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">new_shape</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
|
||
<span class="n">tile_size</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">insert</span><span class="p">(</span><span class="n">tile_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">num_beams</span><span class="p">)</span>
|
||
|
||
<span class="n">new_tensor</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
|
||
<span class="n">new_tensor</span> <span class="o">=</span> <span class="n">new_tensor</span><span class="o">.</span><span class="n">tile</span><span class="p">(</span><span class="n">tile_size</span><span class="o">.</span><span class="n">tolist</span><span class="p">())</span>
|
||
<span class="n">new_tensor</span> <span class="o">=</span> <span class="n">new_tensor</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">new_shape</span><span class="o">.</span><span class="n">tolist</span><span class="p">())</span>
|
||
<span class="k">return</span> <span class="n">new_tensor</span>
|
||
|
||
|
||
<span class="k">class</span> <span class="nc">_Runtime</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
|
||
<span class="n">runtime_rank</span><span class="p">:</span> <span class="nb">int</span>
|
||
<span class="n">runtime</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">Runtime</span>
|
||
<span class="n">engine</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">ICudaEngine</span>
|
||
<span class="n">ctx_context</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">IExecutionContext</span>
|
||
<span class="n">context_0</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">IExecutionContext</span>
|
||
<span class="n">context_1</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">IExecutionContext</span>
|
||
<span class="n">cuda_graph_instances</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">cudart</span><span class="o">.</span><span class="n">cudaGraphExec_t</span><span class="p">]</span>
|
||
|
||
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">engine_buffer</span><span class="p">,</span> <span class="n">mapping</span><span class="p">:</span> <span class="n">Mapping</span><span class="p">):</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">address</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">__prepare</span><span class="p">(</span><span class="n">mapping</span><span class="p">,</span> <span class="n">engine_buffer</span><span class="p">)</span>
|
||
|
||
<span class="k">def</span> <span class="nf">_serialize_engine</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">trt</span><span class="o">.</span><span class="n">IHostMemory</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">serialize</span><span class="p">()</span>
|
||
|
||
<span class="k">def</span> <span class="nf">__create_and_setup_context</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">address</span><span class="p">,</span> <span class="n">profile_idx</span><span class="p">,</span>
|
||
<span class="n">stream</span><span class="p">)</span> <span class="o">-></span> <span class="n">trt</span><span class="o">.</span><span class="n">IExecutionContext</span><span class="p">:</span>
|
||
<span class="n">context</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">create_execution_context_without_device_memory</span><span class="p">()</span>
|
||
<span class="k">assert</span> <span class="n">context</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="n">context</span><span class="o">.</span><span class="n">device_memory</span> <span class="o">=</span> <span class="n">address</span>
|
||
<span class="n">context</span><span class="o">.</span><span class="n">set_optimization_profile_async</span><span class="p">(</span><span class="n">profile_idx</span><span class="p">,</span> <span class="n">stream</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">context</span>
|
||
|
||
<span class="k">def</span> <span class="nf">__prepare</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">mapping</span><span class="p">:</span> <span class="n">Mapping</span><span class="p">,</span> <span class="n">engine_buffer</span><span class="p">):</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">runtime_rank</span> <span class="o">=</span> <span class="n">mapping</span><span class="o">.</span><span class="n">rank</span>
|
||
<span class="n">local_rank</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime_rank</span> <span class="o">%</span> <span class="n">mapping</span><span class="o">.</span><span class="n">gpus_per_node</span>
|
||
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">set_device</span><span class="p">(</span><span class="n">local_rank</span><span class="p">)</span>
|
||
<span class="n">CUASSERT</span><span class="p">(</span><span class="n">cudart</span><span class="o">.</span><span class="n">cudaSetDevice</span><span class="p">(</span><span class="n">local_rank</span><span class="p">))</span>
|
||
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">Runtime</span><span class="p">(</span><span class="n">logger</span><span class="o">.</span><span class="n">trt_logger</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">engine</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">deserialize_cuda_engine</span><span class="p">(</span><span class="n">engine_buffer</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="c1"># The device_memory_size stores the memory required by the largest profile</span>
|
||
<span class="n">address</span> <span class="o">=</span> <span class="n">CUASSERT</span><span class="p">(</span><span class="n">cudart</span><span class="o">.</span><span class="n">cudaMalloc</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">device_memory_size</span><span class="p">))[</span><span class="mi">0</span><span class="p">]</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">address</span> <span class="o">=</span> <span class="n">address</span>
|
||
|
||
<span class="c1"># cuda graph ping-pong instances</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">cuda_graph_instances</span> <span class="o">=</span> <span class="p">[</span><span class="kc">None</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">2</span><span class="p">)]</span>
|
||
|
||
<span class="k">with</span> <span class="n">_scoped_stream</span><span class="p">()</span> <span class="k">as</span> <span class="n">stream</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">num_optimization_profiles</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
|
||
<span class="c1"># At step = 0, context_1 is active</span>
|
||
<span class="c1"># At step = 1, context_0 is active</span>
|
||
<span class="c1"># At step = 2, context_1 is active</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">context_0</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">__create_and_setup_context</span><span class="p">(</span>
|
||
<span class="n">address</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">stream</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">context_1</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">__create_and_setup_context</span><span class="p">(</span>
|
||
<span class="n">address</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">stream</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">ctx_context</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">context_1</span>
|
||
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">num_optimization_profiles</span> <span class="o">==</span> <span class="mi">2</span><span class="p">:</span>
|
||
<span class="c1"># At step = 0, ctx_context is active</span>
|
||
<span class="c1"># At step = 1, context_0 is active</span>
|
||
<span class="c1"># At step = 2, context_1 is active</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">ctx_context</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">__create_and_setup_context</span><span class="p">(</span>
|
||
<span class="n">address</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">stream</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">context_0</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">__create_and_setup_context</span><span class="p">(</span>
|
||
<span class="n">address</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">stream</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">context_1</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">__create_and_setup_context</span><span class="p">(</span>
|
||
<span class="n">address</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">stream</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="k">assert</span> <span class="kc">False</span><span class="p">,</span> <span class="s2">"Maximum of up to two optimization profiles only"</span>
|
||
|
||
<span class="k">def</span> <span class="nf">_set_shape</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">IExecutionContext</span><span class="p">,</span>
|
||
<span class="n">shape_dict</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]):</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">num_io_tensors</span><span class="p">):</span>
|
||
<span class="n">name</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_name</span><span class="p">(</span><span class="n">i</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">shape_dict</span><span class="p">:</span>
|
||
<span class="c1"># shape and buffer can be set by calling _set_tensors API</span>
|
||
<span class="k">continue</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_mode</span><span class="p">(</span><span class="n">name</span><span class="p">)</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">TensorIOMode</span><span class="o">.</span><span class="n">INPUT</span><span class="p">:</span>
|
||
<span class="n">ok</span> <span class="o">=</span> <span class="n">context</span><span class="o">.</span><span class="n">set_input_shape</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">shape_dict</span><span class="p">[</span><span class="n">name</span><span class="p">])</span>
|
||
<span class="n">dtype</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_dtype</span><span class="p">(</span><span class="n">name</span><span class="p">)</span>
|
||
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s2">"setting input tensor </span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s2"> with shape </span><span class="si">{</span><span class="n">shape_dict</span><span class="p">[</span><span class="n">name</span><span class="p">]</span><span class="si">}</span><span class="s2"> and type </span><span class="si">{</span><span class="n">dtype</span><span class="si">}</span><span class="s2">"</span>
|
||
<span class="p">)</span>
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="n">ok</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s2">"Couldn't assign </span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s2"> with shape </span><span class="si">{</span><span class="n">shape_dict</span><span class="p">[</span><span class="n">name</span><span class="p">]</span><span class="si">}</span><span class="s2">, "</span>
|
||
<span class="sa">f</span><span class="s2">"engine supports [min, opt, max] = </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_profile_shape</span><span class="p">(</span><span class="n">name</span><span class="p">,</span><span class="w"> </span><span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">active_optimization_profile</span><span class="p">)</span><span class="si">}</span><span class="s2">"</span>
|
||
<span class="p">)</span>
|
||
|
||
<span class="k">def</span> <span class="nf">_set_buffer</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">IExecutionContext</span><span class="p">,</span>
|
||
<span class="n">buffer_dict</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]):</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">num_io_tensors</span><span class="p">):</span>
|
||
<span class="n">name</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_name</span><span class="p">(</span><span class="n">i</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">name</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">buffer_dict</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
|
||
<span class="n">dtype</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_dtype</span><span class="p">(</span><span class="n">name</span><span class="p">)</span>
|
||
<span class="n">shape</span> <span class="o">=</span> <span class="n">context</span><span class="o">.</span><span class="n">get_tensor_shape</span><span class="p">(</span><span class="n">name</span><span class="p">)</span>
|
||
<span class="n">buffer_dict</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="nb">tuple</span><span class="p">(</span><span class="n">shape</span><span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">trt_dtype_to_torch</span><span class="p">(</span><span class="n">dtype</span><span class="p">),</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">buffer_dict</span><span class="p">[</span><span class="n">name</span><span class="p">]</span><span class="o">.</span><span class="n">is_contiguous</span><span class="p">(</span>
|
||
<span class="p">),</span> <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s2"> is not contiguous()"</span>
|
||
<span class="n">context</span><span class="o">.</span><span class="n">set_tensor_address</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">buffer_dict</span><span class="p">[</span><span class="n">name</span><span class="p">]</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">())</span>
|
||
|
||
<span class="k">def</span> <span class="nf">_set_tensors</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">IExecutionContext</span><span class="p">,</span>
|
||
<span class="n">tensors</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="s2">"RuntimeTensor"</span><span class="p">]):</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">num_io_tensors</span><span class="p">):</span>
|
||
<span class="n">name</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_name</span><span class="p">(</span><span class="n">i</span><span class="p">)</span>
|
||
<span class="c1"># it's allowed to call set_tensors multi times with different tensors</span>
|
||
<span class="c1"># each time only set some of the engine tensors, so it is valid to skip the ones not in the current given tensors dict</span>
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">tensors</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_mode</span><span class="p">(</span><span class="n">name</span><span class="p">)</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">TensorIOMode</span><span class="o">.</span><span class="n">OUTPUT</span><span class="p">:</span>
|
||
<span class="n">dtype</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_dtype</span><span class="p">(</span><span class="n">name</span><span class="p">)</span>
|
||
<span class="n">shape</span> <span class="o">=</span> <span class="n">context</span><span class="o">.</span><span class="n">get_tensor_shape</span><span class="p">(</span><span class="n">name</span><span class="p">)</span>
|
||
<span class="n">tensors</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="n">RuntimeTensor</span><span class="o">.</span><span class="n">from_torch</span><span class="p">(</span>
|
||
<span class="n">name</span><span class="p">,</span>
|
||
<span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="nb">tuple</span><span class="p">(</span><span class="n">shape</span><span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">trt_dtype_to_torch</span><span class="p">(</span><span class="n">dtype</span><span class="p">),</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">))</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="k">continue</span>
|
||
<span class="n">t</span> <span class="o">=</span> <span class="n">tensors</span><span class="p">[</span><span class="n">name</span><span class="p">]</span>
|
||
<span class="c1"># output's shape is inference by TRT, no need to set the shape here</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_mode</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">name</span><span class="p">)</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">TensorIOMode</span><span class="o">.</span><span class="n">INPUT</span><span class="p">:</span>
|
||
<span class="n">context</span><span class="o">.</span><span class="n">set_input_shape</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="n">t</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
|
||
<span class="n">context</span><span class="o">.</span><span class="n">set_tensor_address</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="n">t</span><span class="o">.</span><span class="n">data</span><span class="p">)</span>
|
||
|
||
<span class="k">def</span> <span class="nf">_check_tensors</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">IExecutionContext</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">num_io_tensors</span><span class="p">):</span>
|
||
<span class="n">name</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_name</span><span class="p">(</span><span class="n">i</span><span class="p">)</span>
|
||
<span class="n">ptr</span> <span class="o">=</span> <span class="n">context</span><span class="o">.</span><span class="n">get_tensor_address</span><span class="p">(</span><span class="n">name</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">ptr</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Engine I/O tensor </span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s2"> is unbound"</span><span class="p">)</span>
|
||
|
||
<span class="k">def</span> <span class="nf">_run</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
|
||
<span class="n">context</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">IExecutionContext</span><span class="p">,</span>
|
||
<span class="n">stream</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">Stream</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-></span> <span class="nb">bool</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="n">stream</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">stream</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">current_stream</span><span class="p">()</span><span class="o">.</span><span class="n">cuda_stream</span>
|
||
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">stream</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">Stream</span><span class="p">):</span>
|
||
<span class="n">stream</span> <span class="o">=</span> <span class="n">stream</span><span class="o">.</span><span class="n">cuda_stream</span>
|
||
<span class="n">ok</span> <span class="o">=</span> <span class="n">context</span><span class="o">.</span><span class="n">execute_async_v3</span><span class="p">(</span><span class="n">stream</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">ok</span>
|
||
|
||
<span class="k">def</span> <span class="fm">__del__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">try</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">address</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">cudart</span><span class="o">.</span><span class="n">cudaFree</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">address</span><span class="p">)</span>
|
||
<span class="k">except</span> <span class="ne">TypeError</span><span class="p">:</span>
|
||
<span class="k">pass</span>
|
||
|
||
|
||
<div class="viewcode-block" id="ModelConfig">
|
||
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.ModelConfig">[docs]</a>
|
||
<span class="nd">@dataclass</span>
|
||
<span class="k">class</span> <span class="nc">ModelConfig</span><span class="p">:</span>
|
||
<span class="n">max_batch_size</span><span class="p">:</span> <span class="nb">int</span>
|
||
<span class="n">max_beam_width</span><span class="p">:</span> <span class="nb">int</span>
|
||
<span class="n">vocab_size</span><span class="p">:</span> <span class="nb">int</span>
|
||
<span class="n">num_layers</span><span class="p">:</span> <span class="nb">int</span>
|
||
<span class="n">num_heads</span><span class="p">:</span> <span class="nb">int</span>
|
||
<span class="n">num_kv_heads</span><span class="p">:</span> <span class="nb">int</span>
|
||
<span class="n">hidden_size</span><span class="p">:</span> <span class="nb">int</span>
|
||
<span class="n">gpt_attention_plugin</span><span class="p">:</span> <span class="nb">bool</span>
|
||
<span class="n">remove_input_padding</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
|
||
<span class="n">model_name</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">""</span>
|
||
<span class="n">paged_kv_cache</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
|
||
<span class="n">cross_attention</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
|
||
<span class="n">head_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="n">has_position_embedding</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span>
|
||
<span class="n">has_token_type_embedding</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
|
||
<span class="n">tokens_per_block</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">64</span>
|
||
<span class="n">max_prompt_embedding_table_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="n">quant_mode</span><span class="p">:</span> <span class="n">QuantMode</span> <span class="o">=</span> <span class="n">QuantMode</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
||
<span class="n">gather_context_logits</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
|
||
<span class="n">gather_generation_logits</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
|
||
<span class="n">dtype</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">""</span>
|
||
<span class="n">use_custom_all_reduce</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
|
||
<span class="n">lora_plugin</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
|
||
<span class="n">lora_target_modules</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default_factory</span><span class="o">=</span><span class="nb">list</span><span class="p">)</span>
|
||
<span class="n">use_context_fmha_for_generation</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
|
||
<span class="n">trtllm_modules_to_hf_modules</span><span class="p">:</span> <span class="nb">dict</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="n">skip_cross_qkv</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
|
||
<span class="n">num_medusa_heads</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="n">max_medusa_tokens</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="n">mamba_d_state</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="n">mamba_d_conv</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="n">mamba_expand</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="n">paged_state</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span>
|
||
<span class="n">mamba_conv1d_plugin</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="SamplingConfig">
|
||
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.SamplingConfig">[docs]</a>
|
||
<span class="nd">@dataclass</span>
|
||
<span class="k">class</span> <span class="nc">SamplingConfig</span><span class="p">:</span>
|
||
<span class="n">end_id</span><span class="p">:</span> <span class="nb">int</span>
|
||
<span class="n">pad_id</span><span class="p">:</span> <span class="nb">int</span>
|
||
|
||
<span class="n">max_new_tokens</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mi">20</span><span class="p">)</span>
|
||
<span class="n">num_beams</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="n">max_attention_window_size</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span>
|
||
<span class="n">sink_token_length</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span>
|
||
<span class="n">output_sequence_lengths</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||
<span class="n">return_dict</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||
<span class="n">stop_words_list</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span>
|
||
<span class="n">bad_words_list</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span>
|
||
|
||
<span class="n">temperature</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mf">1.0</span><span class="p">)</span>
|
||
<span class="n">top_k</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="n">top_p</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mf">0.0</span><span class="p">)</span>
|
||
<span class="n">top_p_decay</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span> <span class="c1"># float</span>
|
||
<span class="n">top_p_min</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span> <span class="c1"># float</span>
|
||
<span class="n">top_p_reset_ids</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span> <span class="c1"># int</span>
|
||
|
||
<span class="n">length_penalty</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mf">1.0</span><span class="p">)</span>
|
||
<span class="n">early_stopping</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="n">repetition_penalty</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mf">1.0</span><span class="p">)</span>
|
||
<span class="n">min_length</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="n">presence_penalty</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mf">0.0</span><span class="p">)</span>
|
||
<span class="n">frequency_penalty</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mf">0.0</span><span class="p">)</span>
|
||
<span class="n">use_beam_hyps</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||
|
||
<span class="c1"># None here means user didn't set it, and dynamicDecodeOp.cpp take optional value</span>
|
||
<span class="c1"># The real default value is set in dynamicDecodeOp.cpp when it's None</span>
|
||
<span class="n">beam_search_diversity_rate</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">init</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="mf">0.0</span><span class="p">)</span>
|
||
<span class="n">random_seed</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">init</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span>
|
||
<span class="n">output_cum_log_probs</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">init</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||
<span class="n">output_log_probs</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">init</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||
|
||
<div class="viewcode-block" id="SamplingConfig.update">
|
||
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.SamplingConfig.update">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">update</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
|
||
<span class="n">unused_kwargs</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span>
|
||
<span class="k">for</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span> <span class="ow">in</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
|
||
<span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">key</span><span class="p">):</span>
|
||
<span class="nb">setattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">unused_kwargs</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">value</span>
|
||
<span class="k">return</span> <span class="n">unused_kwargs</span></div>
|
||
</div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="LogitsProcessor">
|
||
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.LogitsProcessor">[docs]</a>
|
||
<span class="k">class</span> <span class="nc">LogitsProcessor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> Base class for all logit processors that can be applied during generation.</span>
|
||
<span class="sd"> """</span>
|
||
|
||
<span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">step</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">input_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">scores</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="si">}</span><span class="s2"> is an abstract class. Only classes inheriting this class can be called."</span>
|
||
<span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="LogitsProcessorList">
|
||
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.LogitsProcessorList">[docs]</a>
|
||
<span class="k">class</span> <span class="nc">LogitsProcessorList</span><span class="p">(</span><span class="nb">list</span><span class="p">,</span> <span class="n">LogitsProcessor</span><span class="p">):</span>
|
||
|
||
<span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">step</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">input_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">scores</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="k">for</span> <span class="n">processor</span> <span class="ow">in</span> <span class="bp">self</span><span class="p">:</span>
|
||
<span class="n">scores</span> <span class="o">=</span> <span class="n">processor</span><span class="p">(</span><span class="n">step</span><span class="p">,</span> <span class="n">input_ids</span><span class="p">,</span> <span class="n">scores</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">scores</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="StoppingCriteria">
|
||
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.StoppingCriteria">[docs]</a>
|
||
<span class="k">class</span> <span class="nc">StoppingCriteria</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> Base class for all stopping criteria that can be applied during generation.</span>
|
||
<span class="sd"> """</span>
|
||
|
||
<span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">step</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">input_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">scores</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="nb">bool</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">(</span><span class="s2">"StoppingCriteria needs to be subclassed"</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="StoppingCriteriaList">
|
||
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.StoppingCriteriaList">[docs]</a>
|
||
<span class="k">class</span> <span class="nc">StoppingCriteriaList</span><span class="p">(</span><span class="nb">list</span><span class="p">,</span> <span class="n">StoppingCriteria</span><span class="p">):</span>
|
||
|
||
<span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">step</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">input_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">scores</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="nb">bool</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="nb">any</span><span class="p">(</span><span class="n">criteria</span><span class="p">(</span><span class="n">step</span><span class="p">,</span> <span class="n">input_ids</span><span class="p">,</span> <span class="n">scores</span><span class="p">)</span> <span class="k">for</span> <span class="n">criteria</span> <span class="ow">in</span> <span class="bp">self</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<span class="k">class</span> <span class="nc">RuntimeTensor</span><span class="p">:</span>
|
||
|
||
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">_name</span> <span class="o">=</span> <span class="s2">""</span>
|
||
<span class="c1"># shape is the one sent to TRT, the actual torch tensor can be larger than the shape</span>
|
||
<span class="c1"># this is useful when allocating a big KV cache tensor at the beginning and incremental seq length dim of TRT engine's input tensor</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">_shape</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">_torch_tensor</span> <span class="o">=</span> <span class="kc">None</span>
|
||
|
||
<span class="nd">@staticmethod</span>
|
||
<span class="k">def</span> <span class="nf">from_torch</span><span class="p">(</span>
|
||
<span class="n">name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
|
||
<span class="n">data</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">override_shape</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Iterable</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-></span> <span class="s1">'RuntimeTensor'</span><span class="p">:</span>
|
||
<span class="k">assert</span> <span class="p">(</span><span class="nb">isinstance</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">))</span>
|
||
<span class="n">t</span> <span class="o">=</span> <span class="n">RuntimeTensor</span><span class="p">()</span>
|
||
<span class="n">t</span><span class="o">.</span><span class="n">_name</span> <span class="o">=</span> <span class="n">name</span>
|
||
<span class="c1"># need to hold the torch tensor for memory life time</span>
|
||
<span class="n">t</span><span class="o">.</span><span class="n">_torch_tensor</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span>
|
||
<span class="n">torch_shape</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">data</span><span class="o">.</span><span class="n">size</span><span class="p">())</span>
|
||
<span class="k">if</span> <span class="n">override_shape</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">t</span><span class="o">.</span><span class="n">_shape</span> <span class="o">=</span> <span class="n">override_shape</span>
|
||
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">override_shape</span><span class="p">,</span> <span class="nb">list</span><span class="p">)</span> <span class="ow">or</span> <span class="nb">isinstance</span><span class="p">(</span>
|
||
<span class="n">override_shape</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="nb">all</span><span class="p">([</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">x</span> <span class="o">>=</span> <span class="mi">0</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">override_shape</span>
|
||
<span class="p">]),</span> <span class="sa">f</span><span class="s2">"Expect all dimensions >=0, got </span><span class="si">{</span><span class="n">override_shape</span><span class="si">}</span><span class="s2">"</span>
|
||
|
||
<span class="k">def</span> <span class="nf">volume_func</span><span class="p">(</span><span class="n">dims</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="n">reduce</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">:</span> <span class="n">x</span> <span class="o">*</span> <span class="n">y</span><span class="p">,</span> <span class="n">dims</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">volume_func</span><span class="p">(</span><span class="n">override_shape</span><span class="p">)</span> <span class="o"><=</span> <span class="n">volume_func</span><span class="p">(</span><span class="n">torch_shape</span><span class="p">),</span> \
|
||
<span class="sa">f</span><span class="s2">"Override the shape to be larger than the underlying torch Tensor, got </span><span class="si">{</span><span class="n">override_shape</span><span class="si">}</span><span class="s2">, torch tensor shape </span><span class="si">{</span><span class="n">torch_shape</span><span class="si">}</span><span class="s2">"</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">t</span><span class="o">.</span><span class="n">_shape</span> <span class="o">=</span> <span class="n">torch_shape</span>
|
||
<span class="k">return</span> <span class="n">t</span>
|
||
|
||
<span class="k">def</span> <span class="nf">to_torch</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_torch_tensor</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span> <span class="nf">shape</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">Iterable</span><span class="p">[</span><span class="nb">int</span><span class="p">]:</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_shape</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span> <span class="nf">data</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_torch_tensor</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">()</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span> <span class="nf">name</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">str</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_name</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span> <span class="nf">dtype</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">torch</span><span class="o">.</span><span class="n">dtype</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_torch_tensor</span><span class="o">.</span><span class="n">dtype</span>
|
||
|
||
|
||
<div class="viewcode-block" id="GenerationSession">
|
||
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession">[docs]</a>
|
||
<span class="k">class</span> <span class="nc">GenerationSession</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
|
||
|
||
<span class="n">_model_config</span><span class="p">:</span> <span class="n">ModelConfig</span>
|
||
<span class="n">mapping</span><span class="p">:</span> <span class="n">Mapping</span>
|
||
<span class="n">runtime</span><span class="p">:</span> <span class="n">_Runtime</span>
|
||
<span class="n">device</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span>
|
||
<span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span>
|
||
<span class="n">buffer_allocated</span><span class="p">:</span> <span class="nb">bool</span>
|
||
<span class="n">debug_mode</span><span class="p">:</span> <span class="nb">bool</span>
|
||
<span class="n">quant_mode</span><span class="p">:</span> <span class="n">QuantMode</span>
|
||
<span class="n">cuda_graph_mode</span><span class="p">:</span> <span class="nb">bool</span>
|
||
<span class="n">dtype</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span>
|
||
<span class="n">debug_tensors_to_save</span><span class="p">:</span> <span class="kc">None</span>
|
||
<span class="n">num_medusa_tokens</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="n">medusa_topks</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="n">medusa_paths</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="n">medusa_tree_ids</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="n">medusa_position_offsets</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="n">medusa_temperature</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span>
|
||
|
||
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
|
||
<span class="n">model_config</span><span class="p">:</span> <span class="n">ModelConfig</span><span class="p">,</span>
|
||
<span class="n">engine_buffer</span><span class="p">,</span>
|
||
<span class="n">mapping</span><span class="p">:</span> <span class="n">Mapping</span><span class="p">,</span>
|
||
<span class="n">debug_mode</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">debug_tensors_to_save</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">cuda_graph_mode</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">stream</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">Stream</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
|
||
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">model_config</span><span class="p">,</span> <span class="n">ModelConfig</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span> <span class="o">=</span> <span class="n">model_config</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">mapping</span> <span class="o">=</span> <span class="n">mapping</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span> <span class="o">=</span> <span class="n">_Runtime</span><span class="p">(</span><span class="n">engine_buffer</span><span class="p">,</span> <span class="n">mapping</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">device</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s1">'cuda:</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">runtime_rank</span><span class="w"> </span><span class="o">%</span><span class="w"> </span><span class="n">mapping</span><span class="o">.</span><span class="n">gpus_per_node</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
||
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">set_device</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||
<span class="c1"># dynamic_decoder currently use torch's current stream, so must let TRT enqueue use same stream here</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">stream</span> <span class="o">=</span> <span class="n">stream</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">stream</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">stream</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">Stream</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">set_stream</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">stream</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">debug_mode</span> <span class="o">=</span> <span class="n">debug_mode</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">debug_tensors_to_save</span> <span class="o">=</span> <span class="n">debug_tensors_to_save</span>
|
||
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">cuda_graph_mode</span> <span class="o">=</span> <span class="n">cuda_graph_mode</span>
|
||
<span class="c1"># Optional inputs for dynamic decoder</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">top_p_decay</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">top_p_min</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">top_p_reset_ids</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="c1"># TODO: in tensorrt_llm/cpp/tensorrt_llm/thop/dynamicDecodeOp.cpp it's T, can be float or half?</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">embedding_bias_opt</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="c1"># use one more block in paged kv cache.</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">use_one_more_block</span> <span class="o">=</span> <span class="kc">False</span>
|
||
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">buffer_allocated</span> <span class="o">=</span> <span class="kc">False</span>
|
||
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span> <span class="o">=</span> <span class="n">pad_vocab_size</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">vocab_size</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span><span class="p">:</span>
|
||
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
|
||
<span class="s2">"The paged KV cache in Python runtime is experimental. For performance and correctness, please, use C++ runtime."</span>
|
||
<span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">has_pp</span><span class="p">():</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">nccl_comm</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">classes</span><span class="o">.</span><span class="n">trtllm</span><span class="o">.</span><span class="n">NcclCommunicatorOp</span><span class="p">(</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">rank</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">decoder_logits_dtype</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_tensor_dtype</span><span class="p">(</span><span class="s1">'logits'</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder_logits_dtype</span> <span class="ow">not</span> <span class="ow">in</span> <span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">]:</span>
|
||
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
|
||
<span class="s2">"Logits dtype not supported by decoder. Falling back to float32. You may want to change the logits dtype to float16 in your model definition."</span>
|
||
<span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">decoder_logits_dtype</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">dynamic_decoder</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">classes</span><span class="o">.</span><span class="n">trtllm</span><span class="o">.</span><span class="n">DynamicDecodeOp</span><span class="p">(</span>
|
||
<span class="n">model_config</span><span class="o">.</span><span class="n">max_batch_size</span><span class="p">,</span> <span class="n">model_config</span><span class="o">.</span><span class="n">max_beam_width</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">vocab_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder_logits_dtype</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">model_config</span><span class="o">.</span><span class="n">use_context_fmha_for_generation</span><span class="p">:</span>
|
||
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
|
||
<span class="s2">"Context FMHA is used for generation. Use it only for testing"</span><span class="p">)</span>
|
||
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">gather_tree</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">tensorrt_llm</span><span class="o">.</span><span class="n">gather_tree</span>
|
||
|
||
<span class="n">expected_tensor_names</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_first_pp_rank</span><span class="p">():</span>
|
||
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">'input_ids'</span><span class="p">]</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">'hidden_states_input'</span><span class="p">]</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
|
||
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">'logits'</span><span class="p">]</span>
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="n">model_config</span><span class="o">.</span><span class="n">gather_context_logits</span><span class="p">:</span>
|
||
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">'last_token_ids'</span><span class="p">]</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">'hidden_states_output'</span><span class="p">]</span>
|
||
|
||
<span class="k">if</span> <span class="n">model_config</span><span class="o">.</span><span class="n">has_position_embedding</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_first_pp_rank</span><span class="p">(</span>
|
||
<span class="p">):</span>
|
||
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">'position_ids'</span><span class="p">]</span>
|
||
<span class="k">if</span> <span class="n">model_config</span><span class="o">.</span><span class="n">has_token_type_embedding</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_first_pp_rank</span><span class="p">(</span>
|
||
<span class="p">):</span>
|
||
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">'token_type_ids'</span><span class="p">]</span>
|
||
|
||
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">'cache_indirection'</span><span class="p">]</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span><span class="p">:</span>
|
||
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="sa">f</span><span class="s1">'kv_cache_block_pointers'</span><span class="p">]</span>
|
||
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="sa">f</span><span class="s1">'host_kv_cache_block_pointers'</span><span class="p">]</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
|
||
<span class="sa">f</span><span class="s1">'past_key_value_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">'</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">)</span>
|
||
<span class="p">]</span>
|
||
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
|
||
<span class="sa">f</span><span class="s1">'present_key_value_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">'</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">)</span>
|
||
<span class="p">]</span>
|
||
|
||
<span class="k">if</span> <span class="n">model_config</span><span class="o">.</span><span class="n">gpt_attention_plugin</span><span class="p">:</span>
|
||
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
|
||
<span class="s1">'sequence_length'</span><span class="p">,</span> <span class="s1">'context_lengths'</span><span class="p">,</span> <span class="s1">'host_request_types'</span><span class="p">,</span>
|
||
<span class="s1">'host_past_key_value_lengths'</span><span class="p">,</span> <span class="s1">'host_sink_token_length'</span>
|
||
<span class="p">]</span>
|
||
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="sa">f</span><span class="s1">'host_max_attention_window_sizes'</span><span class="p">]</span>
|
||
<span class="k">if</span> <span class="n">model_config</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
|
||
<span class="n">expected_tensor_names</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="s1">'host_context_lengths'</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
|
||
<span class="s1">'attention_mask'</span><span class="p">,</span>
|
||
<span class="p">]</span>
|
||
|
||
<span class="k">if</span> <span class="n">model_config</span><span class="o">.</span><span class="n">max_prompt_embedding_table_size</span> <span class="o">></span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
|
||
<span class="s1">'prompt_embedding_table'</span><span class="p">,</span> <span class="s1">'tasks'</span><span class="p">,</span> <span class="s1">'prompt_vocab_size'</span>
|
||
<span class="p">]</span>
|
||
|
||
<span class="k">if</span> <span class="n">model_config</span><span class="o">.</span><span class="n">cross_attention</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="n">model_config</span><span class="o">.</span><span class="n">gpt_attention_plugin</span><span class="p">:</span>
|
||
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
|
||
<span class="sa">f</span><span class="s1">'cross_present_key_value_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">'</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">)</span>
|
||
<span class="p">]</span>
|
||
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
|
||
<span class="sa">f</span><span class="s1">'cross_past_key_value_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">'</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">)</span>
|
||
<span class="p">]</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
|
||
<span class="s1">'cross_attention_mask'</span><span class="p">,</span>
|
||
<span class="p">]</span>
|
||
|
||
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
|
||
<span class="s1">'encoder_output'</span><span class="p">,</span>
|
||
<span class="s1">'encoder_input_lengths'</span><span class="p">,</span>
|
||
<span class="s1">'encoder_max_input_length'</span><span class="p">,</span>
|
||
<span class="s1">'cross_kv_cache_gen'</span><span class="p">,</span>
|
||
<span class="p">]</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">skip_cross_qkv</span> <span class="o">=</span> <span class="n">model_config</span><span class="o">.</span><span class="n">skip_cross_qkv</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">skip_cross_qkv</span><span class="p">:</span>
|
||
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">'cross_qkv_reuse'</span><span class="p">]</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span> <span class="o">></span> <span class="mi">1</span> <span class="ow">and</span> <span class="n">model_config</span><span class="o">.</span><span class="n">use_custom_all_reduce</span><span class="p">:</span>
|
||
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">'all_reduce_workspace'</span><span class="p">]</span>
|
||
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">lora_target_modules</span> <span class="o">=</span> <span class="n">model_config</span><span class="o">.</span><span class="n">lora_target_modules</span>
|
||
|
||
<span class="c1"># In current design, q_lora_params, k_lora_params and v_lora_params should be all enabled or all disabled at the same time.</span>
|
||
<span class="c1"># However, there are some cases that the lora modules only contain one or two of them, so we use zero tensor to fill the missing ones.</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">missing_qkv_modules</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">lora_target_modules</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="nb">any</span><span class="p">(</span><span class="n">x</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">lora_target_modules</span>
|
||
<span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="p">[</span><span class="s2">"attn_q"</span><span class="p">,</span> <span class="s2">"attn_k"</span><span class="p">,</span> <span class="s2">"attn_v"</span><span class="p">]):</span>
|
||
<span class="k">for</span> <span class="n">lora_module</span> <span class="ow">in</span> <span class="p">[</span><span class="s2">"attn_q"</span><span class="p">,</span> <span class="s2">"attn_k"</span><span class="p">,</span> <span class="s2">"attn_v"</span><span class="p">]:</span>
|
||
<span class="k">if</span> <span class="n">lora_module</span> <span class="ow">not</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">lora_target_modules</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">missing_qkv_modules</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">lora_module</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="nb">any</span><span class="p">(</span><span class="n">x</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">lora_target_modules</span>
|
||
<span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="p">[</span><span class="s2">"cross_attn_q"</span><span class="p">,</span> <span class="s2">"cross_attn_k"</span><span class="p">,</span> <span class="s2">"cross_attn_v"</span><span class="p">]):</span>
|
||
<span class="k">for</span> <span class="n">lora_module</span> <span class="ow">in</span> <span class="p">[</span>
|
||
<span class="s2">"cross_attn_q"</span><span class="p">,</span> <span class="s2">"cross_attn_k"</span><span class="p">,</span> <span class="s2">"cross_attn_v"</span>
|
||
<span class="p">]:</span>
|
||
<span class="k">if</span> <span class="n">lora_module</span> <span class="ow">not</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">lora_target_modules</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">missing_qkv_modules</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">lora_module</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">model_config</span><span class="o">.</span><span class="n">lora_plugin</span><span class="p">:</span>
|
||
|
||
<span class="k">for</span> <span class="n">lora_module</span> <span class="ow">in</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">lora_target_modules</span> <span class="o">+</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">missing_qkv_modules</span><span class="p">):</span>
|
||
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
|
||
<span class="sa">f</span><span class="s1">'</span><span class="si">{</span><span class="n">lora_module</span><span class="si">}</span><span class="s1">_lora_ranks_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">'</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">)</span>
|
||
<span class="p">]</span>
|
||
|
||
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
|
||
<span class="sa">f</span><span class="s1">'</span><span class="si">{</span><span class="n">lora_module</span><span class="si">}</span><span class="s1">_lora_weights_pointers_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">'</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">)</span>
|
||
<span class="p">]</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
|
||
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">'host_encoder_input_lengths'</span><span class="p">]</span>
|
||
|
||
<span class="k">if</span> <span class="n">model_config</span><span class="o">.</span><span class="n">num_medusa_heads</span> <span class="o">></span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
|
||
<span class="s1">'medusa_position_offsets'</span><span class="p">,</span> <span class="s1">'medusa_packed_mask'</span><span class="p">,</span> <span class="s1">'medusa_logits'</span>
|
||
<span class="p">]</span>
|
||
|
||
<span class="n">found_tensor_names</span> <span class="o">=</span> <span class="p">[</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_name</span><span class="p">(</span><span class="n">i</span><span class="p">)</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">num_io_tensors</span><span class="p">)</span>
|
||
<span class="p">]</span>
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">debug_mode</span> <span class="ow">and</span> <span class="nb">set</span><span class="p">(</span><span class="n">expected_tensor_names</span><span class="p">)</span> <span class="o">!=</span> <span class="nb">set</span><span class="p">(</span>
|
||
<span class="n">found_tensor_names</span><span class="p">):</span>
|
||
<span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s2">"The following expected tensors are not found: </span><span class="si">{</span><span class="nb">set</span><span class="p">(</span><span class="n">expected_tensor_names</span><span class="p">)</span><span class="o">.</span><span class="n">difference</span><span class="p">(</span><span class="nb">set</span><span class="p">(</span><span class="n">found_tensor_names</span><span class="p">))</span><span class="si">}</span><span class="s2">"</span>
|
||
<span class="p">)</span>
|
||
<span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s2">"Those tensors in engine are not expected: </span><span class="si">{</span><span class="nb">set</span><span class="p">(</span><span class="n">found_tensor_names</span><span class="p">)</span><span class="o">.</span><span class="n">difference</span><span class="p">(</span><span class="nb">set</span><span class="p">(</span><span class="n">expected_tensor_names</span><span class="p">))</span><span class="si">}</span><span class="s2">"</span>
|
||
<span class="p">)</span>
|
||
<span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Expected tensor names: </span><span class="si">{</span><span class="n">expected_tensor_names</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||
<span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Found tensor names: </span><span class="si">{</span><span class="n">found_tensor_names</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span>
|
||
<span class="s2">"Tensor names in engine are not the same as expected, to use this GenerationSession, "</span>
|
||
<span class="s2">"you need to use PretrainedModel.prepare_inputs to create TRT Network inputs."</span>
|
||
<span class="p">)</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">debug_mode</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">debug_tensors</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span>
|
||
<span class="nb">set</span><span class="p">(</span><span class="n">found_tensor_names</span><span class="p">)</span> <span class="o">-</span> <span class="nb">set</span><span class="p">(</span><span class="n">expected_tensor_names</span><span class="p">))</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span> <span class="nf">vocab_size</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">vocab_size</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span> <span class="nf">num_layers</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">num_layers</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_size</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> \
|
||
<span class="sa">f</span><span class="s2">"num_layers </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">num_layers</span><span class="si">}</span><span class="s2"> must be a multiple of pipeline parallelism size </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_size</span><span class="si">}</span><span class="s2">"</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">num_layers</span> <span class="o">//</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_size</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span> <span class="nf">first_layer</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_rank</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span> <span class="nf">last_layer</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span> <span class="nf">num_heads</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">num_heads</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span> <span class="nf">hidden_size</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">hidden_size</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span> <span class="nf">use_gpt_attention_plugin</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">gpt_attention_plugin</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span> <span class="nf">use_mamba_conv1d_plugin</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">mamba_conv1d_plugin</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span> <span class="nf">paged_kv_cache</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">paged_kv_cache</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span> <span class="nf">tokens_per_block</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">tokens_per_block</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span> <span class="nf">remove_input_padding</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">remove_input_padding</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span> <span class="nf">num_heads_kv</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">num_kv_heads</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span> <span class="nf">head_size</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span> <span class="o">//</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_heads</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">head_size</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">head_size</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span> <span class="nf">max_prompt_embedding_table_size</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">max_prompt_embedding_table_size</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span> <span class="nf">quant_mode</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">quant_mode</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span> <span class="nf">gather_context_logits</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">gather_context_logits</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span> <span class="nf">gather_generation_logits</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">gather_generation_logits</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span> <span class="nf">dtype</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="n">str_dtype_to_torch</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span> <span class="nf">use_custom_all_reduce</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">use_custom_all_reduce</span>
|
||
|
||
<div class="viewcode-block" id="GenerationSession.cuda_stream_guard">
|
||
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.cuda_stream_guard">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">cuda_stream_guard</span><span class="p">(</span><span class="n">func</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">"""Sync external stream and set current stream to the one bound to the session. Reset on exit.</span>
|
||
<span class="sd"> """</span>
|
||
|
||
<span class="nd">@wraps</span><span class="p">(</span><span class="n">func</span><span class="p">)</span>
|
||
<span class="k">def</span> <span class="nf">wrapper</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
|
||
<span class="n">external_stream</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">current_stream</span><span class="p">()</span>
|
||
<span class="k">if</span> <span class="n">external_stream</span> <span class="o">!=</span> <span class="bp">self</span><span class="o">.</span><span class="n">stream</span><span class="p">:</span>
|
||
<span class="n">external_stream</span><span class="o">.</span><span class="n">synchronize</span><span class="p">()</span>
|
||
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">set_stream</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">stream</span><span class="p">)</span>
|
||
<span class="n">ret</span> <span class="o">=</span> <span class="n">func</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">external_stream</span> <span class="o">!=</span> <span class="bp">self</span><span class="o">.</span><span class="n">stream</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">stream</span><span class="o">.</span><span class="n">synchronize</span><span class="p">()</span>
|
||
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">set_stream</span><span class="p">(</span><span class="n">external_stream</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">ret</span>
|
||
|
||
<span class="k">return</span> <span class="n">wrapper</span></div>
|
||
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span> <span class="nf">cross_attention</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">cross_attention</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span> <span class="nf">has_position_embedding</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">has_position_embedding</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span> <span class="nf">has_token_type_embedding</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">has_token_type_embedding</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span> <span class="nf">use_lora_plugin</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">lora_plugin</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span> <span class="nf">use_context_fmha_for_generation</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">use_context_fmha_for_generation</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span> <span class="nf">is_medusa_mode</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_heads</span> <span class="o">></span> <span class="mi">0</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span> <span class="nf">max_medusa_tokens</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">max_medusa_tokens</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span> <span class="nf">num_medusa_heads</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">num_medusa_heads</span>
|
||
|
||
<span class="k">def</span> <span class="nf">_capture_cuda_graph_and_instantiate</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">,</span> <span class="n">stream</span><span class="p">,</span> <span class="n">step</span><span class="p">):</span>
|
||
<span class="n">instance_idx</span> <span class="o">=</span> <span class="p">(</span><span class="n">step</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">%</span> <span class="mi">2</span>
|
||
<span class="c1"># capture cuda graph</span>
|
||
<span class="n">CUASSERT</span><span class="p">(</span>
|
||
<span class="n">cudart</span><span class="o">.</span><span class="n">cudaStreamBeginCapture</span><span class="p">(</span>
|
||
<span class="n">stream</span><span class="p">,</span>
|
||
<span class="n">cudart</span><span class="o">.</span><span class="n">cudaStreamCaptureMode</span><span class="o">.</span><span class="n">cudaStreamCaptureModeGlobal</span><span class="p">))</span>
|
||
<span class="n">context</span><span class="o">.</span><span class="n">execute_async_v3</span><span class="p">(</span><span class="n">stream</span><span class="p">)</span>
|
||
<span class="n">next_graph</span> <span class="o">=</span> <span class="n">CUASSERT</span><span class="p">(</span><span class="n">cudart</span><span class="o">.</span><span class="n">cudaStreamEndCapture</span><span class="p">(</span><span class="n">stream</span><span class="p">))[</span><span class="mi">0</span><span class="p">]</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">cuda_graph_instances</span><span class="p">[</span><span class="n">instance_idx</span><span class="p">]</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">cuda_graph_instances</span><span class="p">[</span>
|
||
<span class="n">instance_idx</span><span class="p">]</span> <span class="o">=</span> <span class="n">_update_cuda_graph_instance</span><span class="p">(</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">cuda_graph_instances</span><span class="p">[</span><span class="n">instance_idx</span><span class="p">],</span> <span class="n">next_graph</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">cuda_graph_instances</span><span class="p">[</span><span class="n">instance_idx</span><span class="p">]</span> <span class="o">=</span> <span class="n">CUASSERT</span><span class="p">(</span>
|
||
<span class="n">cudart</span><span class="o">.</span><span class="n">cudaGraphInstantiate</span><span class="p">(</span><span class="n">next_graph</span><span class="p">,</span> <span class="mi">0</span><span class="p">))[</span><span class="mi">0</span><span class="p">]</span>
|
||
|
||
<span class="c1"># Pre-upload cuda graph to stream</span>
|
||
<span class="n">CUASSERT</span><span class="p">(</span>
|
||
<span class="n">cudart</span><span class="o">.</span><span class="n">cudaGraphUpload</span><span class="p">(</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">cuda_graph_instances</span><span class="p">[</span><span class="n">instance_idx</span><span class="p">],</span> <span class="n">stream</span><span class="p">))</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span> <span class="nf">paged_state</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">paged_state</span>
|
||
|
||
<span class="k">def</span> <span class="nf">__setup_decoder</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">sampling_config</span><span class="p">:</span> <span class="n">SamplingConfig</span><span class="p">,</span>
|
||
<span class="n">host_context_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''Allocate buffers and setup the post-processing decoder kernel</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">batch_size</span> <span class="o">=</span> <span class="n">host_context_lengths</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
||
<span class="n">scfg</span> <span class="o">=</span> <span class="n">sampling_config</span> <span class="c1"># just to make a shorter name, no other meaning</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scfg</span><span class="o">.</span><span class="n">top_k</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
|
||
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">top_k</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"scfg.top_k.dtype (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">top_k</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">) must be torch.int32"</span>
|
||
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">top_k</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
|
||
<span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">batch_size</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"scfg.top_k.shape[0] (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">top_k</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">) must equal to batch_size (</span><span class="si">{</span><span class="n">batch_size</span><span class="si">}</span><span class="s2">)"</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">top_k</span> <span class="o">=</span> <span class="n">scfg</span><span class="o">.</span><span class="n">top_k</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">top_k</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span> <span class="n">scfg</span><span class="o">.</span><span class="n">top_k</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scfg</span><span class="o">.</span><span class="n">top_p</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
|
||
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">top_p</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"scfg.top_p.dtype (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">top_p</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">) must be torch.float32"</span>
|
||
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">top_p</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
|
||
<span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">batch_size</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"scfg.top_p.shape[0] (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">top_p</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">) must equal to batch_size (</span><span class="si">{</span><span class="n">batch_size</span><span class="si">}</span><span class="s2">)"</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">top_p</span> <span class="o">=</span> <span class="n">scfg</span><span class="o">.</span><span class="n">top_p</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">top_p</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
|
||
<span class="n">scfg</span><span class="o">.</span><span class="n">top_p</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scfg</span><span class="o">.</span><span class="n">temperature</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
|
||
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">temperature</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"scfg.temperature.dtype (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">temperature</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">) must be torch.float32"</span>
|
||
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">temperature</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
|
||
<span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">batch_size</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"scfg.temperature.shape[0] (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">temperature</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">) must equal to batch_size (</span><span class="si">{</span><span class="n">batch_size</span><span class="si">}</span><span class="s2">)"</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">temperature</span> <span class="o">=</span> <span class="n">scfg</span><span class="o">.</span><span class="n">temperature</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">temperature</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
|
||
<span class="n">scfg</span><span class="o">.</span><span class="n">temperature</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scfg</span><span class="o">.</span><span class="n">repetition_penalty</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
|
||
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">repetition_penalty</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"scfg.repetition_penalty.dtype (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">repetition_penalty</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">) must be torch.float32"</span>
|
||
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">repetition_penalty</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
|
||
<span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">batch_size</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"scfg.repetition_penalty.shape[0] (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">repetition_penalty</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">) must equal to batch_size (</span><span class="si">{</span><span class="n">batch_size</span><span class="si">}</span><span class="s2">)"</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">repetition_penalty</span> <span class="o">=</span> <span class="n">scfg</span><span class="o">.</span><span class="n">repetition_penalty</span>
|
||
<span class="k">elif</span> <span class="n">scfg</span><span class="o">.</span><span class="n">repetition_penalty</span> <span class="o">==</span> <span class="mf">1.0</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">repetition_penalty</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">repetition_penalty</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
|
||
<span class="n">scfg</span><span class="o">.</span><span class="n">repetition_penalty</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scfg</span><span class="o">.</span><span class="n">length_penalty</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
|
||
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">length_penalty</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"scfg.length_penalty.dtype (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">length_penalty</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">) must be torch.float32"</span>
|
||
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">length_penalty</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
|
||
<span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">batch_size</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"scfg.length_penalty.shape[0] (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">length_penalty</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">) must equal to batch_size (</span><span class="si">{</span><span class="n">batch_size</span><span class="si">}</span><span class="s2">)"</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">host_length_penalty</span> <span class="o">=</span> <span class="n">scfg</span><span class="o">.</span><span class="n">length_penalty</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">host_length_penalty</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
|
||
<span class="n">scfg</span><span class="o">.</span><span class="n">length_penalty</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
||
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">length_penalty</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">host_length_penalty</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scfg</span><span class="o">.</span><span class="n">early_stopping</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
|
||
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">early_stopping</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"scfg.early_stopping.dtype (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">early_stopping</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">) must be torch.int32"</span>
|
||
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">early_stopping</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
|
||
<span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">batch_size</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"scfg.early_stopping.shape[0] (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">early_stopping</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">) must equal to batch_size (</span><span class="si">{</span><span class="n">batch_size</span><span class="si">}</span><span class="s2">)"</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">host_early_stopping</span> <span class="o">=</span> <span class="n">scfg</span><span class="o">.</span><span class="n">early_stopping</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">host_early_stopping</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
|
||
<span class="n">scfg</span><span class="o">.</span><span class="n">early_stopping</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
|
||
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">early_stopping</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">host_early_stopping</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scfg</span><span class="o">.</span><span class="n">presence_penalty</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
|
||
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">presence_penalty</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"scfg.presence_penalty.dtype (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">presence_penalty</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">) must be torch.float32"</span>
|
||
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">presence_penalty</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
|
||
<span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">batch_size</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"scfg.presence_penalty.shape[0] (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">presence_penalty</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">) must equal to batch_size (</span><span class="si">{</span><span class="n">batch_size</span><span class="si">}</span><span class="s2">)"</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">presence_penalty</span> <span class="o">=</span> <span class="n">scfg</span><span class="o">.</span><span class="n">presence_penalty</span>
|
||
<span class="k">elif</span> <span class="n">scfg</span><span class="o">.</span><span class="n">presence_penalty</span> <span class="o">==</span> <span class="mf">0.0</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">presence_penalty</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">presence_penalty</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
|
||
<span class="n">scfg</span><span class="o">.</span><span class="n">presence_penalty</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scfg</span><span class="o">.</span><span class="n">frequency_penalty</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
|
||
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">frequency_penalty</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"scfg.frequency_penalty.dtype (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">frequency_penalty</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">) must be torch.float32"</span>
|
||
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">frequency_penalty</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
|
||
<span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">batch_size</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"scfg.frequency_penalty.shape[0] (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">frequency_penalty</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">) must equal to batch_size (</span><span class="si">{</span><span class="n">batch_size</span><span class="si">}</span><span class="s2">)"</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">frequency_penalty</span> <span class="o">=</span> <span class="n">scfg</span><span class="o">.</span><span class="n">frequency_penalty</span>
|
||
<span class="k">elif</span> <span class="n">scfg</span><span class="o">.</span><span class="n">frequency_penalty</span> <span class="o">==</span> <span class="mf">0.0</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">frequency_penalty</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">frequency_penalty</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
|
||
<span class="n">scfg</span><span class="o">.</span><span class="n">frequency_penalty</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scfg</span><span class="o">.</span><span class="n">min_length</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
|
||
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">min_length</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"scfg.min_length.dtype (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">min_length</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">) must be torch.int32"</span>
|
||
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">min_length</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
|
||
<span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">batch_size</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"scfg.min_length.shape[0] (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">min_length</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">) must equal to batch_size (</span><span class="si">{</span><span class="n">batch_size</span><span class="si">}</span><span class="s2">)"</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">min_length</span> <span class="o">=</span> <span class="n">scfg</span><span class="o">.</span><span class="n">min_length</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">min_length</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
|
||
<span class="n">scfg</span><span class="o">.</span><span class="n">min_length</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scfg</span><span class="o">.</span><span class="n">beam_search_diversity_rate</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
|
||
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">beam_search_diversity_rate</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"scfg.beam_search_diversity_rate.dtype (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">beam_search_diversity_rate</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">) must be torch.float32"</span>
|
||
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">beam_search_diversity_rate</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
|
||
<span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">batch_size</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"scfg.beam_search_diversity_rate.shape[0] (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">beam_search_diversity_rate</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">) must equal to batch_size (</span><span class="si">{</span><span class="n">batch_size</span><span class="si">}</span><span class="s2">)"</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">beam_search_diversity_rate</span> <span class="o">=</span> <span class="n">scfg</span><span class="o">.</span><span class="n">beam_search_diversity_rate</span>
|
||
<span class="k">elif</span> <span class="n">scfg</span><span class="o">.</span><span class="n">beam_search_diversity_rate</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">beam_search_diversity_rate</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">(</span>
|
||
<span class="p">[</span><span class="n">batch_size</span><span class="p">],</span>
|
||
<span class="n">scfg</span><span class="o">.</span><span class="n">beam_search_diversity_rate</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">beam_search_diversity_rate</span> <span class="o">=</span> <span class="kc">None</span>
|
||
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scfg</span><span class="o">.</span><span class="n">random_seed</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
|
||
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">random_seed</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">int64</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"scfg.random_seed.dtype (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">random_seed</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">) must be torch.int64"</span>
|
||
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">random_seed</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
|
||
<span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">batch_size</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"scfg.random_seed.shape[0] (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">random_seed</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">) must equal to batch_size (</span><span class="si">{</span><span class="n">batch_size</span><span class="si">}</span><span class="s2">)"</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">random_seed</span> <span class="o">=</span> <span class="n">scfg</span><span class="o">.</span><span class="n">random_seed</span>
|
||
<span class="k">elif</span> <span class="n">scfg</span><span class="o">.</span><span class="n">random_seed</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">random_seed</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
|
||
<span class="n">scfg</span><span class="o">.</span><span class="n">random_seed</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int64</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">random_seed</span> <span class="o">=</span> <span class="kc">None</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">dynamic_decoder</span><span class="o">.</span><span class="n">setup</span><span class="p">(</span>
|
||
<span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">top_k</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">top_p</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">temperature</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">repetition_penalty</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">presence_penalty</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">frequency_penalty</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">min_length</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">host_length_penalty</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">host_early_stopping</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">beam_search_diversity_rate</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">random_seed</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">top_p_decay</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">top_p_min</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">top_p_reset_ids</span><span class="p">)</span>
|
||
|
||
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">end_id</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">,</span> <span class="s2">"end_id cannot be none"</span>
|
||
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">pad_id</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">,</span> <span class="s1">'pad_id cannot be none'</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">end_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">((</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span><span class="p">,</span> <span class="p">),</span>
|
||
<span class="n">scfg</span><span class="o">.</span><span class="n">end_id</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||
<span class="n">max_context_length</span> <span class="o">=</span> <span class="n">host_context_lengths</span><span class="o">.</span><span class="n">max</span><span class="p">()</span>
|
||
|
||
<span class="c1"># setup output ids buffer</span>
|
||
<span class="k">if</span> <span class="n">input_ids</span><span class="o">.</span><span class="n">dim</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
|
||
<span class="c1"># input_ids only have one dimension, which means remove_padding is enabled</span>
|
||
<span class="n">split_ids_list</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span>
|
||
<span class="n">torch</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">input_ids</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
|
||
<span class="n">host_context_lengths</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span><span class="o">.</span><span class="n">tolist</span><span class="p">(),</span>
|
||
<span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">))</span>
|
||
<span class="n">padded_input_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nested</span><span class="o">.</span><span class="n">to_padded_tensor</span><span class="p">(</span>
|
||
<span class="n">torch</span><span class="o">.</span><span class="n">nested</span><span class="o">.</span><span class="n">nested_tensor</span><span class="p">(</span><span class="n">split_ids_list</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">),</span>
|
||
<span class="n">scfg</span><span class="o">.</span><span class="n">pad_id</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">max_context_length</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">padded_input_ids</span> <span class="o">=</span> <span class="n">input_ids</span>
|
||
<span class="k">if</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span> <span class="o">></span> <span class="mi">1</span><span class="p">:</span>
|
||
<span class="n">tiled_input_ids</span> <span class="o">=</span> <span class="n">_tile_beam_width</span><span class="p">(</span><span class="n">padded_input_ids</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span><span class="p">)</span>
|
||
<span class="n">tiled_input_ids</span> <span class="o">=</span> <span class="n">tiled_input_ids</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span>
|
||
<span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span><span class="p">,</span>
|
||
<span class="n">max_context_length</span><span class="p">)</span>
|
||
<span class="n">tiled_input_ids</span><span class="o">.</span><span class="n">permute</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="c1"># TODO: delete?</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">output_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span>
|
||
<span class="p">(</span><span class="n">tiled_input_ids</span><span class="p">,</span>
|
||
<span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span> <span class="o">-</span> <span class="n">max_context_length</span><span class="p">),</span>
|
||
<span class="n">scfg</span><span class="o">.</span><span class="n">end_id</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">padded_input_ids</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="n">padded_input_ids</span><span class="o">.</span><span class="n">device</span><span class="p">)),</span>
|
||
<span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">output_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span>
|
||
<span class="p">(</span><span class="n">padded_input_ids</span><span class="p">,</span>
|
||
<span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">(</span>
|
||
<span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span> <span class="o">-</span> <span class="n">max_context_length</span><span class="p">),</span>
|
||
<span class="n">scfg</span><span class="o">.</span><span class="n">end_id</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">padded_input_ids</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="n">padded_input_ids</span><span class="o">.</span><span class="n">device</span><span class="p">)),</span>
|
||
<span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
|
||
|
||
<span class="c1"># Note: we still allocate max_seq_length size of parent ids (not max_attention_window_size).</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">parent_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span>
|
||
<span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span>
|
||
<span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_tokens</span> <span class="o">+</span> <span class="mi">1</span><span class="p">],</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">generation_input_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span>
|
||
<span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_tokens</span> <span class="o">+</span> <span class="mi">1</span><span class="p">],</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">medusa_output_tokens</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span>
|
||
<span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_tokens</span><span class="p">],</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">medusa_temperature</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">medusa_output_logits</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span>
|
||
<span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_heads</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">],</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_tensor_dtype</span><span class="p">(</span><span class="s1">'logits'</span><span class="p">),</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||
<span class="k">elif</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span> <span class="o">></span> <span class="mi">1</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span><span class="p">,</span> <span class="mi">1</span><span class="p">],</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">],</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span> <span class="o">></span> <span class="mi">1</span> <span class="ow">or</span> <span class="n">scfg</span><span class="o">.</span><span class="n">output_cum_log_probs</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">cum_log_probs</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span><span class="p">),</span>
|
||
<span class="o">-</span><span class="mf">1e20</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">cum_log_probs</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="mf">0.0</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">cum_log_probs</span> <span class="o">=</span> <span class="kc">None</span>
|
||
|
||
<span class="k">if</span> <span class="n">scfg</span><span class="o">.</span><span class="n">output_log_probs</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">log_probs</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span>
|
||
<span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">log_probs_tiled</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span>
|
||
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span><span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">log_probs</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">log_probs_tiled</span> <span class="o">=</span> <span class="kc">None</span>
|
||
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">finished</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span><span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">uint8</span><span class="p">,</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">scfg</span><span class="o">.</span><span class="n">use_beam_hyps</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_output_ids_tgt</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">(</span>
|
||
<span class="n">size</span><span class="o">=</span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span> <span class="o">*</span> <span class="mi">2</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">],</span>
|
||
<span class="n">fill_value</span><span class="o">=</span><span class="n">scfg</span><span class="o">.</span><span class="n">end_id</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_sequence_lengths_tgt</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span>
|
||
<span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span> <span class="o">*</span> <span class="mi">2</span><span class="p">],</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_cum_log_probs</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span>
|
||
<span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span> <span class="o">*</span> <span class="mi">2</span><span class="p">],</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">,</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_normed_scores</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span>
|
||
<span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span> <span class="o">*</span> <span class="mi">2</span><span class="p">],</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">,</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_log_probs</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span>
|
||
<span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span> <span class="o">*</span> <span class="mi">2</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">],</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">,</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_min_normed_scores</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">,</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_num_beams</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_is_done</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">bool</span><span class="p">,</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_output_ids_tgt</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_sequence_lengths_tgt</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_cum_log_probs</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_normed_scores</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_log_probs</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_min_normed_scores</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_num_beams</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_is_done</span> <span class="o">=</span> <span class="kc">None</span>
|
||
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">cross_qkv_reuse</span> <span class="o">=</span> <span class="kc">None</span>
|
||
|
||
<span class="k">def</span> <span class="nf">_tensor_dtype</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">name</span><span class="p">):</span>
|
||
<span class="c1"># return torch dtype given tensor name for convenience</span>
|
||
<span class="n">dtype</span> <span class="o">=</span> <span class="n">trt_dtype_to_torch</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_dtype</span><span class="p">(</span><span class="n">name</span><span class="p">))</span>
|
||
<span class="k">return</span> <span class="n">dtype</span>
|
||
|
||
<span class="k">def</span> <span class="nf">_init_medusa</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">medusa_choices</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]):</span>
|
||
<span class="kn">from</span> <span class="nn">tensorrt_llm.runtime.medusa_utils</span> <span class="kn">import</span> <span class="p">(</span><span class="n">_medusa_setup</span><span class="p">,</span>
|
||
<span class="n">expand_choices_if_needed</span><span class="p">)</span>
|
||
<span class="n">medusa_choices</span> <span class="o">=</span> <span class="n">expand_choices_if_needed</span><span class="p">(</span><span class="n">medusa_choices</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_tokens</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">medusa_choices</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_tokens</span> <span class="o">></span> <span class="mi">0</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_tokens</span> <span class="o"><=</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_medusa_tokens</span>
|
||
<span class="n">medusa_info</span> <span class="o">=</span> <span class="n">_medusa_setup</span><span class="p">(</span><span class="n">medusa_choices</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_heads</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">medusa_topks</span> <span class="o">=</span> <span class="n">medusa_info</span><span class="o">.</span><span class="n">medusa_topks</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">medusa_mask</span> <span class="o">=</span> <span class="n">medusa_info</span><span class="o">.</span><span class="n">medusa_mask</span><span class="p">[</span><span class="mi">1</span><span class="p">:,</span> <span class="mi">1</span><span class="p">:]</span><span class="o">.</span><span class="n">to</span><span class="p">(</span>
|
||
<span class="n">torch</span><span class="o">.</span><span class="n">bool</span>
|
||
<span class="p">)</span> <span class="c1"># convert to bool, original mask includes true token as well</span>
|
||
|
||
<span class="c1"># Expand medusa position offsets to number of batch size in order to be compatible with the new Medusa.</span>
|
||
<span class="n">target_shape</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">medusa_info</span><span class="o">.</span><span class="n">medusa_packed_mask</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
|
||
<span class="n">target_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">medusa_packed_mask</span> <span class="o">=</span> <span class="n">medusa_info</span><span class="o">.</span><span class="n">medusa_packed_mask</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span>
|
||
<span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">expand</span><span class="p">(</span><span class="n">target_shape</span><span class="p">)</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
|
||
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">medusa_paths</span> <span class="o">=</span> <span class="n">medusa_info</span><span class="o">.</span><span class="n">medusa_paths</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">medusa_tree_ids</span> <span class="o">=</span> <span class="n">medusa_info</span><span class="o">.</span><span class="n">medusa_tree_ids</span>
|
||
|
||
<span class="c1"># Expand medusa position offsets to number of batch size in order to be compatible with the new Medusa.</span>
|
||
<span class="n">target_shape</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span>
|
||
<span class="n">medusa_info</span><span class="o">.</span><span class="n">medusa_position_offsets</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
|
||
<span class="n">target_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">medusa_position_offsets</span> <span class="o">=</span> <span class="n">medusa_info</span><span class="o">.</span><span class="n">medusa_position_offsets</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span>
|
||
<span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">expand</span><span class="p">(</span><span class="n">target_shape</span><span class="p">)</span><span class="o">.</span><span class="n">int</span><span class="p">()</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
|
||
<span class="n">medusa_fp_mask</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">medusa_mask</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
||
<span class="n">medusa_fp_mask</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">logical_not</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">medusa_mask</span><span class="p">)]</span> <span class="o">=</span> <span class="nb">float</span><span class="p">(</span><span class="s1">'-inf'</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">medusa_mask</span> <span class="o">=</span> <span class="n">medusa_fp_mask</span>
|
||
<span class="k">return</span>
|
||
|
||
<div class="viewcode-block" id="GenerationSession.setup">
|
||
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.setup">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">setup</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
|
||
<span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">max_context_length</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">max_new_tokens</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">beam_width</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">max_attention_window_size</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">sink_token_length</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">encoder_max_input_length</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">lora_manager</span><span class="p">:</span> <span class="n">LoraManager</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">lora_uids</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">medusa_choices</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
|
||
<span class="c1"># Store these params related to buffer size to check against</span>
|
||
<span class="c1"># the input shape with the params given in decode()</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">=</span> <span class="n">batch_size</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">max_context_length</span> <span class="o">=</span> <span class="n">max_context_length</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">max_new_tokens</span> <span class="o">=</span> <span class="n">max_new_tokens</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span> <span class="o">=</span> <span class="n">max_context_length</span> <span class="o">+</span> <span class="n">max_new_tokens</span>
|
||
<span class="k">if</span> <span class="n">medusa_choices</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span> <span class="o">+=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">max_medusa_tokens</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">beam_width</span> <span class="o">=</span> <span class="n">beam_width</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">encoder_max_input_length</span> <span class="o">=</span> <span class="n">encoder_max_input_length</span>
|
||
<span class="k">if</span> <span class="n">max_attention_window_size</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span>
|
||
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span>
|
||
<span class="s2">"The max_attention_window_size is not set, we will use max_seq_length by default."</span>
|
||
<span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">host_max_attention_window_sizes</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span>
|
||
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span><span class="p">,</span> <span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window_size</span>
|
||
|
||
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">max_attention_window_size</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="n">max_attention_window_size</span> <span class="o">></span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">:</span>
|
||
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
|
||
<span class="s2">"The value of max_attention_window_size should ideally not exceed max_seq_length. "</span>
|
||
<span class="s2">"Therefore, it has been adjusted to match the value of max_seq_length."</span>
|
||
<span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window_size</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">max_attention_window_size</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">host_max_attention_window_sizes</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span>
|
||
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span><span class="p">,</span> <span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window_size</span>
|
||
|
||
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">max_attention_window_size</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window_size</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span>
|
||
<span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">max_attention_window_size</span><span class="p">)</span><span class="o">.</span><span class="n">item</span><span class="p">())</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window_size</span> <span class="o">></span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">:</span>
|
||
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
|
||
<span class="s2">"The value of max_attention_window_size should ideally not exceed max_seq_length. "</span>
|
||
<span class="s2">"Therefore, it has been adjusted to match the value of max_seq_length."</span>
|
||
<span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window_size</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window_size</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">max_attention_window_size</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">!=</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span><span class="p">:</span>
|
||
<span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span>
|
||
<span class="s2">"max_attention_window_size tensor's size is not equal to num_layers! "</span>
|
||
<span class="s2">"Note that num_layers = num_total_layers // pipeline_parallelism_size."</span>
|
||
<span class="p">)</span>
|
||
<span class="k">assert</span> <span class="kc">False</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">host_max_attention_window_sizes</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">minimum</span><span class="p">(</span>
|
||
<span class="n">max_attention_window_size</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">torch</span><span class="o">.</span><span class="n">IntTensor</span><span class="p">([</span><span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">]</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span><span class="p">))</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="k">assert</span> <span class="kc">False</span><span class="p">,</span> <span class="s2">"invalid max_attention_window_size!"</span>
|
||
|
||
<span class="k">if</span> <span class="n">sink_token_length</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">sink_token_length</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">host_sink_token_length</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
|
||
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">sink_token_length</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">sink_token_length</span> <span class="o">=</span> <span class="n">sink_token_length</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">host_sink_token_length</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span>
|
||
<span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">sink_token_length</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="k">assert</span> <span class="kc">False</span><span class="p">,</span> <span class="s2">"invalid sink_token_length!"</span>
|
||
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">use_one_more_block</span> <span class="o">=</span> <span class="p">(</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span> <span class="ow">and</span> <span class="n">beam_width</span> <span class="o">></span> <span class="mi">1</span>
|
||
<span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span> <span class="o">></span> <span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window_size</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">lora_manager</span> <span class="o">=</span> <span class="n">lora_manager</span>
|
||
<span class="k">if</span> <span class="n">medusa_choices</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">_init_medusa</span><span class="p">(</span><span class="n">medusa_choices</span><span class="p">)</span>
|
||
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span> <span class="o">=</span> <span class="p">{}</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'logits'</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span>
|
||
<span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_tokens</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_context_logits</span> <span class="k">else</span>
|
||
<span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">max_context_length</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_tensor_dtype</span><span class="p">(</span><span class="s1">'logits'</span><span class="p">),</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||
<span class="n">medusa_logits_shape</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_heads</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span>
|
||
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_tokens</span> <span class="o">+</span> <span class="mi">1</span><span class="p">),</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
|
||
<span class="n">medusa_logits_shape</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_heads</span><span class="p">,</span> <span class="n">batch_size</span> <span class="o">*</span>
|
||
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_tokens</span> <span class="o">+</span> <span class="mi">1</span><span class="p">),</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">)</span>
|
||
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'medusa_logits'</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span>
|
||
<span class="n">medusa_logits_shape</span> <span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_context_logits</span> <span class="k">else</span>
|
||
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_heads</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">max_context_length</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_tensor_dtype</span><span class="p">(</span><span class="s1">'medusa_logits'</span><span class="p">),</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'logits'</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span>
|
||
<span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_context_logits</span> <span class="k">else</span>
|
||
<span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">max_context_length</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_tensor_dtype</span><span class="p">(</span><span class="s1">'logits'</span><span class="p">),</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span><span class="p">:</span>
|
||
<span class="c1"># use shape info to pass max length info in remove padding mode</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'encoder_max_input_length'</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span>
|
||
<span class="p">(</span><span class="n">encoder_max_input_length</span><span class="p">,</span> <span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_tensor_dtype</span><span class="p">(</span><span class="s1">'encoder_max_input_length'</span><span class="p">),</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span><span class="p">:</span>
|
||
<span class="n">bubble_len</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">sink_token_length</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">tokens_per_block</span> <span class="o">></span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="n">bubble_len</span> <span class="o">+=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">tokens_per_block</span> <span class="o">-</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">sink_token_length</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">tokens_per_block</span><span class="p">)</span>
|
||
<span class="n">blocks</span> <span class="o">=</span> <span class="n">batch_size</span> <span class="o">*</span> <span class="n">beam_width</span> <span class="o">*</span> <span class="n">math</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span>
|
||
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window_size</span> <span class="o">+</span> <span class="n">bubble_len</span><span class="p">)</span> <span class="o">/</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">tokens_per_block</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_one_more_block</span><span class="p">:</span>
|
||
<span class="n">blocks</span> <span class="o">+=</span> <span class="n">batch_size</span> <span class="o">*</span> <span class="n">beam_width</span>
|
||
<span class="n">cache_shape</span> <span class="o">=</span> <span class="p">(</span>
|
||
<span class="n">blocks</span><span class="p">,</span>
|
||
<span class="mi">2</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">num_heads_kv</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">tokens_per_block</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">head_size</span><span class="p">,</span>
|
||
<span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">cache_shape</span> <span class="o">=</span> <span class="p">(</span>
|
||
<span class="n">batch_size</span><span class="p">,</span>
|
||
<span class="mi">2</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">num_heads_kv</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window_size</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">head_size</span><span class="p">,</span>
|
||
<span class="p">)</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span><span class="p">:</span>
|
||
<span class="n">cross_cache_shape</span> <span class="o">=</span> <span class="p">(</span>
|
||
<span class="n">batch_size</span><span class="p">,</span>
|
||
<span class="mi">2</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">num_heads_kv</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">encoder_max_input_length</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">head_size</span><span class="p">,</span>
|
||
<span class="p">)</span>
|
||
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">quant_mode</span><span class="o">.</span><span class="n">has_kv_cache_quant</span><span class="p">():</span>
|
||
<span class="c1"># Since torch does not support fp8 now, using int8 here.</span>
|
||
<span class="n">kv_cache_type</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">int8</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">kv_cache_type</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dtype</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span> <span class="k">else</span> <span class="bp">self</span><span class="o">.</span><span class="n">_tensor_dtype</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s1">'present_key_value_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">'present_key_value_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">'</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span>
|
||
<span class="n">cache_shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">kv_cache_type</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">'cross_present_key_value_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">'</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span>
|
||
<span class="n">cross_cache_shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">kv_cache_type</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="c1"># without plugin, we need two set of kv cache buffers,</span>
|
||
<span class="c1"># one for inputs, and the other for outputs.</span>
|
||
<span class="c1"># They will take turns to act as input and output buffers.</span>
|
||
<span class="c1"># Not applicable to cross KV buffers as it's constant</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">):</span>
|
||
<span class="n">trt_dtype</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_dtype</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s1">'present_key_value_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">trt_dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">fp8</span><span class="p">:</span>
|
||
<span class="c1"># PyTorch doesn't support fp8 datatype, use int8 instead of it because int8 datatype size is same with fp8.</span>
|
||
<span class="c1"># TODO: Remove this section when PyTorch support fp8 datatype</span>
|
||
<span class="n">dtype</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">int8</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">dtype</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_tensor_dtype</span><span class="p">(</span><span class="sa">f</span><span class="s1">'present_key_value_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">'1_present_key_value_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">'</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span>
|
||
<span class="n">cache_shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_custom_all_reduce</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span> <span class="o">></span> <span class="mi">1</span><span class="p">:</span>
|
||
<span class="n">set_peer_access</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">ipc_buffers</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">all_reduce_workspace</span> <span class="o">=</span> <span class="n">CustomAllReduceHelper</span><span class="o">.</span><span class="n">allocate_workspace</span><span class="p">(</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="p">,</span>
|
||
<span class="n">CustomAllReduceHelper</span><span class="o">.</span><span class="n">max_workspace_size_auto</span><span class="p">(</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span><span class="p">))</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_lora_plugin</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">lora_manager</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">assert</span> <span class="n">lora_uids</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="n">lora_weights_pointers_list</span> <span class="o">=</span> <span class="p">[</span>
|
||
<span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">2</span><span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int64</span><span class="p">)</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span>
|
||
<span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span><span class="p">)</span>
|
||
<span class="p">]</span>
|
||
|
||
<span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span><span class="p">):</span>
|
||
<span class="n">layer_idx</span> <span class="o">=</span> <span class="n">idx</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span>
|
||
|
||
<span class="k">for</span> <span class="n">lora_module</span> <span class="ow">in</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">lora_target_modules</span> <span class="o">+</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">missing_qkv_modules</span><span class="p">):</span>
|
||
<span class="n">lora_ranks_</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="n">lora_ptrs_</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="k">for</span> <span class="n">batch_idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
|
||
<span class="n">lora_uid</span> <span class="o">=</span> <span class="n">lora_uids</span><span class="p">[</span><span class="n">batch_idx</span><span class="p">]</span>
|
||
<span class="k">if</span> <span class="n">lora_uid</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">lora_uid</span> <span class="o">!=</span> <span class="s2">"-1"</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">lora_manager</span><span class="o">.</span><span class="n">uid_to_low_ranks</span><span class="p">(</span>
|
||
<span class="n">lora_uid</span><span class="p">)[</span><span class="n">layer_idx</span><span class="p">][</span><span class="n">lora_module</span><span class="p">]</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="n">lora_ranks_</span><span class="o">.</span><span class="n">append</span><span class="p">(</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">lora_manager</span><span class="o">.</span><span class="n">uid_to_low_ranks</span><span class="p">(</span><span class="n">lora_uid</span><span class="p">)</span>
|
||
<span class="p">[</span><span class="n">layer_idx</span><span class="p">][</span><span class="n">lora_module</span><span class="p">])</span>
|
||
<span class="n">lora_ptrs_</span><span class="o">.</span><span class="n">append</span><span class="p">(</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">lora_manager</span><span class="o">.</span><span class="n">lora_weights_pointers_list</span><span class="p">[</span>
|
||
<span class="n">layer_idx</span><span class="p">][</span><span class="n">lora_uid</span><span class="p">][</span><span class="n">lora_module</span><span class="p">])</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">lora_ranks_</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
||
<span class="n">lora_ptrs_</span><span class="o">.</span><span class="n">append</span><span class="p">([</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">])</span>
|
||
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="o">.</span><span class="n">update</span><span class="p">({</span>
|
||
<span class="sa">f</span><span class="s1">'</span><span class="si">{</span><span class="n">lora_module</span><span class="si">}</span><span class="s1">_lora_ranks_</span><span class="si">{</span><span class="n">layer_idx</span><span class="si">}</span><span class="s1">'</span><span class="p">:</span>
|
||
<span class="n">torch</span><span class="o">.</span><span class="n">IntTensor</span><span class="p">(</span><span class="n">lora_ranks_</span><span class="p">)</span>
|
||
<span class="p">})</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="o">.</span><span class="n">update</span><span class="p">({</span>
|
||
<span class="sa">f</span><span class="s1">'</span><span class="si">{</span><span class="n">lora_module</span><span class="si">}</span><span class="s1">_lora_weights_pointers_</span><span class="si">{</span><span class="n">layer_idx</span><span class="si">}</span><span class="s1">'</span><span class="p">:</span>
|
||
<span class="n">torch</span><span class="o">.</span><span class="n">LongTensor</span><span class="p">(</span><span class="n">lora_ptrs_</span><span class="p">)</span>
|
||
<span class="p">})</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'medusa_packed_mask'</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">medusa_packed_mask</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span>
|
||
<span class="s1">'medusa_position_offsets'</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">medusa_position_offsets</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">buffer_allocated</span> <span class="o">=</span> <span class="kc">True</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_tokens</span></div>
|
||
|
||
|
||
<span class="k">def</span> <span class="nf">_get_context_shape_buffer</span><span class="p">(</span>
|
||
<span class="bp">self</span><span class="p">,</span>
|
||
<span class="n">input_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">context_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">host_context_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">position_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">last_token_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">attention_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">cross_attention_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">cache_indirection</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">kv_cache_block_pointers</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span>
|
||
<span class="n">host_kv_cache_block_pointers</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span>
|
||
<span class="n">hidden_states_input</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">prompt_embedding_table</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">tasks</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">prompt_vocab_size</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">encoder_output</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">encoder_input_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="n">RuntimeTensor</span><span class="p">]:</span>
|
||
<span class="n">tensors</span> <span class="o">=</span> <span class="p">{}</span>
|
||
|
||
<span class="k">def</span> <span class="nf">sym</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">name</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="n">RuntimeTensor</span><span class="o">.</span><span class="n">from_torch</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span>
|
||
|
||
<span class="k">def</span> <span class="nf">add_tensor</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">name</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="n">tensors</span><span class="o">.</span><span class="n">update</span><span class="p">({</span><span class="n">name</span><span class="p">:</span> <span class="n">sym</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">name</span><span class="p">)})</span>
|
||
|
||
<span class="k">def</span> <span class="nf">add_tensor_with_shape</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">name</span><span class="p">,</span> <span class="n">shape</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="n">tensors</span><span class="o">.</span><span class="n">update</span><span class="p">(</span>
|
||
<span class="p">{</span><span class="n">name</span><span class="p">:</span> <span class="n">RuntimeTensor</span><span class="o">.</span><span class="n">from_torch</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">override_shape</span><span class="o">=</span><span class="n">shape</span><span class="p">)})</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">,</span> <span class="s1">'context_lengths'</span><span class="p">)</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="n">cache_indirection</span><span class="p">,</span> <span class="s1">'cache_indirection'</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_position_embedding</span><span class="p">:</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="n">position_ids</span><span class="p">,</span> <span class="s1">'position_ids'</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span><span class="p">:</span>
|
||
<span class="c1"># in context phase, need to generate cross kv cache, set to True</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">bool</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">),</span>
|
||
<span class="s1">'cross_kv_cache_gen'</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">skip_cross_qkv</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_qkv_reuse</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="c1"># see Attention's self.qkv output dim</span>
|
||
<span class="n">cross_qkv_out_dim</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_heads</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">head_size</span> <span class="o">+</span> <span class="p">(</span>
|
||
<span class="mi">2</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_heads_kv</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">head_size</span><span class="p">)</span>
|
||
<span class="n">cross_qkv_shape</span> <span class="o">=</span> <span class="n">encoder_output</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="p">(</span>
|
||
<span class="n">cross_qkv_out_dim</span><span class="p">,</span> <span class="p">)</span>
|
||
<span class="n">cross_qkv_reuse</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span><span class="n">cross_qkv_shape</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">encoder_output</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="n">encoder_output</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">cross_qkv_reuse</span> <span class="o">=</span> <span class="n">cross_qkv_reuse</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">cross_qkv_reuse</span><span class="p">,</span> <span class="s1">'cross_qkv_reuse'</span><span class="p">)</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="n">encoder_output</span><span class="p">,</span> <span class="s1">'encoder_output'</span><span class="p">)</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="n">encoder_input_lengths</span><span class="p">,</span> <span class="s1">'encoder_input_lengths'</span><span class="p">)</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'encoder_max_input_length'</span><span class="p">],</span>
|
||
<span class="s1">'encoder_max_input_length'</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="n">cross_attention_mask</span><span class="p">,</span> <span class="s1">'cross_attention_mask'</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">has_pp</span><span class="p">():</span>
|
||
<span class="n">hidden_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span>
|
||
<span class="k">if</span> <span class="n">input_ids</span><span class="o">.</span><span class="n">dim</span><span class="p">()</span> <span class="o">==</span> <span class="mi">2</span><span class="p">:</span>
|
||
<span class="n">hidden_states_input</span> <span class="o">=</span> <span class="n">hidden_states_input</span><span class="o">.</span><span class="n">resize_</span><span class="p">(</span>
|
||
<span class="n">input_ids</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">input_ids</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">hidden_size</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">hidden_states_input</span> <span class="o">=</span> <span class="n">hidden_states_input</span><span class="o">.</span><span class="n">resize_</span><span class="p">(</span>
|
||
<span class="n">input_ids</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">hidden_size</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'logits'</span><span class="p">],</span> <span class="s1">'logits'</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span><span class="p">:</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'medusa_logits'</span><span class="p">],</span> <span class="s1">'medusa_logits'</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_context_logits</span><span class="p">:</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="n">last_token_ids</span><span class="p">,</span> <span class="s1">'last_token_ids'</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="n">hidden_states_input</span><span class="p">,</span> <span class="s1">'hidden_states_output'</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_first_pp_rank</span><span class="p">():</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="n">input_ids</span><span class="p">,</span> <span class="s1">'input_ids'</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="n">hidden_states_input</span><span class="p">,</span> <span class="s1">'hidden_states_input'</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">prompt_embedding_table</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="n">prompt_embedding_table</span><span class="p">,</span> <span class="s1">'prompt_embedding_table'</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
|
||
<span class="n">tasks_generation</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">concat</span><span class="p">([</span>
|
||
<span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">([</span><span class="n">context_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]</span><span class="o">.</span><span class="n">item</span><span class="p">()],</span>
|
||
<span class="n">tasks</span><span class="p">[</span><span class="n">b</span><span class="p">]</span><span class="o">.</span><span class="n">item</span><span class="p">(),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
|
||
<span class="k">for</span> <span class="n">b</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">context_lengths</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">))</span>
|
||
<span class="p">])</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">tasks_generation</span> <span class="o">=</span> <span class="n">tasks</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="n">tasks_generation</span><span class="p">,</span> <span class="s1">'tasks'</span><span class="p">)</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="n">prompt_vocab_size</span><span class="p">,</span> <span class="s1">'prompt_vocab_size'</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span><span class="p">:</span>
|
||
<span class="n">buffer</span> <span class="o">=</span> <span class="n">kv_cache_block_pointers</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span>
|
||
<span class="n">shape</span> <span class="o">=</span> <span class="n">kv_cache_block_pointers</span><span class="o">.</span><span class="n">shape</span>
|
||
<span class="n">shape</span> <span class="o">=</span> <span class="p">[</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span> <span class="o">*</span><span class="n">shape</span><span class="p">[</span><span class="mi">3</span><span class="p">:]]</span>
|
||
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="n">buffer</span><span class="p">,</span> <span class="sa">f</span><span class="s1">'kv_cache_block_pointers'</span><span class="p">,</span> <span class="n">shape</span><span class="p">)</span>
|
||
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="n">host_kv_cache_block_pointers</span><span class="p">,</span>
|
||
<span class="sa">f</span><span class="s1">'host_kv_cache_block_pointers'</span><span class="p">,</span> <span class="n">shape</span><span class="p">)</span>
|
||
|
||
<span class="n">batch_size</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span><span class="p">:</span>
|
||
<span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
|
||
<span class="n">kv_cache_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_heads_kv</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">head_size</span><span class="p">)</span>
|
||
<span class="c1"># for empty tensor, TRT does not really use the tensor data, so any dtype is fine</span>
|
||
<span class="n">kv_cache_buffer</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="n">kv_cache_buffer</span><span class="p">,</span>
|
||
<span class="sa">f</span><span class="s1">'past_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">,</span>
|
||
<span class="n">kv_cache_shape</span><span class="p">)</span>
|
||
<span class="n">present</span> <span class="o">=</span> <span class="sa">f</span><span class="s1">'present_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="n">present</span><span class="p">],</span> <span class="n">present</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span><span class="p">:</span>
|
||
<span class="n">cross_kv_cache_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">num_heads_kv</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">head_size</span><span class="p">)</span>
|
||
<span class="c1"># for empty tensor, TRT does not really use the tensor data, so any dtype is fine</span>
|
||
<span class="n">cross_kv_cache_buffer</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="n">cross_kv_cache_buffer</span><span class="p">,</span>
|
||
<span class="sa">f</span><span class="s1">'cross_past_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">,</span>
|
||
<span class="n">cross_kv_cache_shape</span><span class="p">)</span>
|
||
<span class="n">cross_present</span> <span class="o">=</span> <span class="sa">f</span><span class="s1">'cross_present_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="n">cross_present</span><span class="p">],</span> <span class="n">cross_present</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">key_value_cache</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">'present_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">]</span>
|
||
<span class="c1"># when plugin is used, past_ket_value tensor does not need to be empty tensor</span>
|
||
<span class="c1"># because plugin does not care, and does not use this shape.</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="n">key_value_cache</span><span class="p">,</span> <span class="sa">f</span><span class="s1">'past_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="n">key_value_cache</span><span class="p">,</span> <span class="sa">f</span><span class="s1">'present_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span><span class="p">:</span>
|
||
<span class="n">cross_cache_buffer</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span>
|
||
<span class="sa">f</span><span class="s1">'cross_present_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">]</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="n">cross_cache_buffer</span><span class="p">,</span>
|
||
<span class="sa">f</span><span class="s1">'cross_past_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="n">cross_cache_buffer</span><span class="p">,</span>
|
||
<span class="sa">f</span><span class="s1">'cross_present_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
|
||
<span class="c1"># context request</span>
|
||
<span class="n">host_request_types</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">,</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="s1">'cpu'</span><span class="p">)</span><span class="o">.</span><span class="n">int</span><span class="p">()</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
|
||
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span><span class="p">,</span>
|
||
<span class="s1">'sequence_length'</span><span class="p">,</span> <span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="p">))</span>
|
||
|
||
<span class="c1"># field 0: past_key_value_length, field 1: is_context (deprecated). changed to [0], otherwise affects batch padded input mode</span>
|
||
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="n">host_context_lengths</span><span class="p">,</span>
|
||
<span class="s1">'host_past_key_value_lengths'</span><span class="p">,</span> <span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="p">))</span>
|
||
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">host_sink_token_length</span><span class="p">,</span>
|
||
<span class="s1">'host_sink_token_length'</span><span class="p">,</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="p">))</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="n">host_request_types</span><span class="p">,</span> <span class="s1">'host_request_types'</span><span class="p">)</span>
|
||
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">host_max_attention_window_sizes</span><span class="p">,</span>
|
||
<span class="sa">f</span><span class="s1">'host_max_attention_window_sizes'</span><span class="p">,</span>
|
||
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span><span class="p">,</span> <span class="p">))</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="n">host_context_lengths</span><span class="p">,</span> <span class="s1">'host_context_lengths'</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="n">attention_mask</span><span class="p">,</span> <span class="s1">'attention_mask'</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_custom_all_reduce</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span> <span class="o">></span> <span class="mi">1</span><span class="p">:</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">all_reduce_workspace</span><span class="p">,</span> <span class="s1">'all_reduce_workspace'</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_lora_plugin</span><span class="p">:</span>
|
||
<span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span><span class="p">):</span>
|
||
<span class="k">for</span> <span class="n">lora_module</span> <span class="ow">in</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">lora_target_modules</span> <span class="o">+</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">missing_qkv_modules</span><span class="p">):</span>
|
||
<span class="n">layer_idx</span> <span class="o">=</span> <span class="n">idx</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span>
|
||
<span class="n">lora_ranks</span> <span class="o">=</span> <span class="sa">f</span><span class="s1">'</span><span class="si">{</span><span class="n">lora_module</span><span class="si">}</span><span class="s1">_lora_ranks_</span><span class="si">{</span><span class="n">layer_idx</span><span class="si">}</span><span class="s1">'</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="n">lora_ranks</span><span class="p">],</span> <span class="n">lora_ranks</span><span class="p">)</span>
|
||
<span class="n">lora_weights</span> <span class="o">=</span> <span class="sa">f</span><span class="s1">'</span><span class="si">{</span><span class="n">lora_module</span><span class="si">}</span><span class="s1">_lora_weights_pointers_</span><span class="si">{</span><span class="n">layer_idx</span><span class="si">}</span><span class="s1">'</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="n">lora_weights</span><span class="p">],</span> <span class="n">lora_weights</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="n">encoder_input_lengths</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="s1">'cpu'</span><span class="p">),</span>
|
||
<span class="s1">'host_encoder_input_lengths'</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span><span class="p">:</span>
|
||
<span class="c1"># Medusa mask and position offsets are fixed for the whole session.</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'medusa_packed_mask'</span><span class="p">],</span> <span class="s1">'medusa_packed_mask'</span><span class="p">)</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'medusa_position_offsets'</span><span class="p">],</span>
|
||
<span class="s1">'medusa_position_offsets'</span><span class="p">)</span>
|
||
|
||
<span class="k">return</span> <span class="n">tensors</span>
|
||
|
||
<span class="k">def</span> <span class="nf">_get_next_step_shape_buffer</span><span class="p">(</span>
|
||
<span class="bp">self</span><span class="p">,</span>
|
||
<span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">beam_width</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">max_context_length</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">step</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">context_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">host_context_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">position_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">last_token_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">attention_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">cross_attention_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">cache_indirection</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">kv_cache_block_pointers</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span>
|
||
<span class="n">host_kv_cache_block_pointers</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span>
|
||
<span class="n">hidden_states_input</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">prompt_embedding_table</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">tasks</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">prompt_vocab_size</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">encoder_output</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">encoder_input_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
|
||
<span class="n">tensors</span> <span class="o">=</span> <span class="p">{}</span> <span class="c1"># Dict[str, RuntimeTensor]</span>
|
||
|
||
<span class="k">def</span> <span class="nf">sym</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">name</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="n">RuntimeTensor</span><span class="o">.</span><span class="n">from_torch</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span>
|
||
|
||
<span class="k">def</span> <span class="nf">add_tensor</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">name</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="n">tensors</span><span class="o">.</span><span class="n">update</span><span class="p">({</span><span class="n">name</span><span class="p">:</span> <span class="n">sym</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">name</span><span class="p">)})</span>
|
||
|
||
<span class="k">def</span> <span class="nf">add_tensor_with_shape</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">name</span><span class="p">,</span> <span class="n">shape</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="n">tensors</span><span class="o">.</span><span class="n">update</span><span class="p">(</span>
|
||
<span class="p">{</span><span class="n">name</span><span class="p">:</span> <span class="n">RuntimeTensor</span><span class="o">.</span><span class="n">from_torch</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">override_shape</span><span class="o">=</span><span class="n">shape</span><span class="p">)})</span>
|
||
|
||
<span class="n">context_lengths_local</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
|
||
<span class="n">host_context_lengths_local</span> <span class="o">=</span> <span class="n">host_context_lengths</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_context_fmha_for_generation</span><span class="p">:</span>
|
||
<span class="n">context_lengths_local</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones_like</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">,</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">)</span><span class="o">.</span><span class="n">int</span><span class="p">()</span>
|
||
<span class="n">host_context_lengths_local</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones_like</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">,</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="s1">'cpu'</span><span class="p">)</span><span class="o">.</span><span class="n">int</span><span class="p">()</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="n">context_lengths_local</span><span class="p">,</span> <span class="s1">'context_lengths'</span><span class="p">)</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="n">cache_indirection</span><span class="p">,</span> <span class="s1">'cache_indirection'</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">has_pp</span><span class="p">():</span>
|
||
<span class="n">hidden_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span>
|
||
<span class="n">shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">beam_width</span><span class="p">,</span>
|
||
<span class="n">hidden_size</span><span class="p">)</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span> <span class="k">else</span> <span class="p">(</span>
|
||
<span class="n">batch_size</span> <span class="o">*</span> <span class="n">beam_width</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">)</span>
|
||
<span class="n">hidden_states_input</span> <span class="o">=</span> <span class="n">hidden_states_input</span><span class="o">.</span><span class="n">resize_</span><span class="p">(</span><span class="o">*</span><span class="n">shape</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'logits'</span><span class="p">],</span> <span class="s1">'logits'</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span><span class="p">:</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'medusa_logits'</span><span class="p">],</span> <span class="s1">'medusa_logits'</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_context_logits</span><span class="p">:</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="n">last_token_ids</span><span class="p">,</span> <span class="s1">'last_token_ids'</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="n">hidden_states_input</span><span class="p">,</span> <span class="s1">'hidden_states_output'</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_first_pp_rank</span><span class="p">():</span>
|
||
<span class="n">input_ids_shape</span> <span class="o">=</span> <span class="p">(</span>
|
||
<span class="n">batch_size</span> <span class="o">*</span> <span class="n">beam_width</span> <span class="o">*</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_tokens</span> <span class="o">+</span> <span class="mi">1</span><span class="p">),</span>
|
||
<span class="p">)</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span> <span class="k">else</span> <span class="p">(</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">beam_width</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_tokens</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span><span class="p">:</span>
|
||
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">generation_input_ids</span><span class="p">,</span> <span class="s1">'input_ids'</span><span class="p">,</span>
|
||
<span class="n">input_ids_shape</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span><span class="p">,</span> <span class="s1">'input_ids'</span><span class="p">,</span>
|
||
<span class="n">input_ids_shape</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="n">hidden_states_input</span><span class="p">,</span> <span class="s1">'hidden_states_input'</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="n">host_context_lengths_local</span><span class="p">,</span> <span class="s1">'host_context_lengths'</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_position_embedding</span><span class="p">:</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="n">position_ids</span><span class="p">,</span> <span class="s1">'position_ids'</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
|
||
<span class="c1"># disable (or minimize) cross qkv computation at generation phase</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">skip_cross_qkv</span><span class="p">:</span>
|
||
<span class="c1"># disable</span>
|
||
<span class="n">encoder_output_shape</span> <span class="o">=</span> <span class="n">encoder_output</span><span class="o">.</span><span class="n">shape</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">cross_qkv_reuse</span><span class="p">,</span> <span class="s1">'cross_qkv_reuse'</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="c1"># minimize</span>
|
||
<span class="c1"># hacky way: such that qkv gemm becomes a gemv which is cheap and negligible</span>
|
||
<span class="n">encoder_output_shape</span> <span class="o">=</span> <span class="p">[</span>
|
||
<span class="mi">1</span><span class="p">,</span> <span class="n">encoder_output</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
|
||
<span class="p">]</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span> <span class="k">else</span> <span class="p">[</span>
|
||
<span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">encoder_output</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
|
||
<span class="p">]</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="c1"># OOTB path doesn't have kv cache for now, so this encoder_output is</span>
|
||
<span class="c1"># a must-have input. We just use the encoder_output</span>
|
||
<span class="n">encoder_output_shape</span> <span class="o">=</span> <span class="n">encoder_output</span><span class="o">.</span><span class="n">shape</span>
|
||
|
||
<span class="c1"># in generation phase, cross kv cache is already filled during context phase, set to False</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">bool</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">),</span>
|
||
<span class="s1">'cross_kv_cache_gen'</span><span class="p">)</span>
|
||
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="n">encoder_output</span><span class="p">,</span> <span class="s1">'encoder_output'</span><span class="p">,</span>
|
||
<span class="n">encoder_output_shape</span><span class="p">)</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="n">encoder_input_lengths</span><span class="p">,</span> <span class="s1">'encoder_input_lengths'</span><span class="p">)</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'encoder_max_input_length'</span><span class="p">],</span>
|
||
<span class="s1">'encoder_max_input_length'</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="n">cross_attention_mask</span><span class="p">,</span> <span class="s1">'cross_attention_mask'</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span><span class="p">:</span>
|
||
<span class="n">shape</span> <span class="o">=</span> <span class="n">kv_cache_block_pointers</span><span class="o">.</span><span class="n">shape</span>
|
||
<span class="n">shape</span> <span class="o">=</span> <span class="p">[</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span> <span class="o">*</span><span class="n">shape</span><span class="p">[</span><span class="mi">3</span><span class="p">:]]</span>
|
||
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="n">kv_cache_block_pointers</span><span class="p">,</span>
|
||
<span class="sa">f</span><span class="s1">'kv_cache_block_pointers'</span><span class="p">,</span> <span class="n">shape</span><span class="p">)</span>
|
||
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="n">host_kv_cache_block_pointers</span><span class="p">,</span>
|
||
<span class="sa">f</span><span class="s1">'host_kv_cache_block_pointers'</span><span class="p">,</span> <span class="n">shape</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">prompt_embedding_table</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="n">prompt_embedding_table</span><span class="p">,</span> <span class="s1">'prompt_embedding_table'</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
|
||
<span class="n">gen_tasks</span> <span class="o">=</span> <span class="n">tasks</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">gen_tasks</span> <span class="o">=</span> <span class="n">tasks</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="n">gen_tasks</span><span class="p">,</span> <span class="s1">'tasks'</span><span class="p">)</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="n">prompt_vocab_size</span><span class="p">,</span> <span class="s1">'prompt_vocab_size'</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span><span class="p">:</span>
|
||
<span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
|
||
<span class="n">next_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">beam_width</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_heads_kv</span><span class="p">,</span>
|
||
<span class="n">max_context_length</span> <span class="o">+</span> <span class="n">step</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">head_size</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">step</span> <span class="o">%</span> <span class="mi">2</span><span class="p">:</span>
|
||
<span class="n">add_tensor_with_shape</span><span class="p">(</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">'1_present_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">],</span>
|
||
<span class="sa">f</span><span class="s1">'past_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">,</span> <span class="n">next_shape</span><span class="p">)</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">'present_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">],</span>
|
||
<span class="sa">f</span><span class="s1">'present_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">add_tensor_with_shape</span><span class="p">(</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">'present_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">],</span>
|
||
<span class="sa">f</span><span class="s1">'past_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">,</span> <span class="n">next_shape</span><span class="p">)</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">'1_present_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">],</span>
|
||
<span class="sa">f</span><span class="s1">'present_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">key_value_cache</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">'present_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">]</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="n">key_value_cache</span><span class="p">,</span> <span class="sa">f</span><span class="s1">'past_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="n">key_value_cache</span><span class="p">,</span> <span class="sa">f</span><span class="s1">'present_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span><span class="p">:</span>
|
||
<span class="n">cross_cache_buffer</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span>
|
||
<span class="sa">f</span><span class="s1">'cross_present_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">]</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="n">cross_cache_buffer</span><span class="p">,</span>
|
||
<span class="sa">f</span><span class="s1">'cross_past_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="n">cross_cache_buffer</span><span class="p">,</span>
|
||
<span class="sa">f</span><span class="s1">'cross_present_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
|
||
<span class="c1"># generation requests</span>
|
||
<span class="n">host_request_types</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones_like</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">,</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="s1">'cpu'</span><span class="p">)</span><span class="o">.</span><span class="n">int</span><span class="p">()</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_context_fmha_for_generation</span><span class="p">:</span>
|
||
<span class="n">host_request_types</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">,</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="s1">'cpu'</span><span class="p">)</span><span class="o">.</span><span class="n">int</span><span class="p">()</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span><span class="p">:</span>
|
||
<span class="n">host_past_key_value_lengths</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="c1"># previous [past_kv_length, is_context] has been deprecated. only past_kv_length should be given here</span>
|
||
<span class="c1"># Note we should use max_context_length here to align to max -- but isn't this done in attn plugin's max_element() already?</span>
|
||
<span class="n">host_past_key_value_lengths</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span>
|
||
<span class="p">[</span><span class="n">max_context_length</span> <span class="o">+</span> <span class="n">step</span><span class="p">]</span> <span class="o">*</span> <span class="p">(</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">beam_width</span><span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="s1">'cpu'</span><span class="p">)</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="n">host_past_key_value_lengths</span><span class="p">,</span>
|
||
<span class="s1">'host_past_key_value_lengths'</span><span class="p">)</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="n">host_request_types</span><span class="p">,</span> <span class="s1">'host_request_types'</span><span class="p">)</span>
|
||
<span class="c1"># Sequence lengths are not used in the context phase actually.</span>
|
||
<span class="n">sequence_length</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_context_fmha_for_generation</span><span class="p">:</span>
|
||
<span class="n">sequence_length</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
|
||
<span class="n">sequence_length</span> <span class="o">+=</span> <span class="mi">1</span>
|
||
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="n">sequence_length</span><span class="p">,</span> <span class="s1">'sequence_length'</span><span class="p">,</span>
|
||
<span class="p">(</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">beam_width</span><span class="p">,</span> <span class="p">))</span>
|
||
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">host_sink_token_length</span><span class="p">,</span>
|
||
<span class="s1">'host_sink_token_length'</span><span class="p">,</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="p">))</span>
|
||
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">host_max_attention_window_sizes</span><span class="p">,</span>
|
||
<span class="sa">f</span><span class="s1">'host_max_attention_window_sizes'</span><span class="p">,</span>
|
||
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span><span class="p">,</span> <span class="p">))</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="n">host_context_lengths_local</span><span class="p">,</span> <span class="s1">'host_context_lengths'</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="n">attention_mask</span><span class="p">,</span> <span class="s1">'attention_mask'</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_custom_all_reduce</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span> <span class="o">></span> <span class="mi">1</span><span class="p">:</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">all_reduce_workspace</span><span class="p">,</span> <span class="s1">'all_reduce_workspace'</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_lora_plugin</span><span class="p">:</span>
|
||
<span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span><span class="p">):</span>
|
||
<span class="n">layer_idx</span> <span class="o">=</span> <span class="n">idx</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span>
|
||
<span class="k">for</span> <span class="n">lora_module</span> <span class="ow">in</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">lora_target_modules</span> <span class="o">+</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">missing_qkv_modules</span><span class="p">):</span>
|
||
<span class="n">lora_ranks</span> <span class="o">=</span> <span class="sa">f</span><span class="s1">'</span><span class="si">{</span><span class="n">lora_module</span><span class="si">}</span><span class="s1">_lora_ranks_</span><span class="si">{</span><span class="n">layer_idx</span><span class="si">}</span><span class="s1">'</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="n">lora_ranks</span><span class="p">],</span> <span class="n">lora_ranks</span><span class="p">)</span>
|
||
<span class="n">lora_module</span> <span class="o">=</span> <span class="sa">f</span><span class="s1">'</span><span class="si">{</span><span class="n">lora_module</span><span class="si">}</span><span class="s1">_lora_weights_pointers_</span><span class="si">{</span><span class="n">layer_idx</span><span class="si">}</span><span class="s1">'</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="n">lora_module</span><span class="p">],</span> <span class="n">lora_module</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="n">encoder_input_lengths</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="s1">'cpu'</span><span class="p">),</span>
|
||
<span class="s1">'host_encoder_input_lengths'</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span><span class="p">:</span>
|
||
<span class="c1"># Medusa mask and position offsets are fixed for the whole session.</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'medusa_packed_mask'</span><span class="p">],</span> <span class="s1">'medusa_packed_mask'</span><span class="p">)</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'medusa_position_offsets'</span><span class="p">],</span>
|
||
<span class="s1">'medusa_position_offsets'</span><span class="p">)</span>
|
||
|
||
<span class="k">return</span> <span class="n">tensors</span>
|
||
|
||
<span class="k">def</span> <span class="nf">_prepare_context_inputs</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span>
|
||
<span class="n">host_context_lengths</span><span class="p">,</span> <span class="n">use_gpt_attention_plugin</span><span class="p">,</span>
|
||
<span class="n">remove_input_padding</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
|
||
|
||
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">remove_input_padding</span><span class="p">:</span>
|
||
<span class="c1"># For Medusa, last_token_ids should contain the actual indices</span>
|
||
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">last_token_ids</span> <span class="o">-</span> <span class="mi">1</span> <span class="c1"># sub 1 from context_lengths for indices</span>
|
||
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">last_token_ids</span><span class="o">.</span><span class="n">reshape</span><span class="p">([</span><span class="n">batch_size</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">])</span>
|
||
<span class="k">if</span> <span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
|
||
<span class="n">max_context_length</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">'max_context_length'</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">remove_input_padding</span><span class="p">:</span>
|
||
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">concat</span><span class="p">([</span>
|
||
<span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">host_context_lengths</span><span class="p">[</span><span class="n">i</span><span class="p">],</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">)</span>
|
||
<span class="p">])</span>
|
||
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">last_token_ids</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">int</span><span class="p">()</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">max_context_length</span><span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span>
|
||
<span class="p">[</span><span class="mi">1</span><span class="p">,</span>
|
||
<span class="o">-</span><span class="mi">1</span><span class="p">])</span><span class="o">.</span><span class="n">expand</span><span class="p">([</span><span class="n">batch_size</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">])</span>
|
||
<span class="n">ret</span> <span class="o">=</span> <span class="p">{</span><span class="s1">'last_token_ids'</span><span class="p">:</span> <span class="n">last_token_ids</span><span class="p">}</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">input_ids</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">'input_ids'</span><span class="p">)</span>
|
||
<span class="n">pad_id</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">'pad_id'</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
|
||
<span class="n">attention_mask</span> <span class="o">=</span> <span class="n">_prepare_attention_mask</span><span class="p">(</span><span class="n">input_ids</span><span class="p">,</span> <span class="n">pad_id</span><span class="p">)</span>
|
||
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">attention_mask</span><span class="o">.</span><span class="n">long</span><span class="p">()</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span>
|
||
<span class="n">position_ids</span><span class="o">.</span><span class="n">masked_fill_</span><span class="p">(</span><span class="n">attention_mask</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
|
||
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">position_ids</span><span class="o">.</span><span class="n">int</span><span class="p">()</span>
|
||
|
||
<span class="n">ret</span> <span class="o">=</span> <span class="p">{</span>
|
||
<span class="s1">'attention_mask'</span><span class="p">:</span> <span class="n">attention_mask</span><span class="p">,</span>
|
||
<span class="s1">'last_token_ids'</span><span class="p">:</span> <span class="n">last_token_ids</span>
|
||
<span class="p">}</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_position_embedding</span><span class="p">:</span>
|
||
<span class="n">ret</span><span class="p">[</span><span class="s1">'position_ids'</span><span class="p">]</span> <span class="o">=</span> <span class="n">position_ids</span>
|
||
|
||
<span class="k">return</span> <span class="n">ret</span>
|
||
|
||
<span class="k">def</span> <span class="nf">_prepare_generation_inputs</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span>
|
||
<span class="n">use_gpt_attention_plugin</span><span class="p">,</span>
|
||
<span class="n">remove_input_padding</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
|
||
|
||
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones_like</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
|
||
<span class="n">step</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">'step'</span><span class="p">)</span>
|
||
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">context_lengths</span> <span class="o">+</span> <span class="n">step</span>
|
||
<span class="k">if</span> <span class="n">remove_input_padding</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span><span class="p">:</span>
|
||
<span class="c1"># For Medusa, last_token_ids should be [bs * seq] and should contain the actual indices (starts from 1)</span>
|
||
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_tokens</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="n">context_lengths</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">last_token_ids</span><span class="o">.</span><span class="n">expand</span><span class="p">([</span><span class="n">batch_size</span><span class="p">,</span>
|
||
<span class="o">-</span><span class="mi">1</span><span class="p">])</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">last_token_ids</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">int</span><span class="p">()</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span><span class="p">:</span>
|
||
<span class="c1"># For Medusa, last_token_ids should be [bs, seq] and should contain the actual indices (starts from 0)</span>
|
||
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_tokens</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="n">context_lengths</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">last_token_ids</span><span class="o">.</span><span class="n">expand</span><span class="p">([</span><span class="n">batch_size</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">])</span>
|
||
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="n">position_ids</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
|
||
|
||
<span class="n">ret</span> <span class="o">=</span> <span class="p">{</span><span class="s1">'last_token_ids'</span><span class="p">:</span> <span class="n">last_token_ids</span><span class="p">}</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">attention_mask</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">'attention_mask'</span><span class="p">)</span>
|
||
<span class="n">num_beams</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">'num_beams'</span><span class="p">)</span>
|
||
<span class="n">attention_mask</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">attention_mask</span><span class="p">,</span>
|
||
<span class="n">attention_mask</span><span class="o">.</span><span class="n">new_ones</span><span class="p">(</span>
|
||
<span class="p">(</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">num_beams</span><span class="p">,</span> <span class="mi">1</span><span class="p">))),</span>
|
||
<span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span>
|
||
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">attention_mask</span><span class="o">.</span><span class="n">long</span><span class="p">()</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span>
|
||
<span class="n">position_ids</span><span class="o">.</span><span class="n">masked_fill_</span><span class="p">(</span><span class="n">attention_mask</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
|
||
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">position_ids</span><span class="p">[:,</span> <span class="o">-</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">position_ids</span><span class="o">.</span><span class="n">int</span><span class="p">()</span>
|
||
|
||
<span class="n">ret</span> <span class="o">=</span> <span class="p">{</span>
|
||
<span class="s1">'last_token_ids'</span><span class="p">:</span> <span class="n">last_token_ids</span><span class="p">,</span>
|
||
<span class="s1">'attention_mask'</span><span class="p">:</span> <span class="n">attention_mask</span><span class="p">,</span>
|
||
<span class="p">}</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_position_embedding</span><span class="p">:</span>
|
||
<span class="n">ret</span><span class="p">[</span><span class="s1">'position_ids'</span><span class="p">]</span> <span class="o">=</span> <span class="n">position_ids</span>
|
||
|
||
<span class="k">return</span> <span class="n">ret</span>
|
||
|
||
<div class="viewcode-block" id="GenerationSession.pp_communicate_new_tokens">
|
||
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.pp_communicate_new_tokens">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">pp_communicate_new_tokens</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">should_stop</span><span class="p">,</span> <span class="n">cache_indir</span><span class="p">,</span>
|
||
<span class="n">sequence_length</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
|
||
<span class="k">for</span> <span class="n">pg</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_group</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="n">pg</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">rank</span><span class="p">:</span>
|
||
<span class="k">continue</span>
|
||
<span class="n">should_stop</span> <span class="o">=</span> <span class="n">should_stop</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">nccl_comm</span><span class="o">.</span><span class="n">send</span><span class="p">(</span><span class="n">should_stop</span><span class="p">,</span> <span class="n">pg</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">nccl_comm</span><span class="o">.</span><span class="n">send</span><span class="p">(</span><span class="n">cache_indir</span><span class="p">,</span> <span class="n">pg</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">nccl_comm</span><span class="o">.</span><span class="n">send</span><span class="p">(</span><span class="n">sequence_length</span><span class="p">,</span> <span class="n">pg</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">nccl_comm</span><span class="o">.</span><span class="n">send</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_group</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">should_stop</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">bool</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">nccl_comm</span><span class="o">.</span><span class="n">recv</span><span class="p">(</span><span class="n">should_stop</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_group</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">nccl_comm</span><span class="o">.</span><span class="n">recv</span><span class="p">(</span><span class="n">cache_indir</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_group</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">nccl_comm</span><span class="o">.</span><span class="n">recv</span><span class="p">(</span><span class="n">sequence_length</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_group</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_first_pp_rank</span><span class="p">():</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">nccl_comm</span><span class="o">.</span><span class="n">recv</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_group</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
|
||
<span class="k">return</span> <span class="n">should_stop</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="GenerationSession.pp_communicate_final_output_ids">
|
||
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.pp_communicate_final_output_ids">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">pp_communicate_final_output_ids</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">final_output_ids</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span>
|
||
<span class="n">beam_width</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">nccl_comm</span><span class="o">.</span><span class="n">send</span><span class="p">(</span><span class="n">final_output_ids</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_group</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
|
||
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_first_pp_rank</span><span class="p">():</span>
|
||
<span class="n">final_output_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span>
|
||
<span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">nccl_comm</span><span class="o">.</span><span class="n">recv</span><span class="p">(</span><span class="n">final_output_ids</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_group</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
|
||
<span class="k">return</span> <span class="n">final_output_ids</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="GenerationSession.finalize_decoder">
|
||
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.finalize_decoder">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">finalize_decoder</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
|
||
<span class="n">context_lengths</span><span class="p">,</span>
|
||
<span class="n">batch_size</span><span class="p">,</span>
|
||
<span class="n">beam_width</span><span class="p">,</span>
|
||
<span class="n">scfg</span><span class="p">,</span>
|
||
<span class="n">in_progress</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
|
||
<span class="n">final_output_ids</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
|
||
<span class="c1"># output shape of self.gather_tree: [batch_size, beam_width, output_len]</span>
|
||
<span class="n">beam_hyps_args</span> <span class="o">=</span> <span class="p">[</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_output_ids_tgt</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_sequence_lengths_tgt</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_cum_log_probs</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_normed_scores</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_log_probs</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_min_normed_scores</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_num_beams</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_is_done</span>
|
||
<span class="p">]</span>
|
||
<span class="k">if</span> <span class="n">scfg</span><span class="o">.</span><span class="n">use_beam_hyps</span> <span class="ow">and</span> <span class="n">in_progress</span><span class="p">:</span>
|
||
<span class="c1"># self.gather_tree modifies these args.</span>
|
||
<span class="c1"># In streaming mode, this results in incorrect decoding in the following steps.</span>
|
||
<span class="n">beam_hyps_args</span> <span class="o">=</span> <span class="n">copy</span><span class="o">.</span><span class="n">deepcopy</span><span class="p">(</span><span class="n">beam_hyps_args</span><span class="p">)</span>
|
||
|
||
<span class="n">final_output_ids</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_tree</span><span class="p">(</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">output_ids</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">parent_ids</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">end_ids</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">cum_log_probs</span><span class="p">,</span>
|
||
<span class="o">*</span><span class="n">beam_hyps_args</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">finished</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">length_penalty</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">early_stopping</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">use_beam_hyps</span><span class="p">)</span>
|
||
|
||
<span class="c1"># Communicate ranks in Pipeline Parallelism</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">has_pp</span><span class="p">():</span>
|
||
<span class="n">final_output_ids</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">pp_communicate_final_output_ids</span><span class="p">(</span>
|
||
<span class="n">final_output_ids</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">)</span>
|
||
|
||
<span class="k">return</span> <span class="n">final_output_ids</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="GenerationSession.find_best_medusa_path">
|
||
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.find_best_medusa_path">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">find_best_medusa_path</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
|
||
<span class="n">batch_size</span><span class="p">,</span>
|
||
<span class="n">input_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">next_logits</span><span class="p">,</span>
|
||
<span class="n">temp</span><span class="o">=</span><span class="mi">0</span><span class="p">):</span>
|
||
<span class="k">assert</span> <span class="n">input_ids</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_tokens</span> <span class="o">+</span> <span class="mi">1</span>
|
||
<span class="n">best_path</span> <span class="o">=</span> <span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">batch_size</span>
|
||
<span class="n">best_path_len</span> <span class="o">=</span> <span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="n">batch_size</span>
|
||
<span class="n">next_tokens</span> <span class="o">=</span> <span class="p">[</span><span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">batch_size</span>
|
||
<span class="n">zero_pad</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">input_ids</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="n">input_ids</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||
<span class="n">input_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">input_ids</span><span class="p">,</span> <span class="n">zero_pad</span><span class="p">),</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">temp</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="n">new_tokens_raw</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span>
|
||
<span class="n">next_logits</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span>
|
||
<span class="p">)</span> <span class="c1"># TODO: can be done by treating [bs, nT, vocab] as [bs*nT, vocab] and using decoderOp?</span>
|
||
<span class="n">new_tokens</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">new_tokens_raw</span><span class="p">,</span> <span class="n">zero_pad</span><span class="p">),</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="n">input_paths</span> <span class="o">=</span> <span class="p">[</span>
|
||
<span class="n">input_ids</span><span class="p">[</span><span class="n">b</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">medusa_paths</span><span class="p">]</span> <span class="k">for</span> <span class="n">b</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">)</span>
|
||
<span class="p">]</span>
|
||
<span class="n">new_paths</span> <span class="o">=</span> <span class="p">[</span>
|
||
<span class="n">new_tokens</span><span class="p">[</span><span class="n">b</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">medusa_paths</span><span class="p">]</span> <span class="k">for</span> <span class="n">b</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">)</span>
|
||
<span class="p">]</span>
|
||
<span class="k">for</span> <span class="n">b</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
|
||
<span class="n">equality</span> <span class="o">=</span> <span class="n">input_paths</span><span class="p">[</span><span class="n">b</span><span class="p">][:,</span> <span class="mi">1</span><span class="p">:]</span> <span class="o">==</span> <span class="n">new_paths</span><span class="p">[</span><span class="n">b</span><span class="p">][:,</span> <span class="p">:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
|
||
<span class="n">paths_correct_len</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cumprod</span><span class="p">(</span><span class="n">equality</span><span class="o">.</span><span class="n">int</span><span class="p">(),</span>
|
||
<span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="n">best_path_len</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="o">=</span> <span class="n">paths_correct_len</span><span class="o">.</span><span class="n">max</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span> <span class="o">+</span> <span class="mi">1</span>
|
||
<span class="k">if</span> <span class="n">best_path_len</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="o">></span> <span class="mi">1</span><span class="p">:</span>
|
||
<span class="n">best_path</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">paths_correct_len</span><span class="p">)</span>
|
||
<span class="n">next_tokens</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="o">=</span> <span class="n">new_paths</span><span class="p">[</span><span class="n">b</span><span class="p">][</span>
|
||
<span class="n">best_path</span><span class="p">[</span><span class="n">b</span><span class="p">]][:</span><span class="n">best_path_len</span><span class="p">[</span><span class="n">b</span><span class="p">]]</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
|
||
|
||
<span class="k">return</span> <span class="n">best_path</span><span class="p">,</span> <span class="n">best_path_len</span><span class="p">,</span> <span class="n">next_tokens</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="GenerationSession.filter_medusa_logits">
|
||
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.filter_medusa_logits">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">filter_medusa_logits</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">best_path</span><span class="p">,</span> <span class="n">best_path_lengths</span><span class="p">,</span>
|
||
<span class="n">medusa_logits</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> medusa_logits is of shape [nMH, bs, nMT+1, vocab]</span>
|
||
|
||
<span class="sd"> Returns [nMH, bs, vocab]</span>
|
||
<span class="sd"> """</span>
|
||
<span class="n">filtered_logits</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span>
|
||
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_heads</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">medusa_logits</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="n">medusa_logits</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||
<span class="n">medusa_logits</span> <span class="o">=</span> <span class="n">medusa_logits</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_heads</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_tokens</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="k">for</span> <span class="n">b</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
|
||
<span class="n">idx</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">medusa_paths</span><span class="p">[</span><span class="n">best_path</span><span class="p">[</span><span class="n">b</span><span class="p">],</span> <span class="n">best_path_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="o">-</span> <span class="mi">1</span><span class="p">]</span>
|
||
<span class="n">filtered_logits</span><span class="p">[:,</span> <span class="n">b</span><span class="p">,</span> <span class="o">...</span><span class="p">]</span> <span class="o">=</span> <span class="n">medusa_logits</span><span class="p">[:,</span> <span class="n">b</span><span class="p">,</span> <span class="n">idx</span><span class="p">,</span> <span class="o">...</span><span class="p">]</span>
|
||
<span class="k">return</span> <span class="n">filtered_logits</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="GenerationSession.get_next_medusa_tokens">
|
||
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.get_next_medusa_tokens">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">get_next_medusa_tokens</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">next_medusa_logits</span><span class="p">):</span>
|
||
<span class="n">next_medusa_tokens</span> <span class="o">=</span> <span class="p">[</span>
|
||
<span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="n">next_medusa_logits</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||
<span class="p">]</span> <span class="c1"># dummy token for now, TODO: update tree_ids and remove this</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_heads</span><span class="p">):</span>
|
||
<span class="n">medusa_token</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">topk</span><span class="p">(</span><span class="n">next_medusa_logits</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="p">:,</span> <span class="p">:],</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">medusa_topks</span><span class="p">[</span><span class="n">i</span><span class="p">],</span>
|
||
<span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">indices</span>
|
||
<span class="n">next_medusa_tokens</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">medusa_token</span><span class="p">)</span>
|
||
<span class="n">next_medusa_tokens</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="n">next_medusa_tokens</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">next_medusa_tokens</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="GenerationSession.update_kv_cache_draft_token_location">
|
||
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.update_kv_cache_draft_token_location">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">update_kv_cache_draft_token_location</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">best_path</span><span class="p">,</span>
|
||
<span class="n">best_path_len</span><span class="p">):</span>
|
||
<span class="n">best_path_len_tensor</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">best_path_len</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int</span><span class="p">,</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">)</span>
|
||
<span class="n">accepted_draft_token_counts</span> <span class="o">=</span> <span class="n">best_path_len_tensor</span> <span class="o">-</span> <span class="mi">1</span>
|
||
<span class="n">accepted_draft_token_offsets</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">batch_size</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">)</span>
|
||
<span class="n">accepted_draft_token_offsets</span><span class="p">[</span><span class="mi">1</span><span class="p">:]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span>
|
||
<span class="n">accepted_draft_token_counts</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
|
||
<span class="n">accepted_draft_token_offsets_cpu</span> <span class="o">=</span> <span class="n">accepted_draft_token_offsets</span><span class="o">.</span><span class="n">to</span><span class="p">(</span>
|
||
<span class="s1">'cpu'</span><span class="p">)</span>
|
||
<span class="n">packed_accepted_draft_tokens_indices</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span>
|
||
<span class="n">accepted_draft_token_offsets_cpu</span><span class="p">[</span><span class="n">batch_size</span><span class="p">],</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">)</span>
|
||
<span class="k">for</span> <span class="n">seq_idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
|
||
<span class="n">seq_start</span> <span class="o">=</span> <span class="n">accepted_draft_token_offsets_cpu</span><span class="p">[</span><span class="n">seq_idx</span><span class="p">]</span>
|
||
<span class="n">seq_end</span> <span class="o">=</span> <span class="n">accepted_draft_token_offsets_cpu</span><span class="p">[</span><span class="n">seq_idx</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]</span>
|
||
<span class="n">seq_accepted_draft_count</span> <span class="o">=</span> <span class="n">seq_end</span> <span class="o">-</span> <span class="n">seq_start</span>
|
||
<span class="n">best_path_idx</span> <span class="o">=</span> <span class="n">best_path</span><span class="p">[</span><span class="n">seq_idx</span><span class="p">]</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span>
|
||
<span class="n">best_path</span><span class="p">[</span><span class="n">seq_idx</span><span class="p">],</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> <span class="k">else</span> <span class="n">best_path</span><span class="p">[</span><span class="n">seq_idx</span><span class="p">]</span>
|
||
<span class="n">seq_accepted_token_indices</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">medusa_paths</span><span class="p">[</span>
|
||
<span class="n">best_path_idx</span><span class="p">,</span> <span class="mi">1</span><span class="p">:</span><span class="mi">1</span> <span class="o">+</span> <span class="n">seq_accepted_draft_count</span><span class="p">]</span>
|
||
<span class="n">packed_accepted_draft_tokens_indices</span><span class="p">[</span>
|
||
<span class="n">seq_start</span><span class="p">:</span><span class="n">seq_end</span><span class="p">]</span> <span class="o">=</span> <span class="n">seq_accepted_token_indices</span> <span class="o">-</span> <span class="mi">1</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_updater</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">accepted_draft_token_offsets</span><span class="p">,</span>
|
||
<span class="n">packed_accepted_draft_tokens_indices</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_tokens</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span> <span class="o">+=</span> <span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="GenerationSession.update_output_ids_by_offset">
|
||
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.update_output_ids_by_offset">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">update_output_ids_by_offset</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">new_generated_ids</span><span class="p">,</span> <span class="n">offsets</span><span class="p">):</span>
|
||
<span class="c1"># output_ids [batch_size, padded_input_length]</span>
|
||
<span class="c1"># new_generated_ids [batch_size, padded_accepted_length]</span>
|
||
<span class="c1"># offsets [batch_size]</span>
|
||
<span class="c1"># FIXME: using fused kernel to update the padded output ids.</span>
|
||
<span class="n">batch_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">output_ids</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
||
<span class="k">for</span> <span class="n">b</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">output_ids</span><span class="p">[</span><span class="n">b</span><span class="p">,</span> <span class="n">offsets</span><span class="p">[</span><span class="n">b</span><span class="p">]:(</span>
|
||
<span class="n">offsets</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]</span>
|
||
<span class="p">)]</span> <span class="o">=</span> <span class="n">new_generated_ids</span><span class="p">[</span><span class="n">b</span><span class="p">][:</span><span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]]</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="GenerationSession.next_medusa_input_ids">
|
||
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.next_medusa_input_ids">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">next_medusa_input_ids</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="c1"># self.new_tokens [batch_size, padded_accepted_length]</span>
|
||
<span class="c1"># self.accept_lengths [batch_size]</span>
|
||
<span class="c1"># self.medusa_new_tokens [batch_size, num_medusa_tokens]</span>
|
||
<span class="c1"># FIXME: using fused kernel to generate the new medusa input ids.</span>
|
||
<span class="n">batch_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
||
<span class="k">for</span> <span class="n">b</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">generation_input_ids</span><span class="p">[</span><span class="n">b</span><span class="p">,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span><span class="p">[</span>
|
||
<span class="n">b</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="o">-</span> <span class="mi">1</span><span class="p">]</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">generation_input_ids</span><span class="p">[</span><span class="n">b</span><span class="p">,</span> <span class="mi">1</span><span class="p">:]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">medusa_output_tokens</span><span class="p">[</span><span class="n">b</span><span class="p">,</span> <span class="p">:]</span></div>
|
||
|
||
|
||
<span class="c1"># OPTIMIZE: need to optimize this early-stop workflow.</span>
|
||
<div class="viewcode-block" id="GenerationSession.early_stop_criteria">
|
||
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.early_stop_criteria">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">early_stop_criteria</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">step</span><span class="p">,</span> <span class="n">should_stop</span><span class="p">):</span>
|
||
<span class="k">for</span> <span class="n">b</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">medusa_should_step</span><span class="p">[</span><span class="n">b</span><span class="p">]:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="k">continue</span>
|
||
<span class="c1"># output sequence length criteria.</span>
|
||
<span class="n">prev_total_output_length</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">total_accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]</span>
|
||
<span class="c1"># end id criteria.</span>
|
||
<span class="n">should_stop_with_end_id</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">any</span><span class="p">(</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span><span class="p">[</span><span class="n">b</span><span class="p">,</span> <span class="p">:</span><span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]]</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">end_ids</span><span class="p">[</span><span class="n">b</span><span class="p">])</span>
|
||
<span class="n">end_id_pos</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span><span class="p">[</span><span class="n">b</span><span class="p">,</span> <span class="p">:</span><span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]]</span> <span class="o">==</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">end_ids</span><span class="p">[</span><span class="n">b</span><span class="p">])</span><span class="o">.</span><span class="n">nonzero</span><span class="p">(</span><span class="n">as_tuple</span><span class="o">=</span><span class="kc">True</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">medusa_should_step</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">medusa_should_step</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="ow">or</span> <span class="p">(</span>
|
||
<span class="n">prev_total_output_length</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="o">>=</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">max_new_tokens</span><span class="p">)</span> <span class="ow">or</span> <span class="n">should_stop_with_end_id</span>
|
||
<span class="c1"># update accept lengths for the current step.</span>
|
||
<span class="k">if</span> <span class="p">(</span><span class="n">prev_total_output_length</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="o">>=</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">max_new_tokens</span><span class="p">):</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">max_new_tokens</span> <span class="o">-</span> <span class="n">prev_total_output_length</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">])</span>
|
||
<span class="k">if</span> <span class="n">should_stop_with_end_id</span><span class="p">:</span>
|
||
<span class="c1"># get the position of first end_id.</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">end_id_pos</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">])</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">total_accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="o">+=</span> <span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]</span>
|
||
|
||
<span class="n">should_stop</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">should_stop</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="ow">or</span> <span class="p">(</span><span class="n">step</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_new_tokens</span> <span class="o">-</span>
|
||
<span class="mi">1</span><span class="p">)</span> <span class="ow">or</span> <span class="n">torch</span><span class="o">.</span><span class="n">all</span><span class="p">(</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">medusa_should_step</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">should_stop</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="GenerationSession.process_logits_for_medusa_mode">
|
||
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.process_logits_for_medusa_mode">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">process_logits_for_medusa_mode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">step</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">input_ids</span><span class="p">,</span>
|
||
<span class="n">logits</span><span class="p">,</span> <span class="n">context_has_medusa_tokens</span><span class="p">,</span>
|
||
<span class="n">next_step_buffer</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">):</span>
|
||
<span class="n">medusa_logits</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'medusa_logits'</span><span class="p">]</span>
|
||
<span class="n">best_path</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="n">best_path_lengths</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="n">should_stop</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="kc">False</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="nb">bool</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">step</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="c1"># logits buffer is of shape [bs, medusa_tokens+1, vocab]</span>
|
||
<span class="c1"># but during context phase, we get only [bs, 1, vocab] but contiguous</span>
|
||
<span class="n">logits</span> <span class="o">=</span> <span class="n">logits</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)[:</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">logits</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]]</span><span class="o">.</span><span class="n">view</span><span class="p">(</span>
|
||
<span class="n">batch_size</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="n">next_main_token_logits</span> <span class="o">=</span> <span class="n">logits</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">decoder_logits_dtype</span><span class="p">)</span>
|
||
<span class="n">next_main_token</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">next_main_token_logits</span><span class="p">,</span>
|
||
<span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span> <span class="o">=</span> <span class="n">next_main_token</span>
|
||
<span class="c1"># NOTE: stop criteria.</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">medusa_should_step</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">eq</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">),</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">end_ids</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">torch</span><span class="o">.</span><span class="n">equal</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="n">end_ids</span><span class="p">):</span>
|
||
<span class="c1"># stop if context phase output EOS</span>
|
||
<span class="n">should_stop</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="kc">True</span>
|
||
<span class="c1"># NOTE: only one token's medusa logit will be written in.</span>
|
||
<span class="n">medusa_logits</span> <span class="o">=</span> <span class="n">medusa_logits</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_tokens</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span>
|
||
<span class="o">-</span><span class="mi">1</span><span class="p">)[</span><span class="mi">0</span><span class="p">,</span> <span class="o">...</span><span class="p">]</span>
|
||
<span class="n">next_medusa_logits</span> <span class="o">=</span> <span class="n">medusa_logits</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_heads</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span>
|
||
<span class="o">-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">decoder_logits_dtype</span><span class="p">)</span>
|
||
<span class="n">next_medusa_tokens</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_next_medusa_tokens</span><span class="p">(</span>
|
||
<span class="n">batch_size</span><span class="p">,</span> <span class="n">next_medusa_logits</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">medusa_output_tokens</span> <span class="o">=</span> <span class="n">next_medusa_tokens</span><span class="p">[:,</span> <span class="bp">self</span><span class="o">.</span><span class="n">medusa_tree_ids</span><span class="p">[</span>
|
||
<span class="o">-</span><span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_tokens</span><span class="p">:]]</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">total_accept_lengths</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">next_token_logits</span> <span class="o">=</span> <span class="n">logits</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">decoder_logits_dtype</span><span class="p">)</span>
|
||
|
||
<span class="n">best_path</span><span class="p">,</span> <span class="n">best_path_lengths</span><span class="p">,</span> <span class="n">next_main_tokens</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">find_best_medusa_path</span><span class="p">(</span>
|
||
<span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">generation_input_ids</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">),</span>
|
||
<span class="n">next_token_logits</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_tokens</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span>
|
||
<span class="o">-</span><span class="mi">1</span><span class="p">))</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">best_path_lengths</span><span class="p">,</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nested</span><span class="o">.</span><span class="n">to_padded_tensor</span><span class="p">(</span>
|
||
<span class="n">torch</span><span class="o">.</span><span class="n">nested</span><span class="o">.</span><span class="n">nested_tensor</span><span class="p">(</span><span class="n">next_main_tokens</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">end_ids</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> <span class="c1">#FIXME end id padding.</span>
|
||
<span class="n">next_medusa_logits</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">filter_medusa_logits</span><span class="p">(</span>
|
||
<span class="n">batch_size</span><span class="p">,</span> <span class="n">best_path</span><span class="p">,</span> <span class="n">best_path_lengths</span><span class="p">,</span> <span class="n">medusa_logits</span><span class="p">)</span>
|
||
<span class="n">next_medusa_tokens</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_next_medusa_tokens</span><span class="p">(</span>
|
||
<span class="n">batch_size</span><span class="p">,</span> <span class="n">next_medusa_logits</span><span class="p">)</span>
|
||
|
||
<span class="n">should_stop</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">early_stop_criteria</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">step</span><span class="p">,</span>
|
||
<span class="n">should_stop</span><span class="p">)</span>
|
||
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">medusa_output_tokens</span> <span class="o">=</span> <span class="n">next_medusa_tokens</span><span class="p">[:,</span> <span class="bp">self</span><span class="o">.</span><span class="n">medusa_tree_ids</span><span class="p">[</span>
|
||
<span class="o">-</span><span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_tokens</span><span class="p">:]]</span>
|
||
|
||
<span class="c1"># NOTE: self.accept_lengths are the lengths of accepted tokens in the current step</span>
|
||
<span class="c1"># NOTE: self.sequence_length_buffer = num_past_kv_cache (accepted) + num_medusa_tokens + 1</span>
|
||
<span class="k">if</span> <span class="n">step</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">update_output_ids_by_offset</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="c1"># Note:self.sequence_length_buffer = num_past_kv_cache (accepted) + num_medusa_tokens</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">update_output_ids_by_offset</span><span class="p">(</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_tokens</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">step</span> <span class="o">!=</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_new_tokens</span> <span class="o">-</span> <span class="mi">1</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">should_stop</span><span class="o">.</span><span class="n">item</span><span class="p">():</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">next_medusa_input_ids</span><span class="p">()</span>
|
||
<span class="k">if</span> <span class="n">step</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="k">assert</span> <span class="n">best_path</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">best_path_lengths</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">update_kv_cache_draft_token_location</span><span class="p">(</span>
|
||
<span class="n">batch_size</span><span class="p">,</span> <span class="n">best_path</span><span class="p">,</span> <span class="n">best_path_lengths</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span> <span class="o">+=</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_tokens</span> <span class="o">+</span> <span class="mi">1</span>
|
||
|
||
<span class="c1"># NOTE: set the accepted tokens for the last step.</span>
|
||
<span class="k">if</span> <span class="n">should_stop</span><span class="o">.</span><span class="n">item</span><span class="p">():</span>
|
||
<span class="c1"># remove num_medusa_tokens for next generation.</span>
|
||
<span class="c1"># Runtime: denotes kv cache length start positions.</span>
|
||
<span class="c1"># Output: denotes the length of sequence length (input ids + output ids)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_tokens</span>
|
||
|
||
<span class="n">next_step_buffer</span><span class="p">[</span><span class="s1">'host_past_key_value_lengths'</span><span class="p">]</span><span class="o">.</span><span class="n">to_torch</span><span class="p">()</span><span class="o">.</span><span class="n">copy_</span><span class="p">(</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span><span class="p">)</span>
|
||
|
||
<span class="k">return</span> <span class="n">should_stop</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="GenerationSession.handle_per_step">
|
||
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.handle_per_step">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">handle_per_step</span><span class="p">(</span>
|
||
<span class="bp">self</span><span class="p">,</span> <span class="n">cache_indirections</span><span class="p">:</span> <span class="nb">list</span><span class="p">,</span> <span class="n">step</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">max_context_length</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">input_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">hidden_states</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">scfg</span><span class="p">:</span> <span class="n">SamplingConfig</span><span class="p">,</span>
|
||
<span class="n">kv_cache_block_pointers</span><span class="p">:</span> <span class="nb">list</span><span class="p">,</span> <span class="n">host_kv_cache_block_pointers</span><span class="p">:</span> <span class="nb">list</span><span class="p">,</span>
|
||
<span class="n">prompt_embedding_table</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">tasks</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">context_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">host_context_lengths</span><span class="p">,</span>
|
||
<span class="n">attention_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">cross_attention_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">prompt_vocab_size</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">ite</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">sequence_limit_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">sequence_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">next_step_tensors</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">RuntimeTensor</span><span class="p">],</span> <span class="n">stop_words_data</span><span class="p">,</span>
|
||
<span class="n">bad_words_data</span><span class="p">,</span> <span class="n">no_repeat_ngram_size</span><span class="p">,</span> <span class="n">encoder_output</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">encoder_input_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">stopping_criteria</span><span class="p">:</span> <span class="n">StoppingCriteria</span><span class="p">,</span>
|
||
<span class="n">logits_processor</span><span class="p">:</span> <span class="n">LogitsProcessor</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="n">step</span> <span class="o">%</span> <span class="mi">2</span><span class="p">:</span>
|
||
<span class="n">context</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">context_0</span>
|
||
<span class="n">this_src_cache_indirection</span> <span class="o">=</span> <span class="n">cache_indirections</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
|
||
<span class="n">this_tgt_cache_indirection</span> <span class="o">=</span> <span class="n">cache_indirections</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
||
<span class="n">next_src_cache_indirection</span> <span class="o">=</span> <span class="n">cache_indirections</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">context</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">context_1</span>
|
||
<span class="n">this_src_cache_indirection</span> <span class="o">=</span> <span class="n">cache_indirections</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
||
<span class="n">this_tgt_cache_indirection</span> <span class="o">=</span> <span class="n">cache_indirections</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
|
||
<span class="n">next_src_cache_indirection</span> <span class="o">=</span> <span class="n">cache_indirections</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
|
||
|
||
<span class="k">if</span> <span class="n">step</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="n">model_inputs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_prepare_context_inputs</span><span class="p">(</span>
|
||
<span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span>
|
||
<span class="n">context_lengths</span><span class="o">=</span><span class="n">context_lengths</span><span class="p">,</span>
|
||
<span class="n">host_context_lengths</span><span class="o">=</span><span class="n">host_context_lengths</span><span class="p">,</span>
|
||
<span class="n">use_gpt_attention_plugin</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">,</span>
|
||
<span class="n">remove_input_padding</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">,</span>
|
||
<span class="n">max_context_length</span><span class="o">=</span><span class="n">max_context_length</span><span class="p">,</span>
|
||
<span class="n">input_ids</span><span class="o">=</span><span class="n">input_ids</span><span class="p">,</span>
|
||
<span class="n">pad_id</span><span class="o">=</span><span class="n">scfg</span><span class="o">.</span><span class="n">pad_id</span><span class="p">,</span>
|
||
<span class="n">eos_id</span><span class="o">=</span><span class="n">scfg</span><span class="o">.</span><span class="n">end_id</span><span class="p">)</span>
|
||
|
||
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">model_inputs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">'position_ids'</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
|
||
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">model_inputs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">'last_token_ids'</span><span class="p">)</span>
|
||
<span class="n">attention_mask</span> <span class="o">=</span> <span class="n">model_inputs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">'attention_mask'</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span><span class="p">:</span>
|
||
<span class="n">host_kv_cache_block_pointers</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_manager</span><span class="o">.</span><span class="n">get_block_pointers</span><span class="p">(</span>
|
||
<span class="mi">1</span><span class="p">)</span>
|
||
<span class="n">kv_cache_block_pointers</span> <span class="o">=</span> <span class="n">host_kv_cache_block_pointers</span><span class="o">.</span><span class="n">to</span><span class="p">(</span>
|
||
<span class="s1">'cuda'</span><span class="p">)</span>
|
||
|
||
<span class="n">ctx_tensors</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_context_shape_buffer</span><span class="p">(</span>
|
||
<span class="n">input_ids</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span> <span class="n">host_context_lengths</span><span class="p">,</span> <span class="n">position_ids</span><span class="p">,</span>
|
||
<span class="n">last_token_ids</span><span class="p">,</span> <span class="n">attention_mask</span><span class="p">,</span> <span class="n">cross_attention_mask</span><span class="p">,</span>
|
||
<span class="n">this_src_cache_indirection</span><span class="p">,</span> <span class="n">kv_cache_block_pointers</span><span class="p">,</span>
|
||
<span class="n">host_kv_cache_block_pointers</span><span class="p">,</span> <span class="n">hidden_states</span><span class="p">,</span>
|
||
<span class="n">prompt_embedding_table</span><span class="p">,</span> <span class="n">tasks</span><span class="p">,</span> <span class="n">prompt_vocab_size</span><span class="p">,</span>
|
||
<span class="n">encoder_output</span><span class="p">,</span> <span class="n">encoder_input_lengths</span><span class="p">)</span>
|
||
<span class="n">context</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">ctx_context</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">_set_tensors</span><span class="p">(</span><span class="n">context</span><span class="p">,</span> <span class="n">ctx_tensors</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">debug_mode</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">debug_buffer</span> <span class="o">=</span> <span class="p">{</span>
|
||
<span class="n">name</span><span class="p">:</span> <span class="n">tensor</span><span class="o">.</span><span class="n">to_torch</span><span class="p">()</span>
|
||
<span class="k">for</span> <span class="n">name</span><span class="p">,</span> <span class="n">tensor</span> <span class="ow">in</span> <span class="n">ctx_tensors</span><span class="o">.</span><span class="n">items</span><span class="p">()</span>
|
||
<span class="p">}</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cuda_graph_mode</span><span class="p">:</span>
|
||
<span class="c1"># context mode, clean cuda graph instances</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">cuda_graph_instances</span> <span class="o">=</span> <span class="p">[</span><span class="kc">None</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">2</span><span class="p">)]</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">debug_mode</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">_check_tensors</span><span class="p">(</span><span class="n">context</span><span class="p">)</span>
|
||
<span class="c1"># dynamic_decoder currently use torch's current stream, so must let TRT enqueue use same stream here</span>
|
||
<span class="n">stream</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">current_stream</span><span class="p">()</span><span class="o">.</span><span class="n">cuda_stream</span>
|
||
<span class="n">instance_idx</span> <span class="o">=</span> <span class="n">step</span> <span class="o">%</span> <span class="mi">2</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cuda_graph_mode</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">cuda_graph_instances</span><span class="p">[</span>
|
||
<span class="n">instance_idx</span><span class="p">]</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="c1"># launch cuda graph</span>
|
||
<span class="n">CUASSERT</span><span class="p">(</span>
|
||
<span class="n">cudart</span><span class="o">.</span><span class="n">cudaGraphLaunch</span><span class="p">(</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">cuda_graph_instances</span><span class="p">[</span><span class="n">instance_idx</span><span class="p">],</span> <span class="n">stream</span><span class="p">))</span>
|
||
<span class="n">ok</span> <span class="o">=</span> <span class="kc">True</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">ok</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">_run</span><span class="p">(</span><span class="n">context</span><span class="p">,</span> <span class="n">stream</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="n">ok</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Executing TRT engine failed step=</span><span class="si">{</span><span class="n">step</span><span class="si">}</span><span class="s2">!"</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">debug_mode</span><span class="p">:</span>
|
||
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">synchronize</span><span class="p">()</span>
|
||
|
||
<span class="n">context_logits</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
|
||
<span class="k">if</span> <span class="n">step</span> <span class="o">==</span> <span class="mi">0</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_context_logits</span><span class="p">:</span>
|
||
<span class="k">assert</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span>
|
||
<span class="n">context_logits</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'logits'</span><span class="p">]</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
|
||
<span class="c1"># gather last token of context</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
|
||
<span class="c1"># reshape self.buffer['logits'] from [bs, max_context_length, vocab]</span>
|
||
<span class="c1"># to [1, bs * max_context_length, vocab]</span>
|
||
<span class="c1"># Note that the data are put in the buffer without padding although</span>
|
||
<span class="c1"># the allocated buffer has padding.</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'logits'</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'logits'</span><span class="p">]</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span>
|
||
<span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">])</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'logits'</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">index_select</span><span class="p">(</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'logits'</span><span class="p">],</span> <span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">last_token_ids</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">last_token_ids</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
|
||
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">last_token_ids</span><span class="o">.</span><span class="n">expand</span><span class="p">(</span>
|
||
<span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'logits'</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">gather</span><span class="p">(</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'logits'</span><span class="p">],</span>
|
||
<span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">index</span><span class="o">=</span><span class="n">last_token_ids</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int64</span><span class="p">))</span><span class="o">.</span><span class="n">view</span><span class="p">(</span>
|
||
<span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">step</span> <span class="o">==</span> <span class="mi">0</span> <span class="ow">and</span> <span class="n">beam_width</span> <span class="o">></span> <span class="mi">1</span><span class="p">:</span>
|
||
<span class="k">assert</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span>
|
||
<span class="c1"># these tiled tensors are returned by handle_per_step(), so they can relay to the next generation calls</span>
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
|
||
<span class="n">attention_mask</span> <span class="o">=</span> <span class="n">_tile_beam_width</span><span class="p">(</span><span class="n">attention_mask</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">)</span>
|
||
<span class="n">context_lengths</span> <span class="o">=</span> <span class="n">_tile_beam_width</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">)</span>
|
||
<span class="n">host_context_lengths</span> <span class="o">=</span> <span class="n">_tile_beam_width</span><span class="p">(</span><span class="n">host_context_lengths</span><span class="p">,</span>
|
||
<span class="n">beam_width</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">encoder_input_lengths</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">encoder_input_lengths</span> <span class="o">=</span> <span class="n">_tile_beam_width</span><span class="p">(</span>
|
||
<span class="n">encoder_input_lengths</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">tasks</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">tasks</span> <span class="o">=</span> <span class="n">_tile_beam_width</span><span class="p">(</span><span class="n">tasks</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">)</span>
|
||
|
||
<span class="c1"># Move tiling before logit computing of context</span>
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span><span class="p">:</span>
|
||
<span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
|
||
<span class="c1"># Note: this tiles both self attn cache and cross attn cache!</span>
|
||
<span class="c1"># both names contain "present_key_value"</span>
|
||
<span class="k">if</span> <span class="s2">"present_key_value"</span> <span class="ow">in</span> <span class="n">key</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">_tile_beam_width</span><span class="p">(</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="n">key</span><span class="p">],</span> <span class="n">beam_width</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'logits'</span><span class="p">]</span> <span class="o">=</span> <span class="n">_tile_beam_width</span><span class="p">(</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'logits'</span><span class="p">],</span> <span class="n">beam_width</span><span class="p">)</span>
|
||
|
||
<span class="n">generation_logits</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_generation_logits</span><span class="p">:</span>
|
||
<span class="n">generation_logits</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'logits'</span><span class="p">]</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
|
||
|
||
<span class="c1"># Initialize sequence_lengths (no paddings) for the generation phase.</span>
|
||
<span class="k">if</span> <span class="n">step</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
|
||
|
||
<span class="c1"># NOTE: handle next step.</span>
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="n">step</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_new_tokens</span> <span class="o">-</span> <span class="mi">1</span><span class="p">:</span>
|
||
<span class="c1"># Set shape and address for the next step</span>
|
||
<span class="n">model_inputs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_prepare_generation_inputs</span><span class="p">(</span>
|
||
<span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span>
|
||
<span class="n">context_lengths</span><span class="o">=</span><span class="n">context_lengths</span><span class="p">,</span>
|
||
<span class="n">use_gpt_attention_plugin</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">,</span>
|
||
<span class="n">remove_input_padding</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">,</span>
|
||
<span class="n">step</span><span class="o">=</span><span class="n">step</span><span class="p">,</span>
|
||
<span class="n">num_beams</span><span class="o">=</span><span class="n">beam_width</span><span class="p">,</span>
|
||
<span class="n">attention_mask</span><span class="o">=</span><span class="n">attention_mask</span><span class="p">,</span>
|
||
<span class="p">)</span>
|
||
|
||
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">model_inputs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">'position_ids'</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
|
||
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">model_inputs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">'last_token_ids'</span><span class="p">)</span>
|
||
<span class="n">attention_mask</span> <span class="o">=</span> <span class="n">model_inputs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">'attention_mask'</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
|
||
|
||
<span class="c1"># Prepare for the next step, and always allocate 1 token slot.</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span><span class="p">:</span>
|
||
<span class="c1"># Iterate to the next step in KV cache manager.</span>
|
||
<span class="c1"># Increase number of tokens for all unfinished sequences.</span>
|
||
<span class="c1"># And allocate new blocks if needed.</span>
|
||
<span class="c1"># We set this to False for all sequences, since we use only length criterion to stop now</span>
|
||
<span class="c1"># OPTIMIZE: find a better of adding multiple tokens for paged kv cache.</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_tokens</span> <span class="o">></span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="c1"># Allocate kv cache token slots for next step.</span>
|
||
<span class="c1"># Make sure there are always > (num_medusa_tokens + 1) free token slots.</span>
|
||
<span class="c1"># Allocate (num_medusa_tokens + 1) * 2 for safety as we don't know the current step or next step's accepted lengths.</span>
|
||
<span class="n">add_token_count</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_tokens</span> <span class="o">+</span>
|
||
<span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="mi">2</span> <span class="k">if</span> <span class="n">step</span> <span class="o">==</span> <span class="mi">0</span> <span class="k">else</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="p">)</span><span class="o">.</span><span class="n">item</span><span class="p">()</span>
|
||
<span class="k">assert</span> <span class="n">add_token_count</span> <span class="o">></span> <span class="mi">0</span>
|
||
<span class="k">for</span> <span class="n">new_tokens</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">add_token_count</span><span class="p">):</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_manager</span><span class="o">.</span><span class="n">step</span><span class="p">([</span><span class="kc">False</span><span class="p">]</span> <span class="o">*</span> <span class="n">batch_size</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_manager</span><span class="o">.</span><span class="n">step</span><span class="p">([</span><span class="kc">False</span><span class="p">]</span> <span class="o">*</span> <span class="n">batch_size</span><span class="p">)</span>
|
||
<span class="n">host_kv_cache_block_pointers</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_manager</span><span class="o">.</span><span class="n">get_block_pointers</span><span class="p">(</span>
|
||
<span class="n">beam_width</span><span class="p">)</span>
|
||
<span class="n">kv_cache_block_pointers</span> <span class="o">=</span> <span class="n">host_kv_cache_block_pointers</span><span class="o">.</span><span class="n">to</span><span class="p">(</span>
|
||
<span class="s1">'cuda'</span><span class="p">)</span>
|
||
|
||
<span class="n">next_context</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">context_1</span> <span class="k">if</span> <span class="n">step</span> <span class="o">%</span> <span class="mi">2</span> <span class="k">else</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">context_0</span>
|
||
<span class="n">next_step_tensors</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_next_step_shape_buffer</span><span class="p">(</span>
|
||
<span class="n">batch_size</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">,</span> <span class="n">max_context_length</span><span class="p">,</span> <span class="n">step</span><span class="p">,</span>
|
||
<span class="n">context_lengths</span><span class="p">,</span> <span class="n">host_context_lengths</span><span class="p">,</span> <span class="n">position_ids</span><span class="p">,</span>
|
||
<span class="n">last_token_ids</span><span class="p">,</span> <span class="n">attention_mask</span><span class="p">,</span> <span class="n">cross_attention_mask</span><span class="p">,</span>
|
||
<span class="n">next_src_cache_indirection</span><span class="p">,</span> <span class="n">kv_cache_block_pointers</span><span class="p">,</span>
|
||
<span class="n">host_kv_cache_block_pointers</span><span class="p">,</span> <span class="n">hidden_states</span><span class="p">,</span>
|
||
<span class="n">prompt_embedding_table</span><span class="p">,</span> <span class="n">tasks</span><span class="p">,</span> <span class="n">prompt_vocab_size</span><span class="p">,</span>
|
||
<span class="n">encoder_output</span><span class="p">,</span> <span class="n">encoder_input_lengths</span><span class="p">)</span>
|
||
<span class="c1"># there are some tensors created inside the _get_next_step_shape_buffer, not owned by any object</span>
|
||
<span class="c1"># needs to pro-long the life time of the tensors inside the next_step_tensors array</span>
|
||
<span class="c1"># otherwise, it maybe released before the next step actually enqueued</span>
|
||
<span class="c1"># one way to prolong it is to return the list, and destroy it in next step by assigning new values</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">_set_tensors</span><span class="p">(</span><span class="n">next_context</span><span class="p">,</span> <span class="n">next_step_tensors</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cuda_graph_mode</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">_capture_cuda_graph_and_instantiate</span><span class="p">(</span>
|
||
<span class="n">next_context</span><span class="p">,</span> <span class="n">stream</span><span class="p">,</span> <span class="n">step</span><span class="p">)</span>
|
||
|
||
<span class="n">should_stop</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="n">logits</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
|
||
<span class="n">logits</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'logits'</span><span class="p">]</span>
|
||
<span class="k">if</span> <span class="n">logits</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span><span class="p">:</span>
|
||
<span class="n">should_stop</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">process_logits_for_medusa_mode</span><span class="p">(</span>
|
||
<span class="n">step</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">input_ids</span><span class="p">,</span> <span class="n">logits</span><span class="p">,</span> <span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">next_step_tensors</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="n">logits_processor</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">final_output_ids</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">finalize_decoder</span><span class="p">(</span>
|
||
<span class="n">context_lengths</span><span class="p">,</span>
|
||
<span class="n">batch_size</span><span class="p">,</span>
|
||
<span class="n">beam_width</span><span class="p">,</span>
|
||
<span class="n">scfg</span><span class="p">,</span>
|
||
<span class="n">in_progress</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||
<span class="c1"># keep the shape as same as huggingface stopping_criteria</span>
|
||
<span class="n">final_output_ids_</span> <span class="o">=</span> <span class="n">final_output_ids</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span>
|
||
<span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">final_output_ids</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">))</span>
|
||
<span class="n">logits</span> <span class="o">=</span> <span class="n">logits_processor</span><span class="p">(</span><span class="n">step</span><span class="p">,</span> <span class="n">final_output_ids_</span><span class="p">,</span>
|
||
<span class="n">logits</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'logits'</span><span class="p">]</span> <span class="o">=</span> <span class="n">logits</span>
|
||
<span class="c1"># [batch_size x beam_width, vocab_size_padded] -> [batch_size, beam_width, vocab_size_padded]</span>
|
||
<span class="n">next_token_logits</span> <span class="o">=</span> <span class="n">logits</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span>
|
||
<span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">,</span>
|
||
<span class="o">-</span><span class="mi">1</span><span class="p">))</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">decoder_logits_dtype</span><span class="p">)</span>
|
||
<span class="n">decode_step</span> <span class="o">=</span> <span class="n">step</span> <span class="o">+</span> <span class="n">max_context_length</span>
|
||
|
||
<span class="n">stop_words_list_ptrs</span><span class="p">,</span> <span class="n">stop_words_lens</span><span class="p">,</span> <span class="n">max_stop_words_len</span> <span class="o">=</span> <span class="n">stop_words_data</span>
|
||
<span class="n">bad_words_list_ptrs</span><span class="p">,</span> <span class="n">bad_words_lens</span><span class="p">,</span> <span class="n">max_bad_words_len</span> <span class="o">=</span> <span class="n">bad_words_data</span>
|
||
|
||
<span class="n">should_stop</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dynamic_decoder</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span>
|
||
<span class="n">next_token_logits</span><span class="p">,</span> <span class="n">decode_step</span><span class="p">,</span> <span class="n">max_context_length</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">sink_token_length</span><span class="p">,</span>
|
||
<span class="n">ite</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">end_ids</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">embedding_bias_opt</span><span class="p">,</span>
|
||
<span class="n">context_lengths</span><span class="p">,</span> <span class="n">sequence_limit_lengths</span><span class="p">,</span>
|
||
<span class="n">stop_words_list_ptrs</span><span class="p">,</span> <span class="n">stop_words_lens</span><span class="p">,</span>
|
||
<span class="n">max_stop_words_len</span><span class="p">,</span> <span class="n">bad_words_list_ptrs</span><span class="p">,</span> <span class="n">bad_words_lens</span><span class="p">,</span>
|
||
<span class="n">max_bad_words_len</span><span class="p">,</span> <span class="n">no_repeat_ngram_size</span><span class="p">,</span>
|
||
<span class="n">this_src_cache_indirection</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">output_ids</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">finished</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">finished</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">cum_log_probs</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">log_probs</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">log_probs_tiled</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">parent_ids</span><span class="p">,</span>
|
||
<span class="n">this_tgt_cache_indirection</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_output_ids_tgt</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_sequence_lengths_tgt</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_cum_log_probs</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_normed_scores</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_log_probs</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_min_normed_scores</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_num_beams</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_is_done</span><span class="p">,</span>
|
||
<span class="n">scfg</span><span class="o">.</span><span class="n">use_beam_hyps</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">stopping_criteria</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">should_stop</span><span class="o">.</span><span class="n">item</span><span class="p">():</span>
|
||
<span class="n">final_output_ids</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">finalize_decoder</span><span class="p">(</span>
|
||
<span class="n">context_lengths</span><span class="p">,</span>
|
||
<span class="n">batch_size</span><span class="p">,</span>
|
||
<span class="n">beam_width</span><span class="p">,</span>
|
||
<span class="n">scfg</span><span class="p">,</span>
|
||
<span class="n">in_progress</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||
<span class="c1"># keep the shape as same as huggingface stopping_criteria</span>
|
||
<span class="n">final_output_ids_</span> <span class="o">=</span> <span class="n">final_output_ids</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span>
|
||
<span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">final_output_ids</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">))</span>
|
||
<span class="n">should_stop</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">stopping_criteria</span><span class="p">(</span>
|
||
<span class="n">step</span><span class="p">,</span> <span class="n">final_output_ids_</span><span class="p">,</span> <span class="n">logits</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">has_pp</span><span class="p">():</span>
|
||
<span class="n">should_stop</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">pp_communicate_new_tokens</span><span class="p">(</span>
|
||
<span class="n">should_stop</span><span class="p">,</span> <span class="n">this_tgt_cache_indirection</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="p">(</span><span class="n">step</span> <span class="o">>=</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_new_tokens</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="ow">or</span> <span class="p">(</span><span class="n">should_stop</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="ow">and</span> <span class="n">should_stop</span><span class="o">.</span><span class="n">item</span><span class="p">()):</span>
|
||
<span class="c1"># Free all blocks in all sequences.</span>
|
||
<span class="c1"># With in-flight batching and while loop we'll free some sequences, when they are done</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_manager</span><span class="o">.</span><span class="n">step</span><span class="p">([</span><span class="kc">True</span><span class="p">]</span> <span class="o">*</span> <span class="n">batch_size</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">debug_mode</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">dump_debug_buffers</span><span class="p">(</span><span class="n">step</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">next_step_tensors</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">debug_buffer</span> <span class="o">=</span> <span class="p">{</span>
|
||
<span class="n">name</span><span class="p">:</span> <span class="n">tensor</span><span class="o">.</span><span class="n">to_torch</span><span class="p">()</span>
|
||
<span class="k">for</span> <span class="n">name</span><span class="p">,</span> <span class="n">tensor</span> <span class="ow">in</span> <span class="n">next_step_tensors</span><span class="o">.</span><span class="n">items</span><span class="p">()</span>
|
||
<span class="p">}</span>
|
||
|
||
<span class="k">return</span> <span class="n">should_stop</span><span class="p">,</span> <span class="n">next_step_tensors</span><span class="p">,</span> <span class="n">tasks</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span> <span class="n">host_context_lengths</span><span class="p">,</span> <span class="n">attention_mask</span><span class="p">,</span> <span class="n">context_logits</span><span class="p">,</span> <span class="n">generation_logits</span><span class="p">,</span> <span class="n">encoder_input_lengths</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="GenerationSession.dump_debug_buffers">
|
||
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.dump_debug_buffers">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">dump_debug_buffers</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">step</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">debug_tensors_to_save</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="c1"># restricted written tensors according to filter</span>
|
||
<span class="n">debug_tensor_names</span> <span class="o">=</span> <span class="n">copy</span><span class="o">.</span><span class="n">deepcopy</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">debug_buffer</span><span class="o">.</span><span class="n">keys</span><span class="p">()))</span>
|
||
<span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="n">debug_tensor_names</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="nb">all</span><span class="p">([</span><span class="n">kk</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">k</span> <span class="k">for</span> <span class="n">kk</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">debug_tensors_to_save</span><span class="p">]):</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">debug_buffer</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="n">k</span><span class="p">)</span>
|
||
|
||
<span class="n">debug_dir</span> <span class="o">=</span> <span class="n">Path</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s2">"tllm_debug/PP_</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_rank</span><span class="si">}</span><span class="s2">/TP_</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_rank</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||
<span class="n">debug_dir</span><span class="o">.</span><span class="n">mkdir</span><span class="p">(</span><span class="n">parents</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">exist_ok</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||
|
||
<span class="k">for</span> <span class="n">name</span><span class="p">,</span> <span class="n">t</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">debug_buffer</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
|
||
<span class="c1"># convert tensor name to valid file name</span>
|
||
<span class="n">fname</span> <span class="o">=</span> <span class="n">name</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s2">"/"</span><span class="p">,</span> <span class="s2">"."</span><span class="p">)</span>
|
||
<span class="n">t</span> <span class="o">=</span> <span class="n">torch_to_numpy</span><span class="p">(</span><span class="n">t</span><span class="p">)</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">debug_dir</span> <span class="o">/</span> <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">fname</span><span class="si">}</span><span class="s2">-step</span><span class="si">{</span><span class="n">step</span><span class="si">}</span><span class="s2">.npy"</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span>
|
||
|
||
<span class="n">txt_format</span> <span class="o">=</span> <span class="s2">"</span><span class="si">%d</span><span class="s2">"</span> <span class="k">if</span> <span class="n">t</span><span class="o">.</span><span class="n">dtype</span> <span class="ow">in</span> <span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">]</span> <span class="k">else</span> <span class="s1">'</span><span class="si">%.18e</span><span class="s1">'</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">savetxt</span><span class="p">(</span>
|
||
<span class="n">debug_dir</span> <span class="o">/</span> <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">fname</span><span class="si">}</span><span class="s2">-step</span><span class="si">{</span><span class="n">step</span><span class="si">}</span><span class="s2">.txt"</span><span class="p">,</span>
|
||
<span class="n">t</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">t</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]),</span> <span class="c1"># savetxt accepts 2 dims only</span>
|
||
<span class="n">fmt</span><span class="o">=</span><span class="n">txt_format</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="GenerationSession.decode_regular">
|
||
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.decode_regular">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">decode_regular</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
|
||
<span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">scfg</span><span class="p">:</span> <span class="n">SamplingConfig</span><span class="p">,</span>
|
||
<span class="n">sequence_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">context_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">host_context_lengths</span><span class="p">,</span>
|
||
<span class="n">max_context_length</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">beam_width</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">cache_indirections</span><span class="p">:</span> <span class="nb">list</span><span class="p">,</span>
|
||
<span class="n">input_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">hidden_states</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">prompt_embedding_table</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">tasks</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">prompt_vocab_size</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">ite</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">sequence_limit_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">stop_words_data</span><span class="p">,</span>
|
||
<span class="n">bad_words_data</span><span class="p">,</span>
|
||
<span class="n">no_repeat_ngram_size</span><span class="p">,</span>
|
||
<span class="n">output_sequence_lengths</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">return_dict</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">encoder_output</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">encoder_input_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">stopping_criteria</span><span class="p">:</span> <span class="n">StoppingCriteria</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">logits_processor</span><span class="p">:</span> <span class="n">LogitsProcessor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">cross_attention_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
|
||
<span class="n">kv_cache_block_pointers</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="n">host_kv_cache_block_pointers</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="n">attention_mask</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="n">outputs_context_logits</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="n">outputs_generation_logits</span> <span class="o">=</span> <span class="p">[]</span>
|
||
|
||
<span class="k">def</span> <span class="nf">get_outputs_dict</span><span class="p">(</span><span class="n">output_ids</span><span class="p">):</span>
|
||
<span class="n">outputs</span> <span class="o">=</span> <span class="p">{}</span>
|
||
<span class="n">outputs</span><span class="p">[</span><span class="s1">'output_ids'</span><span class="p">]</span> <span class="o">=</span> <span class="n">output_ids</span>
|
||
<span class="k">if</span> <span class="n">scfg</span><span class="o">.</span><span class="n">output_log_probs</span><span class="p">:</span>
|
||
<span class="n">outputs</span><span class="p">[</span><span class="s1">'log_probs'</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">log_probs</span>
|
||
<span class="k">if</span> <span class="n">scfg</span><span class="o">.</span><span class="n">output_cum_log_probs</span><span class="p">:</span>
|
||
<span class="n">outputs</span><span class="p">[</span><span class="s1">'cum_log_probs'</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">cum_log_probs</span>
|
||
<span class="k">if</span> <span class="n">output_sequence_lengths</span><span class="p">:</span>
|
||
<span class="n">outputs</span><span class="p">[</span>
|
||
<span class="s1">'sequence_lengths'</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span>
|
||
<span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">])</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_context_logits</span><span class="p">:</span>
|
||
<span class="n">outputs</span><span class="p">[</span><span class="s1">'context_logits'</span><span class="p">]</span> <span class="o">=</span> <span class="n">outputs_context_logits</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_generation_logits</span><span class="p">:</span>
|
||
<span class="n">outputs</span><span class="p">[</span><span class="s1">'generation_logits'</span><span class="p">]</span> <span class="o">=</span> <span class="n">outputs_generation_logits</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span><span class="p">:</span>
|
||
<span class="n">outputs</span><span class="p">[</span><span class="s1">'medusa_output_tokens'</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">medusa_output_tokens</span>
|
||
<span class="n">outputs</span><span class="p">[</span><span class="s1">'accept_lengths'</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">medusa_temperature</span> <span class="o">!=</span> <span class="mf">0.0</span><span class="p">:</span>
|
||
<span class="n">outputs</span><span class="p">[</span><span class="s1">'medusa_output_logits'</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">medusa_output_logits</span>
|
||
<span class="k">return</span> <span class="n">outputs</span>
|
||
|
||
<span class="n">benchmark_profiler</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">'benchmark_profiler'</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
|
||
<span class="n">generation_phase_step_count</span> <span class="o">=</span> <span class="mi">0</span>
|
||
|
||
<span class="k">def</span> <span class="nf">profile_fn</span><span class="p">(</span><span class="n">benchmark_profiler_obj</span><span class="p">,</span> <span class="n">step_count</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="n">benchmark_profiler_obj</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">benchmark_profiler_obj</span><span class="o">.</span><span class="n">record_cuda_event</span><span class="p">(</span><span class="s1">'last_token'</span><span class="p">)</span>
|
||
<span class="n">benchmark_profiler_obj</span><span class="o">.</span><span class="n">record_elapsed_time</span><span class="p">(</span>
|
||
<span class="s1">'first_token'</span><span class="p">,</span> <span class="s1">'last_token'</span><span class="p">,</span> <span class="s1">'generation_time'</span><span class="p">)</span>
|
||
<span class="n">benchmark_profiler_obj</span><span class="o">.</span><span class="n">add_aux_info</span><span class="p">(</span><span class="s1">'generation_step_count'</span><span class="p">,</span>
|
||
<span class="n">step_count</span><span class="p">)</span>
|
||
|
||
<span class="n">next_step_tensors</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="k">for</span> <span class="n">step</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_new_tokens</span><span class="p">):</span>
|
||
|
||
<span class="n">should_stop</span><span class="p">,</span> <span class="n">next_step_tensors</span><span class="p">,</span> <span class="n">tasks</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span> <span class="n">host_context_lengths</span><span class="p">,</span> <span class="n">attention_mask</span><span class="p">,</span> <span class="n">context_logits</span><span class="p">,</span> <span class="n">generation_logits</span><span class="p">,</span> <span class="n">encoder_input_lengths</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">handle_per_step</span><span class="p">(</span>
|
||
<span class="n">cache_indirections</span><span class="p">,</span> <span class="n">step</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">max_context_length</span><span class="p">,</span>
|
||
<span class="n">beam_width</span><span class="p">,</span> <span class="n">input_ids</span><span class="p">,</span> <span class="n">hidden_states</span><span class="p">,</span> <span class="n">scfg</span><span class="p">,</span>
|
||
<span class="n">kv_cache_block_pointers</span><span class="p">,</span> <span class="n">host_kv_cache_block_pointers</span><span class="p">,</span>
|
||
<span class="n">prompt_embedding_table</span><span class="p">,</span> <span class="n">tasks</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span>
|
||
<span class="n">host_context_lengths</span><span class="p">,</span> <span class="n">attention_mask</span><span class="p">,</span> <span class="n">cross_attention_mask</span><span class="p">,</span>
|
||
<span class="n">prompt_vocab_size</span><span class="p">,</span> <span class="n">ite</span><span class="p">,</span> <span class="n">sequence_limit_lengths</span><span class="p">,</span>
|
||
<span class="n">sequence_lengths</span><span class="p">,</span> <span class="n">next_step_tensors</span><span class="p">,</span> <span class="n">stop_words_data</span><span class="p">,</span>
|
||
<span class="n">bad_words_data</span><span class="p">,</span> <span class="n">no_repeat_ngram_size</span><span class="p">,</span> <span class="n">encoder_output</span><span class="p">,</span>
|
||
<span class="n">encoder_input_lengths</span><span class="p">,</span> <span class="n">stopping_criteria</span><span class="p">,</span> <span class="n">logits_processor</span><span class="p">,</span>
|
||
<span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">step</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="n">benchmark_profiler</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">benchmark_profiler</span><span class="o">.</span><span class="n">record_cuda_event</span><span class="p">(</span><span class="s1">'first_token'</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">generation_phase_step_count</span> <span class="o">=</span> <span class="n">generation_phase_step_count</span> <span class="o">+</span> <span class="mi">1</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
|
||
<span class="k">if</span> <span class="n">step</span> <span class="o">==</span> <span class="mi">0</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_context_logits</span><span class="p">:</span>
|
||
<span class="n">outputs_context_logits</span> <span class="o">=</span> <span class="n">context_logits</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_generation_logits</span><span class="p">:</span>
|
||
<span class="n">outputs_generation_logits</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">generation_logits</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">should_stop</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">should_stop</span><span class="o">.</span><span class="n">item</span><span class="p">():</span>
|
||
<span class="n">profile_fn</span><span class="p">(</span><span class="n">benchmark_profiler</span><span class="p">,</span> <span class="n">generation_phase_step_count</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span><span class="p">:</span>
|
||
<span class="c1"># just hack away for now</span>
|
||
<span class="n">final_output_ids</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">output_ids</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">final_output_ids</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">finalize_decoder</span><span class="p">(</span>
|
||
<span class="n">context_lengths</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">,</span> <span class="n">scfg</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_first_pp_rank</span><span class="p">():</span>
|
||
<span class="k">if</span> <span class="n">return_dict</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">get_outputs_dict</span><span class="p">(</span><span class="n">final_output_ids</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">final_output_ids</span>
|
||
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
|
||
<span class="n">outputs</span> <span class="o">=</span> <span class="p">{}</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_context_logits</span><span class="p">:</span>
|
||
<span class="n">outputs</span><span class="p">[</span><span class="s1">'context_logits'</span><span class="p">]</span> <span class="o">=</span> <span class="n">outputs_context_logits</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_generation_logits</span><span class="p">:</span>
|
||
<span class="n">outputs</span><span class="p">[</span><span class="s1">'generation_logits'</span><span class="p">]</span> <span class="o">=</span> <span class="n">outputs_generation_logits</span>
|
||
<span class="k">return</span> <span class="n">outputs</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="kc">None</span>
|
||
|
||
<span class="k">assert</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span><span class="p">,</span> <span class="s2">"the custom decoder doesn't support medusa."</span>
|
||
|
||
<span class="n">profile_fn</span><span class="p">(</span><span class="n">benchmark_profiler</span><span class="p">,</span> <span class="n">generation_phase_step_count</span><span class="p">)</span>
|
||
|
||
<span class="n">final_output_ids</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">finalize_decoder</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span>
|
||
<span class="n">beam_width</span><span class="p">,</span> <span class="n">scfg</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_first_pp_rank</span><span class="p">():</span>
|
||
<span class="k">if</span> <span class="n">return_dict</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">get_outputs_dict</span><span class="p">(</span><span class="n">final_output_ids</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">final_output_ids</span>
|
||
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
|
||
<span class="n">outputs</span> <span class="o">=</span> <span class="p">{}</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_context_logits</span><span class="p">:</span>
|
||
<span class="n">outputs</span><span class="p">[</span><span class="s1">'context_logits'</span><span class="p">]</span> <span class="o">=</span> <span class="n">outputs_context_logits</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_generation_logits</span><span class="p">:</span>
|
||
<span class="n">outputs</span><span class="p">[</span><span class="s1">'generation_logits'</span><span class="p">]</span> <span class="o">=</span> <span class="n">outputs_generation_logits</span>
|
||
<span class="k">return</span> <span class="n">outputs</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="kc">None</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="GenerationSession.decode_stream">
|
||
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.decode_stream">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">decode_stream</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
|
||
<span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">scfg</span><span class="p">:</span> <span class="n">SamplingConfig</span><span class="p">,</span>
|
||
<span class="n">sequence_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">context_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">host_context_lengths</span><span class="p">,</span>
|
||
<span class="n">max_context_length</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">beam_width</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">cache_indirections</span><span class="p">:</span> <span class="nb">list</span><span class="p">,</span>
|
||
<span class="n">input_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">hidden_states</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">prompt_embedding_table</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">tasks</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">prompt_vocab_size</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">ite</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">sequence_limit_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">stop_words_data</span><span class="p">,</span>
|
||
<span class="n">bad_words_data</span><span class="p">,</span>
|
||
<span class="n">no_repeat_ngram_size</span><span class="p">,</span>
|
||
<span class="n">output_sequence_lengths</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">return_dict</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">encoder_output</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">encoder_input_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">stopping_criteria</span><span class="p">:</span> <span class="n">StoppingCriteria</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">logits_processor</span><span class="p">:</span> <span class="n">LogitsProcessor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">cross_attention_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
|
||
<span class="n">kv_cache_block_pointers</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="n">host_kv_cache_block_pointers</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="n">attention_mask</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="n">outputs_context_logits</span> <span class="o">=</span> <span class="kc">None</span>
|
||
|
||
<span class="k">def</span> <span class="nf">get_outputs_dict</span><span class="p">(</span><span class="n">output_ids</span><span class="p">):</span>
|
||
<span class="n">outputs</span> <span class="o">=</span> <span class="p">{}</span>
|
||
<span class="n">outputs</span><span class="p">[</span><span class="s1">'output_ids'</span><span class="p">]</span> <span class="o">=</span> <span class="n">output_ids</span>
|
||
<span class="k">if</span> <span class="n">output_sequence_lengths</span><span class="p">:</span>
|
||
<span class="n">outputs</span><span class="p">[</span>
|
||
<span class="s1">'sequence_lengths'</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span>
|
||
<span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">])</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_context_logits</span><span class="p">:</span>
|
||
<span class="n">outputs</span><span class="p">[</span><span class="s1">'context_logits'</span><span class="p">]</span> <span class="o">=</span> <span class="n">outputs_context_logits</span>
|
||
<span class="k">return</span> <span class="n">outputs</span>
|
||
|
||
<span class="n">next_step_tensors</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="k">for</span> <span class="n">step</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_new_tokens</span><span class="p">):</span>
|
||
|
||
<span class="n">should_stop</span><span class="p">,</span> <span class="n">next_step_tensors</span><span class="p">,</span> <span class="n">tasks</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span> <span class="n">host_context_lengths</span><span class="p">,</span> <span class="n">attention_mask</span><span class="p">,</span> <span class="n">context_logits</span><span class="p">,</span> <span class="n">generation_logits</span><span class="p">,</span> <span class="n">encoder_input_lengths</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">handle_per_step</span><span class="p">(</span>
|
||
<span class="n">cache_indirections</span><span class="p">,</span> <span class="n">step</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">max_context_length</span><span class="p">,</span>
|
||
<span class="n">beam_width</span><span class="p">,</span> <span class="n">input_ids</span><span class="p">,</span> <span class="n">hidden_states</span><span class="p">,</span> <span class="n">scfg</span><span class="p">,</span>
|
||
<span class="n">kv_cache_block_pointers</span><span class="p">,</span> <span class="n">host_kv_cache_block_pointers</span><span class="p">,</span>
|
||
<span class="n">prompt_embedding_table</span><span class="p">,</span> <span class="n">tasks</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span>
|
||
<span class="n">host_context_lengths</span><span class="p">,</span> <span class="n">attention_mask</span><span class="p">,</span> <span class="n">cross_attention_mask</span><span class="p">,</span>
|
||
<span class="n">prompt_vocab_size</span><span class="p">,</span> <span class="n">ite</span><span class="p">,</span> <span class="n">sequence_limit_lengths</span><span class="p">,</span>
|
||
<span class="n">sequence_lengths</span><span class="p">,</span> <span class="n">next_step_tensors</span><span class="p">,</span> <span class="n">stop_words_data</span><span class="p">,</span>
|
||
<span class="n">bad_words_data</span><span class="p">,</span> <span class="n">no_repeat_ngram_size</span><span class="p">,</span> <span class="n">encoder_output</span><span class="p">,</span>
|
||
<span class="n">encoder_input_lengths</span><span class="p">,</span> <span class="n">stopping_criteria</span><span class="p">,</span> <span class="n">logits_processor</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">step</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="n">outputs_context_logits</span> <span class="o">=</span> <span class="n">context_logits</span>
|
||
<span class="k">if</span> <span class="n">should_stop</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
|
||
<span class="n">final_output_ids</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">finalize_decoder</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">,</span>
|
||
<span class="n">batch_size</span><span class="p">,</span>
|
||
<span class="n">beam_width</span><span class="p">,</span>
|
||
<span class="n">scfg</span><span class="p">,</span>
|
||
<span class="n">in_progress</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_first_pp_rank</span><span class="p">():</span>
|
||
<span class="k">if</span> <span class="n">return_dict</span><span class="p">:</span>
|
||
<span class="k">yield</span> <span class="n">get_outputs_dict</span><span class="p">(</span><span class="n">final_output_ids</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="k">yield</span> <span class="n">final_output_ids</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="k">yield</span> <span class="kc">None</span>
|
||
|
||
<span class="k">if</span> <span class="n">should_stop</span><span class="o">.</span><span class="n">item</span><span class="p">():</span>
|
||
<span class="k">return</span>
|
||
|
||
<span class="n">final_output_ids</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">finalize_decoder</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span>
|
||
<span class="n">beam_width</span><span class="p">,</span> <span class="n">scfg</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_first_pp_rank</span><span class="p">():</span>
|
||
<span class="k">if</span> <span class="n">return_dict</span><span class="p">:</span>
|
||
<span class="k">yield</span> <span class="n">get_outputs_dict</span><span class="p">(</span><span class="n">final_output_ids</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="k">yield</span> <span class="n">final_output_ids</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="k">yield</span> <span class="kc">None</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="GenerationSession.decode_batch">
|
||
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.decode_batch">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">decode_batch</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
|
||
<span class="n">input_ids</span><span class="p">:</span> <span class="n">Sequence</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span>
|
||
<span class="n">sampling_config</span><span class="p">:</span> <span class="n">SamplingConfig</span><span class="p">,</span>
|
||
<span class="n">streaming</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
|
||
<span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
|
||
<span class="n">input_ids</span><span class="p">,</span> <span class="n">context_lengths</span> <span class="o">=</span> <span class="n">_prepare_input_ids</span><span class="p">(</span><span class="n">input_ids</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">decode</span><span class="p">(</span><span class="n">input_ids</span><span class="p">,</span>
|
||
<span class="n">context_lengths</span><span class="p">,</span>
|
||
<span class="n">sampling_config</span><span class="p">,</span>
|
||
<span class="n">streaming</span><span class="o">=</span><span class="n">streaming</span><span class="p">,</span>
|
||
<span class="o">**</span><span class="n">kwargs</span><span class="p">)</span></div>
|
||
|
||
|
||
<span class="c1"># As dynamic_decoder uses torch's current stream, we must ensure it runs on the same stream that</span>
|
||
<span class="c1"># dynamic_decoder was set up with</span>
|
||
<div class="viewcode-block" id="GenerationSession.decode">
|
||
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.decode">[docs]</a>
|
||
<span class="nd">@cuda_stream_guard</span>
|
||
<span class="k">def</span> <span class="nf">decode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
|
||
<span class="n">input_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">context_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">sampling_config</span><span class="p">:</span> <span class="n">SamplingConfig</span><span class="p">,</span>
|
||
<span class="n">prompt_embedding_table</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">tasks</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">prompt_vocab_size</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">stop_words_list</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">bad_words_list</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">no_repeat_ngram_size</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">streaming</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">output_sequence_lengths</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">return_dict</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">encoder_output</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">encoder_input_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">stopping_criteria</span><span class="p">:</span> <span class="n">StoppingCriteria</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">logits_processor</span><span class="p">:</span> <span class="n">LogitsProcessor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">cross_attention_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
|
||
<span class="n">scfg</span> <span class="o">=</span> <span class="n">sampling_config</span>
|
||
<span class="n">batch_size</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
||
<span class="n">beam_width</span> <span class="o">=</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span>
|
||
<span class="n">max_context_length</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">)</span><span class="o">.</span><span class="n">item</span><span class="p">()</span>
|
||
<span class="n">host_context_lengths</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span>
|
||
<span class="k">assert</span> <span class="n">batch_size</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span> \
|
||
<span class="s2">"Given batch size is different from the one used in setup(),"</span> \
|
||
<span class="s2">"rerun the setup function with the new batch size to avoid buffer overflow."</span>
|
||
<span class="k">assert</span> <span class="n">max_context_length</span> <span class="o"><=</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_context_length</span><span class="p">,</span> \
|
||
<span class="s2">"Given input length is large then the one used in setup(),"</span> \
|
||
<span class="s2">"rerun the setup function with the new max_context_length to avoid buffer overflow."</span>
|
||
<span class="k">assert</span> <span class="n">beam_width</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">beam_width</span><span class="p">,</span> \
|
||
<span class="s2">"Given beam width is different from the one used in setup(),"</span> \
|
||
<span class="s2">"rerun the setup function with the new beam width to avoid buffer overflow."</span>
|
||
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">sink_token_length</span> <span class="o"><=</span> <span class="n">torch</span><span class="o">.</span><span class="n">min</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">)</span><span class="o">.</span><span class="n">item</span><span class="p">(),</span> \
|
||
<span class="s2">"Given sink token length is larger than shortest context length,"</span> \
|
||
<span class="s2">"rerun the setup function with a smaller sink token length."</span>
|
||
<span class="n">ite</span> <span class="o">=</span> <span class="mi">0</span> <span class="c1"># index of local batches, will always be 0 if pp_size = 1</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span> <span class="ow">and</span> <span class="n">input_ids</span><span class="o">.</span><span class="n">dim</span><span class="p">()</span> <span class="o">==</span> <span class="mi">2</span><span class="p">:</span>
|
||
<span class="k">assert</span> <span class="n">input_ids</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
|
||
<span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span> <span class="s2">"Packed 2D input must have shape [1, <sum of input lengths>]"</span>
|
||
<span class="n">input_ids</span> <span class="o">=</span> <span class="n">input_ids</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
||
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">__setup_decoder</span><span class="p">(</span><span class="n">input_ids</span><span class="p">,</span> <span class="n">scfg</span><span class="p">,</span> <span class="n">host_context_lengths</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer_allocated</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span><span class="s1">'Buffer not allocated, please call setup first!'</span><span class="p">)</span>
|
||
|
||
<span class="n">sequence_limit_lengths</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||
|
||
<span class="c1"># Sequence_lengths for the dynamic decoder still has the input paddings.</span>
|
||
<span class="n">sequence_lengths</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">((</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">beam_width</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
|
||
<span class="n">max_context_length</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||
|
||
<span class="n">cache_indirections</span> <span class="o">=</span> <span class="p">[</span>
|
||
<span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">((</span>
|
||
<span class="n">batch_size</span><span class="p">,</span>
|
||
<span class="n">beam_width</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window_size</span><span class="p">,</span>
|
||
<span class="p">),</span>
|
||
<span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">),</span>
|
||
<span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">((</span>
|
||
<span class="n">batch_size</span><span class="p">,</span>
|
||
<span class="n">beam_width</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window_size</span><span class="p">,</span>
|
||
<span class="p">),</span>
|
||
<span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||
<span class="p">]</span> <span class="c1"># ping-pong buffers</span>
|
||
|
||
<span class="n">hidden_states</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">has_pp</span><span class="p">():</span>
|
||
<span class="n">max_num_tokens</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">beam_width</span><span class="p">,</span>
|
||
<span class="n">batch_size</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">)</span>
|
||
<span class="n">hidden_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span>
|
||
<span class="n">hidden_states</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="n">max_num_tokens</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">))</span>
|
||
|
||
<span class="c1"># Init KV cache block manager</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span><span class="p">:</span>
|
||
<span class="n">bubble_len</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">sink_token_length</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">tokens_per_block</span> <span class="o">></span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="n">bubble_len</span> <span class="o">+=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">tokens_per_block</span> <span class="o">-</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">sink_token_length</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">tokens_per_block</span><span class="p">)</span>
|
||
<span class="n">max_blocks_per_seq</span> <span class="o">=</span> <span class="n">math</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span>
|
||
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window_size</span> <span class="o">+</span> <span class="n">bubble_len</span><span class="p">)</span> <span class="o">/</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">tokens_per_block</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_one_more_block</span><span class="p">:</span>
|
||
<span class="n">max_blocks_per_seq</span> <span class="o">+=</span> <span class="mi">1</span>
|
||
<span class="n">blocks</span> <span class="o">=</span> <span class="n">batch_size</span> <span class="o">*</span> <span class="n">beam_width</span> <span class="o">*</span> <span class="n">max_blocks_per_seq</span>
|
||
<span class="n">memory_pools</span> <span class="o">=</span> <span class="p">[</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">'present_key_value_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">'</span><span class="p">]</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">)</span>
|
||
<span class="p">]</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_manager</span> <span class="o">=</span> <span class="n">KVCacheManager</span><span class="p">(</span>
|
||
<span class="n">memory_pools</span><span class="p">,</span> <span class="n">blocks</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">tokens_per_block</span><span class="p">,</span> <span class="n">max_blocks_per_seq</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">sink_token_length</span><span class="p">,</span>
|
||
<span class="n">beam_width</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_one_more_block</span><span class="p">)</span>
|
||
|
||
<span class="c1"># Add sequences to the manager</span>
|
||
<span class="k">for</span> <span class="n">bi</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
|
||
<span class="n">generation_sequence</span> <span class="o">=</span> <span class="n">GenerationSequence</span><span class="p">(</span><span class="n">seq_idx</span><span class="o">=</span><span class="n">bi</span><span class="p">,</span>
|
||
<span class="n">batch_idx</span><span class="o">=</span><span class="n">bi</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_manager</span><span class="o">.</span><span class="n">add_sequence</span><span class="p">(</span><span class="n">generation_sequence</span><span class="p">,</span>
|
||
<span class="n">max_context_length</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">quant_mode</span><span class="o">.</span><span class="n">has_kv_cache_quant</span><span class="p">():</span>
|
||
<span class="c1"># Since torch does not support fp8 now, using int8 here.</span>
|
||
<span class="n">kv_cache_type</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">int8</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">kv_cache_type</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dtype</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span> <span class="k">else</span> <span class="bp">self</span><span class="o">.</span><span class="n">_tensor_dtype</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s1">'present_key_value_</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">history_max_seq_length</span> <span class="o">=</span> <span class="p">[</span><span class="n">max_context_length</span><span class="p">]</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_updater</span> <span class="o">=</span> <span class="n">KVCacheUpdater</span><span class="p">()</span>
|
||
<span class="k">assert</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span>
|
||
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_updater</span><span class="o">.</span><span class="n">init_paged_kv_cache</span><span class="p">(</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">num_heads_kv</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">head_size</span><span class="p">,</span> <span class="n">kv_cache_type</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_manager</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">past_key_value_list</span> <span class="o">=</span> <span class="p">[</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">'present_key_value_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">'</span><span class="p">]</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">)</span>
|
||
<span class="p">]</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_updater</span><span class="o">.</span><span class="n">init_linear_kv_cache</span><span class="p">(</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">num_heads_kv</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">head_size</span><span class="p">,</span> <span class="n">kv_cache_type</span><span class="p">,</span>
|
||
<span class="n">past_key_value_list</span><span class="p">)</span>
|
||
|
||
<span class="n">stop_words_lens</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="n">stop_words_list_ptrs</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="n">max_stop_words_len</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="k">if</span> <span class="n">stop_words_list</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">max_stop_words_len</span> <span class="o">=</span> <span class="n">stop_words_list</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span>
|
||
<span class="n">stop_words_lens</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="p">),</span>
|
||
<span class="n">max_stop_words_len</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="s1">'cuda'</span><span class="p">)</span>
|
||
<span class="n">stop_words_list_ptrs</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">batch_size</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int64</span><span class="p">)</span>
|
||
<span class="k">for</span> <span class="n">bi</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
|
||
<span class="n">stop_words_list_ptrs</span><span class="p">[</span><span class="n">bi</span><span class="p">]</span> <span class="o">=</span> <span class="n">stop_words_list</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">(</span>
|
||
<span class="p">)</span> <span class="o">+</span> <span class="n">bi</span> <span class="o">*</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">max_stop_words_len</span>
|
||
<span class="n">stop_words_list_ptrs</span> <span class="o">=</span> <span class="n">stop_words_list_ptrs</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="s1">'cuda'</span><span class="p">)</span>
|
||
<span class="n">stop_words_data</span> <span class="o">=</span> <span class="p">(</span><span class="n">stop_words_list_ptrs</span><span class="p">,</span> <span class="n">stop_words_lens</span><span class="p">,</span>
|
||
<span class="n">max_stop_words_len</span><span class="p">)</span>
|
||
|
||
<span class="n">bad_words_lens</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="n">bad_words_list_ptrs</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="n">max_bad_words_len</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="k">if</span> <span class="n">bad_words_list</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">max_bad_words_len</span> <span class="o">=</span> <span class="n">bad_words_list</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span>
|
||
<span class="n">bad_words_lens</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="p">),</span>
|
||
<span class="n">max_bad_words_len</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="s1">'cuda'</span><span class="p">)</span>
|
||
<span class="n">bad_words_list_ptrs</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">batch_size</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int64</span><span class="p">)</span>
|
||
<span class="k">for</span> <span class="n">bi</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
|
||
<span class="n">bad_words_list_ptrs</span><span class="p">[</span><span class="n">bi</span><span class="p">]</span> <span class="o">=</span> <span class="n">bad_words_list</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">(</span>
|
||
<span class="p">)</span> <span class="o">+</span> <span class="n">bi</span> <span class="o">*</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">max_bad_words_len</span>
|
||
<span class="n">bad_words_list_ptrs</span> <span class="o">=</span> <span class="n">bad_words_list_ptrs</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="s1">'cuda'</span><span class="p">)</span>
|
||
<span class="n">bad_words_data</span> <span class="o">=</span> <span class="p">(</span><span class="n">bad_words_list_ptrs</span><span class="p">,</span> <span class="n">bad_words_lens</span><span class="p">,</span>
|
||
<span class="n">max_bad_words_len</span><span class="p">)</span>
|
||
|
||
<span class="c1"># start context phase</span>
|
||
<span class="k">if</span> <span class="n">streaming</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">decode_stream</span><span class="p">(</span>
|
||
<span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="p">,</span> <span class="n">sequence_lengths</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span>
|
||
<span class="n">host_context_lengths</span><span class="p">,</span> <span class="n">max_context_length</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">,</span>
|
||
<span class="n">cache_indirections</span><span class="p">,</span> <span class="n">input_ids</span><span class="p">,</span> <span class="n">hidden_states</span><span class="p">,</span>
|
||
<span class="n">prompt_embedding_table</span><span class="p">,</span> <span class="n">tasks</span><span class="p">,</span> <span class="n">prompt_vocab_size</span><span class="p">,</span> <span class="n">ite</span><span class="p">,</span>
|
||
<span class="n">sequence_limit_lengths</span><span class="p">,</span> <span class="n">stop_words_data</span><span class="p">,</span> <span class="n">bad_words_data</span><span class="p">,</span>
|
||
<span class="n">no_repeat_ngram_size</span><span class="p">,</span> <span class="n">output_sequence_lengths</span><span class="p">,</span> <span class="n">return_dict</span><span class="p">,</span>
|
||
<span class="n">encoder_output</span><span class="p">,</span> <span class="n">encoder_input_lengths</span><span class="p">,</span> <span class="n">stopping_criteria</span><span class="p">,</span>
|
||
<span class="n">logits_processor</span><span class="p">,</span> <span class="n">cross_attention_mask</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">decode_regular</span><span class="p">(</span>
|
||
<span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="p">,</span> <span class="n">sequence_lengths</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span>
|
||
<span class="n">host_context_lengths</span><span class="p">,</span> <span class="n">max_context_length</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">,</span>
|
||
<span class="n">cache_indirections</span><span class="p">,</span> <span class="n">input_ids</span><span class="p">,</span> <span class="n">hidden_states</span><span class="p">,</span>
|
||
<span class="n">prompt_embedding_table</span><span class="p">,</span> <span class="n">tasks</span><span class="p">,</span> <span class="n">prompt_vocab_size</span><span class="p">,</span> <span class="n">ite</span><span class="p">,</span>
|
||
<span class="n">sequence_limit_lengths</span><span class="p">,</span> <span class="n">stop_words_data</span><span class="p">,</span> <span class="n">bad_words_data</span><span class="p">,</span>
|
||
<span class="n">no_repeat_ngram_size</span><span class="p">,</span> <span class="n">output_sequence_lengths</span><span class="p">,</span> <span class="n">return_dict</span><span class="p">,</span>
|
||
<span class="n">encoder_output</span><span class="p">,</span> <span class="n">encoder_input_lengths</span><span class="p">,</span> <span class="n">stopping_criteria</span><span class="p">,</span>
|
||
<span class="n">logits_processor</span><span class="p">,</span> <span class="n">cross_attention_mask</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span></div>
|
||
</div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="ChatGLMGenerationSession">
|
||
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.ChatGLMGenerationSession">[docs]</a>
|
||
<span class="k">class</span> <span class="nc">ChatGLMGenerationSession</span><span class="p">(</span><span class="n">GenerationSession</span><span class="p">):</span>
|
||
|
||
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
|
||
<span class="bp">self</span><span class="p">,</span>
|
||
<span class="n">model_config</span><span class="p">:</span> <span class="n">ModelConfig</span><span class="p">,</span>
|
||
<span class="n">engine_buffer</span><span class="p">,</span>
|
||
<span class="n">mapping</span><span class="p">:</span> <span class="n">Mapping</span><span class="p">,</span>
|
||
<span class="n">debug_mode</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">debug_tensors_to_save</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">cuda_graph_mode</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">stream</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">Stream</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="p">):</span>
|
||
|
||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span>
|
||
<span class="n">model_config</span><span class="p">,</span>
|
||
<span class="n">engine_buffer</span><span class="p">,</span>
|
||
<span class="n">mapping</span><span class="p">,</span>
|
||
<span class="n">debug_mode</span><span class="p">,</span>
|
||
<span class="n">debug_tensors_to_save</span><span class="p">,</span>
|
||
<span class="n">cuda_graph_mode</span><span class="p">,</span>
|
||
<span class="n">stream</span><span class="p">,</span>
|
||
<span class="p">)</span>
|
||
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">mask_index_tensor</span> <span class="o">=</span> <span class="kc">None</span>
|
||
|
||
<span class="k">def</span> <span class="nf">_prepare_context_inputs</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span>
|
||
<span class="n">use_gpt_attention_plugin</span><span class="p">,</span> <span class="n">remove_input_padding</span><span class="p">,</span>
|
||
<span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
|
||
|
||
<span class="n">max_context_length</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">'max_context_length'</span><span class="p">)</span>
|
||
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
|
||
|
||
<span class="k">if</span> <span class="n">remove_input_padding</span><span class="p">:</span>
|
||
<span class="n">input_lengths_acc</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span>
|
||
<span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">IntTensor</span><span class="p">([</span><span class="mi">0</span><span class="p">])</span><span class="o">.</span><span class="n">cuda</span><span class="p">(),</span> <span class="n">context_lengths</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">),</span>
|
||
<span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
|
||
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="mi">2</span><span class="p">,</span> <span class="n">input_lengths_acc</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]],</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
|
||
<span class="n">position_ids</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="n">input_lengths_acc</span><span class="p">[</span><span class="n">i</span><span class="p">]:</span><span class="n">input_lengths_acc</span><span class="p">[</span>
|
||
<span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">context_lengths</span><span class="p">[</span><span class="n">i</span><span class="p">],</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
|
||
<span class="n">position_ids</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="n">input_lengths_acc</span><span class="p">[</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]</span> <span class="o">-</span>
|
||
<span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">-</span> <span class="mi">2</span>
|
||
<span class="n">position_ids</span><span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="n">input_lengths_acc</span><span class="p">[</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]</span> <span class="o">-</span> <span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span>
|
||
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">position_ids</span><span class="o">.</span><span class="n">int</span><span class="p">()</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
|
||
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">last_token_ids</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">int</span><span class="p">()</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">max_context_length</span><span class="p">],</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
|
||
<span class="n">position_ids</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">,</span> <span class="p">:]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">max_context_length</span><span class="p">)</span>
|
||
|
||
<span class="c1"># specialization for GLM series models</span>
|
||
<span class="k">if</span> <span class="n">kwargs</span><span class="p">[</span><span class="s2">"pad_id"</span><span class="p">]</span> <span class="ow">in</span> <span class="p">[</span><span class="mi">50256</span><span class="p">,</span> <span class="mi">50259</span><span class="p">]:</span>
|
||
<span class="k">if</span> <span class="n">kwargs</span><span class="p">[</span><span class="s2">"pad_id"</span><span class="p">]</span> <span class="o">==</span> <span class="mi">50256</span><span class="p">:</span> <span class="c1"># glm_2b / glm_10b</span>
|
||
<span class="n">mask_ids</span> <span class="o">=</span> <span class="p">[</span><span class="mi">50260</span><span class="p">,</span> <span class="mi">50264</span><span class="p">,</span> <span class="mi">50263</span><span class="p">]</span>
|
||
<span class="k">else</span><span class="p">:</span> <span class="c1"># glm_10b_chinese / glm_large_chinese</span>
|
||
<span class="n">mask_ids</span> <span class="o">=</span> <span class="p">[</span><span class="mi">50003</span><span class="p">,</span> <span class="mi">50008</span><span class="p">,</span> <span class="mi">50009</span><span class="p">]</span>
|
||
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">mask_index_tensor</span> <span class="o">=</span> \
|
||
<span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
|
||
<span class="n">length</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
|
||
<span class="n">input_ids</span> <span class="o">=</span> <span class="n">kwargs</span><span class="p">[</span><span class="s2">"input_ids"</span><span class="p">][</span><span class="n">i</span><span class="p">]</span>
|
||
<span class="n">mask_index</span> <span class="o">=</span> <span class="p">[</span>
|
||
<span class="n">torch</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">input_ids</span> <span class="o">==</span> <span class="nb">id</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">int</span><span class="p">()</span> <span class="k">for</span> <span class="nb">id</span> <span class="ow">in</span> <span class="n">mask_ids</span>
|
||
<span class="p">]</span>
|
||
<span class="n">tail_index</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">([</span><span class="n">max_context_length</span><span class="p">])</span><span class="o">.</span><span class="n">int</span><span class="p">()</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
|
||
<span class="n">mask_index</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">tail_index</span><span class="p">)</span>
|
||
<span class="n">mask_index</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="n">mask_index</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">min</span><span class="p">()</span>
|
||
<span class="n">position_ids</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">length</span> <span class="o">-</span> <span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">mask_index</span><span class="p">)</span>
|
||
<span class="n">position_ids</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">length</span> <span class="o">-</span> <span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">mask_index_tensor</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">mask_index</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
|
||
<span class="n">length</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
|
||
<span class="n">position_ids</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">length</span> <span class="o">-</span> <span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="n">length</span> <span class="o">-</span> <span class="mi">2</span>
|
||
<span class="n">position_ids</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">length</span> <span class="o">-</span> <span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span>
|
||
|
||
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">position_ids</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
|
||
|
||
<span class="n">inputs</span> <span class="o">=</span> <span class="p">{</span>
|
||
<span class="s1">'position_ids'</span><span class="p">:</span> <span class="n">position_ids</span><span class="p">,</span>
|
||
<span class="s1">'last_token_ids'</span><span class="p">:</span> <span class="n">last_token_ids</span>
|
||
<span class="p">}</span>
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
|
||
<span class="n">attention_mask</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
|
||
<span class="n">inputs</span><span class="p">[</span><span class="s1">'attention_mask'</span><span class="p">]</span> <span class="o">=</span> <span class="n">attention_mask</span>
|
||
<span class="k">return</span> <span class="n">inputs</span>
|
||
|
||
<span class="k">def</span> <span class="nf">_prepare_generation_inputs</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span>
|
||
<span class="n">use_gpt_attention_plugin</span><span class="p">,</span>
|
||
<span class="n">remove_input_padding</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
|
||
|
||
<span class="n">step</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">'step'</span><span class="p">)</span>
|
||
<span class="n">num_beams</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">'num_beams'</span><span class="p">)</span>
|
||
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones_like</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">remove_input_padding</span><span class="p">:</span>
|
||
|
||
<span class="k">def</span> <span class="nf">_tile_beam_width_chatglm</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">num_beams</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
|
||
<span class="n">new_shape</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
|
||
<span class="n">new_shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="n">new_shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="n">num_beams</span>
|
||
<span class="n">tile_size</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">new_shape</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
|
||
<span class="n">tile_size</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">insert</span><span class="p">(</span><span class="n">tile_size</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">num_beams</span><span class="p">)</span>
|
||
<span class="n">new_tensor</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
|
||
<span class="n">new_tensor</span> <span class="o">=</span> <span class="n">new_tensor</span><span class="o">.</span><span class="n">tile</span><span class="p">(</span><span class="n">tile_size</span><span class="o">.</span><span class="n">tolist</span><span class="p">())</span>
|
||
<span class="n">new_tensor</span> <span class="o">=</span> <span class="n">new_tensor</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">new_shape</span><span class="o">.</span><span class="n">tolist</span><span class="p">())</span>
|
||
<span class="k">return</span> <span class="n">new_tensor</span>
|
||
|
||
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="mi">2</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
|
||
<span class="n">position_ids</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="n">num_beams</span><span class="p">]</span> <span class="o">-</span> <span class="mi">2</span>
|
||
<span class="n">position_ids</span><span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">step</span> <span class="o">+</span> <span class="mi">2</span>
|
||
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">_tile_beam_width_chatglm</span><span class="p">(</span><span class="n">position_ids</span><span class="p">,</span> <span class="n">num_beams</span><span class="p">)</span>
|
||
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">position_ids</span><span class="o">.</span><span class="n">int</span><span class="p">()</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
|
||
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">last_token_ids</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">int</span><span class="p">()</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">data</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mask_index_tensor</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> <span class="c1"># specialization for GLM series models</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
|
||
<span class="n">data</span><span class="o">.</span><span class="n">append</span><span class="p">([[</span><span class="bp">self</span><span class="o">.</span><span class="n">mask_index_tensor</span><span class="p">[</span><span class="n">i</span><span class="p">]],</span> <span class="p">[</span><span class="n">step</span> <span class="o">+</span> <span class="mi">2</span><span class="p">]])</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
|
||
<span class="n">data</span><span class="o">.</span><span class="n">append</span><span class="p">([[</span><span class="n">context_lengths</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="n">num_beams</span><span class="p">]</span> <span class="o">-</span> <span class="mi">2</span><span class="p">],</span>
|
||
<span class="p">[</span><span class="n">step</span> <span class="o">+</span> <span class="mi">2</span><span class="p">]])</span>
|
||
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">)</span>
|
||
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">_tile_beam_width</span><span class="p">(</span><span class="n">position_ids</span><span class="p">,</span> <span class="n">num_beams</span><span class="p">)</span>
|
||
|
||
<span class="n">inputs</span> <span class="o">=</span> <span class="p">{</span>
|
||
<span class="s1">'position_ids'</span><span class="p">:</span> <span class="n">position_ids</span><span class="p">,</span>
|
||
<span class="s1">'last_token_ids'</span><span class="p">:</span> <span class="n">last_token_ids</span>
|
||
<span class="p">}</span>
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
|
||
<span class="n">attention_mask</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
|
||
<span class="n">inputs</span><span class="p">[</span><span class="s1">'attention_mask'</span><span class="p">]</span> <span class="o">=</span> <span class="n">attention_mask</span>
|
||
<span class="k">return</span> <span class="n">inputs</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="QWenForCausalLMGenerationSession">
|
||
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.QWenForCausalLMGenerationSession">[docs]</a>
|
||
<span class="k">class</span> <span class="nc">QWenForCausalLMGenerationSession</span><span class="p">(</span><span class="n">GenerationSession</span><span class="p">):</span>
|
||
|
||
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
|
||
<span class="bp">self</span><span class="p">,</span>
|
||
<span class="n">model_config</span><span class="p">:</span> <span class="n">ModelConfig</span><span class="p">,</span>
|
||
<span class="n">engine_buffer</span><span class="p">,</span>
|
||
<span class="n">mapping</span><span class="p">:</span> <span class="n">Mapping</span><span class="p">,</span>
|
||
<span class="n">debug_mode</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">debug_tensors_to_save</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">cuda_graph_mode</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">stream</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">Stream</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">global_max_input_length</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">2048</span><span class="p">,</span>
|
||
<span class="n">global_max_output_length</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">4096</span><span class="p">,</span>
|
||
<span class="p">):</span>
|
||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">model_config</span><span class="p">,</span>
|
||
<span class="n">engine_buffer</span><span class="p">,</span>
|
||
<span class="n">mapping</span><span class="p">,</span>
|
||
<span class="n">debug_mode</span><span class="p">,</span>
|
||
<span class="n">debug_tensors_to_save</span><span class="o">=</span><span class="n">debug_tensors_to_save</span><span class="p">,</span>
|
||
<span class="n">cuda_graph_mode</span><span class="o">=</span><span class="n">cuda_graph_mode</span><span class="p">,</span>
|
||
<span class="n">stream</span><span class="o">=</span><span class="n">stream</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">global_max_input_length</span> <span class="o">=</span> <span class="n">global_max_input_length</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">global_max_output_length</span> <span class="o">=</span> <span class="n">global_max_output_length</span>
|
||
|
||
<div class="viewcode-block" id="QWenForCausalLMGenerationSession.generate">
|
||
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.QWenForCausalLMGenerationSession.generate">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">generate</span><span class="p">(</span>
|
||
<span class="bp">self</span><span class="p">,</span>
|
||
<span class="n">input_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">input_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">sampling_config</span><span class="p">:</span> <span class="n">SamplingConfig</span><span class="p">,</span>
|
||
<span class="n">max_new_tokens</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">runtime_rank</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
|
||
<span class="p">):</span>
|
||
<span class="n">max_input_length</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">input_lengths</span><span class="p">)</span><span class="o">.</span><span class="n">item</span><span class="p">()</span>
|
||
<span class="n">max_new_tokens</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">max_new_tokens</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">global_max_output_length</span> <span class="o">-</span> <span class="n">max_input_length</span><span class="p">)</span>
|
||
<span class="c1"># setup batch_size, max_input_length, max_output_len</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">setup</span><span class="p">(</span><span class="n">batch_size</span><span class="o">=</span><span class="n">input_lengths</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
|
||
<span class="n">max_context_length</span><span class="o">=</span><span class="n">max_input_length</span><span class="p">,</span>
|
||
<span class="n">max_new_tokens</span><span class="o">=</span><span class="n">max_new_tokens</span><span class="p">)</span>
|
||
<span class="n">output_ids</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">decode</span><span class="p">(</span><span class="n">input_ids</span><span class="p">,</span> <span class="n">input_lengths</span><span class="p">,</span> <span class="n">sampling_config</span><span class="p">)</span>
|
||
<span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span>
|
||
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">synchronize</span><span class="p">()</span>
|
||
<span class="k">if</span> <span class="n">runtime_rank</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="n">outputs</span> <span class="o">=</span> <span class="n">output_ids</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">,</span> <span class="p">:]</span>
|
||
<span class="k">return</span> <span class="n">outputs</span></div>
|
||
</div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="MambaLMHeadModelGenerationSession">
|
||
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.MambaLMHeadModelGenerationSession">[docs]</a>
|
||
<span class="k">class</span> <span class="nc">MambaLMHeadModelGenerationSession</span><span class="p">(</span><span class="n">GenerationSession</span><span class="p">):</span>
|
||
|
||
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
|
||
<span class="bp">self</span><span class="p">,</span>
|
||
<span class="n">model_config</span><span class="p">:</span> <span class="n">ModelConfig</span><span class="p">,</span>
|
||
<span class="n">engine_buffer</span><span class="p">,</span>
|
||
<span class="n">mapping</span><span class="p">:</span> <span class="n">Mapping</span><span class="p">,</span>
|
||
<span class="n">debug_mode</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">debug_tensors_to_save</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">cuda_graph_mode</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">stream</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">Stream</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="p">):</span>
|
||
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">model_config</span><span class="p">,</span> <span class="n">ModelConfig</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span> <span class="o">=</span> <span class="n">model_config</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">mapping</span> <span class="o">=</span> <span class="n">mapping</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span> <span class="o">=</span> <span class="n">_Runtime</span><span class="p">(</span><span class="n">engine_buffer</span><span class="p">,</span> <span class="n">mapping</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">device</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s1">'cuda:</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">runtime_rank</span><span class="w"> </span><span class="o">%</span><span class="w"> </span><span class="n">mapping</span><span class="o">.</span><span class="n">gpus_per_node</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
||
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">set_device</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||
<span class="c1"># dynamic_decoder currently use torch's current stream, so must let TRT enqueue use same stream here</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">stream</span> <span class="o">=</span> <span class="n">stream</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">stream</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">stream</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">Stream</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">set_stream</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">stream</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">debug_mode</span> <span class="o">=</span> <span class="n">debug_mode</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">debug_tensors_to_save</span> <span class="o">=</span> <span class="n">debug_tensors_to_save</span>
|
||
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">cuda_graph_mode</span> <span class="o">=</span> <span class="n">cuda_graph_mode</span>
|
||
<span class="c1"># Optional inputs for dynamic decoder</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">top_p_decay</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">top_p_min</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">top_p_reset_ids</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="c1"># TODO: in tensorrt_llm/cpp/tensorrt_llm/thop/dynamicDecodeOp.cpp it's T, can be float or half?</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">embedding_bias_opt</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="c1"># use one more block in paged kv cache.</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">use_one_more_block</span> <span class="o">=</span> <span class="kc">False</span>
|
||
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">buffer_allocated</span> <span class="o">=</span> <span class="kc">False</span>
|
||
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span> <span class="o">=</span> <span class="n">pad_vocab_size</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">vocab_size</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span><span class="p">)</span>
|
||
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">decoder_logits_dtype</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_tensor_dtype</span><span class="p">(</span><span class="s1">'logits'</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder_logits_dtype</span> <span class="ow">not</span> <span class="ow">in</span> <span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">]:</span>
|
||
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
|
||
<span class="s2">"Logits dtype not supported by decoder. Falling back to float32. You may want to change the logits dtype to float16 in your model definition."</span>
|
||
<span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">decoder_logits_dtype</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">dynamic_decoder</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">classes</span><span class="o">.</span><span class="n">trtllm</span><span class="o">.</span><span class="n">DynamicDecodeOp</span><span class="p">(</span>
|
||
<span class="n">model_config</span><span class="o">.</span><span class="n">max_batch_size</span><span class="p">,</span> <span class="n">model_config</span><span class="o">.</span><span class="n">max_beam_width</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">vocab_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder_logits_dtype</span><span class="p">)</span>
|
||
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">gather_tree</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">tensorrt_llm</span><span class="o">.</span><span class="n">gather_tree</span>
|
||
|
||
<span class="n">expected_tensor_names</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">'input_ids'</span><span class="p">]</span>
|
||
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">'logits'</span><span class="p">]</span>
|
||
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">'host_request_types'</span><span class="p">]</span>
|
||
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">'host_context_lengths'</span><span class="p">]</span>
|
||
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">'last_token_ids'</span><span class="p">]</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_state</span><span class="p">:</span>
|
||
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
|
||
<span class="sa">f</span><span class="s1">'conv_state_ptr_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">'</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">)</span>
|
||
<span class="p">]</span>
|
||
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
|
||
<span class="sa">f</span><span class="s1">'ssm_state_ptr_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">'</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">)</span>
|
||
<span class="p">]</span>
|
||
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">'slot_mapping'</span><span class="p">]</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
|
||
<span class="sa">f</span><span class="s1">'past_conv_state_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">'</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">)</span>
|
||
<span class="p">]</span>
|
||
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
|
||
<span class="sa">f</span><span class="s1">'present_conv_state_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">'</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">)</span>
|
||
<span class="p">]</span>
|
||
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
|
||
<span class="sa">f</span><span class="s1">'past_ssm_state_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">'</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">)</span>
|
||
<span class="p">]</span>
|
||
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
|
||
<span class="sa">f</span><span class="s1">'present_ssm_state_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">'</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">)</span>
|
||
<span class="p">]</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span> <span class="o">></span> <span class="mi">1</span> <span class="ow">and</span> <span class="n">model_config</span><span class="o">.</span><span class="n">use_custom_all_reduce</span><span class="p">:</span>
|
||
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">'all_reduce_workspace'</span><span class="p">]</span>
|
||
|
||
<span class="n">found_tensor_names</span> <span class="o">=</span> <span class="p">[</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_name</span><span class="p">(</span><span class="n">i</span><span class="p">)</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">num_io_tensors</span><span class="p">)</span>
|
||
<span class="p">]</span>
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">debug_mode</span> <span class="ow">and</span> <span class="nb">set</span><span class="p">(</span><span class="n">expected_tensor_names</span><span class="p">)</span> <span class="o">!=</span> <span class="nb">set</span><span class="p">(</span>
|
||
<span class="n">found_tensor_names</span><span class="p">):</span>
|
||
<span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s1">'self.remove_input_padding=</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="si">}</span><span class="s1">, self.paged_state=</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">paged_state</span><span class="si">}</span><span class="s1">'</span>
|
||
<span class="p">)</span>
|
||
<span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s2">"The following expected tensors are not found: </span><span class="si">{</span><span class="nb">set</span><span class="p">(</span><span class="n">expected_tensor_names</span><span class="p">)</span><span class="o">.</span><span class="n">difference</span><span class="p">(</span><span class="nb">set</span><span class="p">(</span><span class="n">found_tensor_names</span><span class="p">))</span><span class="si">}</span><span class="s2">"</span>
|
||
<span class="p">)</span>
|
||
<span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s2">"Those tensors in engine are not expected: </span><span class="si">{</span><span class="nb">set</span><span class="p">(</span><span class="n">found_tensor_names</span><span class="p">)</span><span class="o">.</span><span class="n">difference</span><span class="p">(</span><span class="nb">set</span><span class="p">(</span><span class="n">expected_tensor_names</span><span class="p">))</span><span class="si">}</span><span class="s2">"</span>
|
||
<span class="p">)</span>
|
||
<span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Expected tensor names: </span><span class="si">{</span><span class="n">expected_tensor_names</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||
<span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Found tensor names: </span><span class="si">{</span><span class="n">found_tensor_names</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span>
|
||
<span class="s2">"Tensor names in engine are not the same as expected."</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">debug_mode</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">debug_tensors</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span>
|
||
<span class="nb">set</span><span class="p">(</span><span class="n">found_tensor_names</span><span class="p">)</span> <span class="o">-</span> <span class="nb">set</span><span class="p">(</span><span class="n">expected_tensor_names</span><span class="p">))</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span> <span class="nf">mamba_d_state</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">mamba_d_state</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span> <span class="nf">mamba_d_conv</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">mamba_d_conv</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span> <span class="nf">mamba_expand</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">mamba_expand</span>
|
||
|
||
<div class="viewcode-block" id="MambaLMHeadModelGenerationSession.setup">
|
||
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.MambaLMHeadModelGenerationSession.setup">[docs]</a>
|
||
<span class="k">def</span> <span class="nf">setup</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
|
||
<span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">max_context_length</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">max_new_tokens</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">beam_width</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">max_attention_window_size</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">sink_token_length</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">encoder_max_input_length</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">lora_manager</span><span class="p">:</span> <span class="n">LoraManager</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">lora_uids</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">medusa_choices</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
|
||
<span class="c1"># Store these params related to buffer size to check against</span>
|
||
<span class="c1"># the input shape with the params given in decode()</span>
|
||
<span class="k">assert</span> <span class="n">beam_width</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span> <span class="s2">"Only support beam width = 1 now."</span>
|
||
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">=</span> <span class="n">batch_size</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">max_context_length</span> <span class="o">=</span> <span class="n">max_context_length</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">max_new_tokens</span> <span class="o">=</span> <span class="n">max_new_tokens</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span> <span class="o">=</span> <span class="n">max_context_length</span> <span class="o">+</span> <span class="n">max_new_tokens</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">mamba_d_inner</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mamba_expand</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">beam_width</span> <span class="o">=</span> <span class="n">beam_width</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">sink_token_length</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span>
|
||
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span> <span class="o">=</span> <span class="p">{}</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'logits'</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span>
|
||
<span class="p">(</span><span class="n">batch_size</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">)</span> <span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_context_logits</span> <span class="k">else</span>
|
||
<span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">max_context_length</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_tensor_dtype</span><span class="p">(</span><span class="s1">'logits'</span><span class="p">),</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_mamba_conv1d_plugin</span><span class="p">:</span>
|
||
<span class="n">conv_state_shape</span> <span class="o">=</span> <span class="p">(</span>
|
||
<span class="n">batch_size</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">mamba_d_conv</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">mamba_d_inner</span><span class="p">,</span>
|
||
<span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">conv_state_shape</span> <span class="o">=</span> <span class="p">(</span>
|
||
<span class="n">batch_size</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">mamba_d_inner</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">mamba_d_conv</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span>
|
||
<span class="p">)</span>
|
||
|
||
<span class="n">ssm_state_shape</span> <span class="o">=</span> <span class="p">(</span>
|
||
<span class="n">batch_size</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">mamba_d_state</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">mamba_d_inner</span><span class="p">,</span>
|
||
<span class="p">)</span>
|
||
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">):</span>
|
||
<span class="c1"># we need two set of kv cache buffers, one for inputs, and the other for outputs.</span>
|
||
<span class="c1"># They will take turns to act as input and output buffers.</span>
|
||
<span class="n">dtype</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dtype</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">'present_conv_state_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">'</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span>
|
||
<span class="n">conv_state_shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">'1_present_conv_state_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">'</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span>
|
||
<span class="n">conv_state_shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">'present_ssm_state_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">'</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span>
|
||
<span class="n">ssm_state_shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_state</span><span class="p">:</span>
|
||
<span class="n">conv_state_ptr</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span>
|
||
<span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">'present_conv_state_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">'</span><span class="p">]</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">()],</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int64</span><span class="p">,</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="s1">'cpu'</span><span class="p">)</span>
|
||
<span class="n">ssm_state_ptr</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span>
|
||
<span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">'present_ssm_state_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">'</span><span class="p">]</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">()],</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int64</span><span class="p">,</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="s1">'cpu'</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">'conv_state_ptr_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">'</span><span class="p">]</span> <span class="o">=</span> <span class="n">conv_state_ptr</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">'ssm_state_ptr_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">'</span><span class="p">]</span> <span class="o">=</span> <span class="n">ssm_state_ptr</span>
|
||
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">buffer_allocated</span> <span class="o">=</span> <span class="kc">True</span></div>
|
||
|
||
|
||
<span class="k">def</span> <span class="nf">_get_context_shape_buffer</span><span class="p">(</span>
|
||
<span class="bp">self</span><span class="p">,</span>
|
||
<span class="n">input_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">context_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">host_context_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">position_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">last_token_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">attention_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">cross_attention_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">cache_indirection</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">kv_cache_block_pointers</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span>
|
||
<span class="n">host_kv_cache_block_pointers</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span>
|
||
<span class="n">hidden_states_input</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">prompt_embedding_table</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">tasks</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">prompt_vocab_size</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">encoder_output</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">encoder_input_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="n">RuntimeTensor</span><span class="p">]:</span>
|
||
<span class="n">tensors</span> <span class="o">=</span> <span class="p">{}</span>
|
||
|
||
<span class="k">def</span> <span class="nf">sym</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">name</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="n">RuntimeTensor</span><span class="o">.</span><span class="n">from_torch</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span>
|
||
|
||
<span class="k">def</span> <span class="nf">add_tensor</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">name</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="n">tensors</span><span class="o">.</span><span class="n">update</span><span class="p">({</span><span class="n">name</span><span class="p">:</span> <span class="n">sym</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">name</span><span class="p">)})</span>
|
||
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="n">input_ids</span><span class="p">,</span> <span class="s1">'input_ids'</span><span class="p">)</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'logits'</span><span class="p">],</span> <span class="s1">'logits'</span><span class="p">)</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="n">last_token_ids</span><span class="p">,</span> <span class="s1">'last_token_ids'</span><span class="p">)</span>
|
||
|
||
<span class="n">batch_size</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
||
<span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_state</span><span class="p">:</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">'conv_state_ptr_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">],</span>
|
||
<span class="sa">f</span><span class="s1">'conv_state_ptr_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">'ssm_state_ptr_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">],</span>
|
||
<span class="sa">f</span><span class="s1">'ssm_state_ptr_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="c1"># conv state</span>
|
||
<span class="n">dtype</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_tensor_dtype</span><span class="p">(</span><span class="sa">f</span><span class="s1">'present_conv_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_mamba_conv1d_plugin</span><span class="p">:</span>
|
||
<span class="n">conv_state_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mamba_d_conv</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">mamba_d_inner</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">conv_state_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mamba_d_inner</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">mamba_d_conv</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span>
|
||
|
||
<span class="n">conv_state</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">conv_state_shape</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="n">conv_state</span><span class="p">,</span> <span class="sa">f</span><span class="s1">'past_conv_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
||
<span class="n">present</span> <span class="o">=</span> <span class="sa">f</span><span class="s1">'present_conv_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="n">present</span><span class="p">],</span> <span class="n">present</span><span class="p">)</span>
|
||
<span class="c1"># ssm state</span>
|
||
<span class="n">ssm_state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">'present_ssm_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">]</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="n">ssm_state</span><span class="p">,</span> <span class="sa">f</span><span class="s1">'past_ssm_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="n">ssm_state</span><span class="p">,</span> <span class="sa">f</span><span class="s1">'present_ssm_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
||
|
||
<span class="c1"># context request</span>
|
||
<span class="n">host_request_types</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">,</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="s1">'cpu'</span><span class="p">)</span><span class="o">.</span><span class="n">int</span><span class="p">()</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="n">host_request_types</span><span class="p">,</span> <span class="s1">'host_request_types'</span><span class="p">)</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="n">host_context_lengths</span><span class="p">,</span> <span class="s1">'host_context_lengths'</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_state</span><span class="p">:</span>
|
||
<span class="n">slot_mapping</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">batch_size</span><span class="p">,</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="n">slot_mapping</span><span class="p">,</span> <span class="s1">'slot_mapping'</span><span class="p">)</span>
|
||
|
||
<span class="c1"># all reduce</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_custom_all_reduce</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span> <span class="o">></span> <span class="mi">1</span><span class="p">:</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">all_reduce_workspace</span><span class="p">,</span> <span class="s1">'all_reduce_workspace'</span><span class="p">)</span>
|
||
|
||
<span class="k">return</span> <span class="n">tensors</span>
|
||
|
||
<span class="k">def</span> <span class="nf">_get_next_step_shape_buffer</span><span class="p">(</span>
|
||
<span class="bp">self</span><span class="p">,</span>
|
||
<span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">beam_width</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">max_context_length</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">step</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">context_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">host_context_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">position_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">last_token_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">attention_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">cross_attention_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">cache_indirection</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">kv_cache_block_pointers</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span>
|
||
<span class="n">host_kv_cache_block_pointers</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span>
|
||
<span class="n">hidden_states_input</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">prompt_embedding_table</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">tasks</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">prompt_vocab_size</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">encoder_output</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">encoder_input_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
|
||
<span class="n">tensors</span> <span class="o">=</span> <span class="p">{}</span> <span class="c1"># Dict[str, RuntimeTensor]</span>
|
||
|
||
<span class="k">def</span> <span class="nf">sym</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">name</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="n">RuntimeTensor</span><span class="o">.</span><span class="n">from_torch</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span>
|
||
|
||
<span class="k">def</span> <span class="nf">add_tensor</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">name</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="n">tensors</span><span class="o">.</span><span class="n">update</span><span class="p">({</span><span class="n">name</span><span class="p">:</span> <span class="n">sym</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">name</span><span class="p">)})</span>
|
||
|
||
<span class="k">def</span> <span class="nf">add_tensor_with_shape</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">name</span><span class="p">,</span> <span class="n">shape</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="n">tensors</span><span class="o">.</span><span class="n">update</span><span class="p">(</span>
|
||
<span class="p">{</span><span class="n">name</span><span class="p">:</span> <span class="n">RuntimeTensor</span><span class="o">.</span><span class="n">from_torch</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">override_shape</span><span class="o">=</span><span class="n">shape</span><span class="p">)})</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
|
||
<span class="n">input_ids_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">beam_width</span><span class="p">,</span> <span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">input_ids_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">beam_width</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
|
||
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span><span class="p">,</span> <span class="s1">'input_ids'</span><span class="p">,</span> <span class="n">input_ids_shape</span><span class="p">)</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'logits'</span><span class="p">],</span> <span class="s1">'logits'</span><span class="p">)</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="n">last_token_ids</span><span class="p">,</span> <span class="s1">'last_token_ids'</span><span class="p">)</span>
|
||
|
||
<span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_state</span><span class="p">:</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">'conv_state_ptr_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">],</span>
|
||
<span class="sa">f</span><span class="s1">'conv_state_ptr_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">'ssm_state_ptr_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">],</span>
|
||
<span class="sa">f</span><span class="s1">'ssm_state_ptr_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="c1"># conv state</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_mamba_conv1d_plugin</span><span class="p">:</span>
|
||
<span class="n">conv_state_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mamba_d_conv</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">mamba_d_inner</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">conv_state_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mamba_d_inner</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">mamba_d_conv</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">step</span> <span class="o">%</span> <span class="mi">2</span><span class="p">:</span>
|
||
<span class="n">add_tensor_with_shape</span><span class="p">(</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">'1_present_conv_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">],</span>
|
||
<span class="sa">f</span><span class="s1">'past_conv_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">,</span> <span class="n">conv_state_shape</span><span class="p">)</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">'present_conv_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">],</span>
|
||
<span class="sa">f</span><span class="s1">'present_conv_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">add_tensor_with_shape</span><span class="p">(</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">'present_conv_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">],</span>
|
||
<span class="sa">f</span><span class="s1">'past_conv_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">,</span> <span class="n">conv_state_shape</span><span class="p">)</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">'1_present_conv_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">],</span>
|
||
<span class="sa">f</span><span class="s1">'present_conv_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
||
<span class="c1"># ssm state</span>
|
||
<span class="n">ssm_state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">'present_ssm_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">]</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="n">ssm_state</span><span class="p">,</span> <span class="sa">f</span><span class="s1">'past_ssm_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="n">ssm_state</span><span class="p">,</span> <span class="sa">f</span><span class="s1">'present_ssm_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
||
|
||
<span class="c1"># generation requests</span>
|
||
<span class="n">host_request_types</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones_like</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">,</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="s1">'cpu'</span><span class="p">)</span><span class="o">.</span><span class="n">int</span><span class="p">()</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="n">host_request_types</span><span class="p">,</span> <span class="s1">'host_request_types'</span><span class="p">)</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="n">host_context_lengths</span><span class="p">,</span> <span class="s1">'host_context_lengths'</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_state</span><span class="p">:</span>
|
||
<span class="n">slot_mapping</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">batch_size</span><span class="p">,</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="n">slot_mapping</span><span class="p">,</span> <span class="s1">'slot_mapping'</span><span class="p">)</span>
|
||
|
||
<span class="c1"># all reduce</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_custom_all_reduce</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span> <span class="o">></span> <span class="mi">1</span><span class="p">:</span>
|
||
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">all_reduce_workspace</span><span class="p">,</span> <span class="s1">'all_reduce_workspace'</span><span class="p">)</span>
|
||
|
||
<span class="k">return</span> <span class="n">tensors</span>
|
||
|
||
<span class="k">def</span> <span class="nf">_prepare_context_inputs</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span>
|
||
<span class="n">host_context_lengths</span><span class="p">,</span> <span class="n">use_gpt_attention_plugin</span><span class="p">,</span>
|
||
<span class="n">remove_input_padding</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
|
||
|
||
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
|
||
<span class="k">if</span> <span class="n">remove_input_padding</span><span class="p">:</span>
|
||
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">last_token_ids</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">int</span><span class="p">()</span>
|
||
<span class="n">ret</span> <span class="o">=</span> <span class="p">{</span><span class="s1">'last_token_ids'</span><span class="p">:</span> <span class="n">last_token_ids</span><span class="p">}</span>
|
||
<span class="k">return</span> <span class="n">ret</span>
|
||
|
||
<span class="k">def</span> <span class="nf">_prepare_generation_inputs</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span>
|
||
<span class="n">use_gpt_attention_plugin</span><span class="p">,</span>
|
||
<span class="n">remove_input_padding</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
|
||
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones_like</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">remove_input_padding</span><span class="p">:</span>
|
||
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">last_token_ids</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">int</span><span class="p">()</span>
|
||
<span class="n">ret</span> <span class="o">=</span> <span class="p">{</span><span class="s1">'last_token_ids'</span><span class="p">:</span> <span class="n">last_token_ids</span><span class="p">}</span>
|
||
<span class="k">return</span> <span class="n">ret</span>
|
||
|
||
<span class="k">def</span> <span class="nf">_capture_cuda_graph_and_instantiate</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">,</span> <span class="n">stream</span><span class="p">,</span> <span class="n">step</span><span class="p">):</span>
|
||
<span class="n">instance_idx</span> <span class="o">=</span> <span class="p">(</span><span class="n">step</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">%</span> <span class="mi">2</span>
|
||
<span class="c1"># Mamba model I/O shape isn't changed in generation phase</span>
|
||
<span class="c1"># Create two cuda graph once.If cuda graph has already existed, skip it.</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">cuda_graph_instances</span><span class="p">[</span><span class="n">instance_idx</span><span class="p">]</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">return</span>
|
||
<span class="c1"># WAR for TRT 9.x</span>
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="n">preview_trt_version</span><span class="p">()</span> <span class="ow">and</span> <span class="n">step</span> <span class="o"><</span> <span class="mi">3</span><span class="p">:</span>
|
||
<span class="k">return</span>
|
||
<span class="k">return</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">_capture_cuda_graph_and_instantiate</span><span class="p">(</span>
|
||
<span class="n">context</span><span class="p">,</span> <span class="n">stream</span><span class="p">,</span> <span class="n">step</span><span class="p">)</span></div>
|
||
|
||
</pre></div>
|
||
|
||
</div>
|
||
</div>
|
||
<footer>
|
||
|
||
<hr/>
|
||
|
||
<div role="contentinfo">
|
||
<p>© Copyright 2023, NVidia.</p>
|
||
</div>
|
||
|
||
Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
|
||
<a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
|
||
provided by <a href="https://readthedocs.org">Read the Docs</a>.
|
||
|
||
|
||
</footer>
|
||
</div>
|
||
</div>
|
||
</section>
|
||
</div>
|
||
<script>
|
||
jQuery(function () {
|
||
SphinxRtdTheme.Navigation.enable(true);
|
||
});
|
||
</script>
|
||
|
||
</body>
|
||
</html> |