TensorRT-LLMs/_modules/tensorrt_llm/runtime/model_runner.html
2023-12-04 16:26:13 +08:00

563 lines
67 KiB
HTML

<!DOCTYPE html>
<html class="writer-html5" lang="en" data-content_root="../../../">
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>tensorrt_llm.runtime.model_runner &mdash; tensorrt_llm documentation</title>
<link rel="stylesheet" type="text/css" href="../../../_static/pygments.css?v=80d5e7a1" />
<link rel="stylesheet" type="text/css" href="../../../_static/css/theme.css?v=19f00094" />
<!--[if lt IE 9]>
<script src="../../../_static/js/html5shiv.min.js"></script>
<![endif]-->
<script src="../../../_static/jquery.js?v=5d32c60e"></script>
<script src="../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
<script src="../../../_static/documentation_options.js?v=5929fcd5"></script>
<script src="../../../_static/doctools.js?v=888ff710"></script>
<script src="../../../_static/sphinx_highlight.js?v=dc90522c"></script>
<script src="../../../_static/js/theme.js"></script>
<link rel="index" title="Index" href="../../../genindex.html" />
<link rel="search" title="Search" href="../../../search.html" />
</head>
<body class="wy-body-for-nav">
<div class="wy-grid-for-nav">
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
<div class="wy-side-scroll">
<div class="wy-side-nav-search" >
<a href="../../../index.html" class="icon icon-home">
tensorrt_llm
</a>
<div role="search">
<form id="rtd-search-form" class="wy-form" action="../../../search.html" method="get">
<input type="text" name="q" placeholder="Search docs" aria-label="Search docs" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
</div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
<p class="caption" role="heading"><span class="caption-text">Contents:</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../architecture.html">TensorRT-LLM Architecture</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../gpt_runtime.html">C++ GPT Runtime</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../batch_manager.html">The Batch Manager in TensorRT-LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../gpt_attention.html">Multi-head, Multi-query and Group-query Attention</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../precision.html">Numerical Precision</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../performance.html">Performance of TensorRT-LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../2023-05-19-how-to-debug.html">How to debug</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../2023-05-17-how-to-add-a-new-model.html">How to add a new model</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../graph-rewriting.html">Graph Rewriting Module</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../memory.html">Memory Usage of TensorRT-LLM</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../python-api/tensorrt_llm.layers.html">Layers</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../python-api/tensorrt_llm.functional.html">Functionals</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../python-api/tensorrt_llm.models.html">Models</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../python-api/tensorrt_llm.plugin.html">Plugin</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../python-api/tensorrt_llm.quantization.html">Quantization</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../python-api/tensorrt_llm.runtime.html">Runtime</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">C++ API</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../_cpp_gen/runtime.html">Runtime</a></li>
</ul>
</div>
</div>
</nav>
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
<a href="../../../index.html">tensorrt_llm</a>
</nav>
<div class="wy-nav-content">
<div class="rst-content">
<div role="navigation" aria-label="Page navigation">
<ul class="wy-breadcrumbs">
<li><a href="../../../index.html" class="icon icon-home" aria-label="Home"></a></li>
<li class="breadcrumb-item"><a href="../../index.html">Module code</a></li>
<li class="breadcrumb-item active">tensorrt_llm.runtime.model_runner</li>
<li class="wy-breadcrumbs-aside">
</li>
</ul>
<hr/>
</div>
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
<div itemprop="articleBody">
<h1>Source code for tensorrt_llm.runtime.model_runner</h1><div class="highlight"><pre>
<span></span><span class="c1"># SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION &amp; AFFILIATES. All rights reserved.</span>
<span class="c1"># SPDX-License-Identifier: Apache-2.0</span>
<span class="c1">#</span>
<span class="c1"># Licensed under the Apache License, Version 2.0 (the &quot;License&quot;);</span>
<span class="c1"># you may not use this file except in compliance with the License.</span>
<span class="c1"># You may obtain a copy of the License at</span>
<span class="c1">#</span>
<span class="c1"># http://www.apache.org/licenses/LICENSE-2.0</span>
<span class="c1">#</span>
<span class="c1"># Unless required by applicable law or agreed to in writing, software</span>
<span class="c1"># distributed under the License is distributed on an &quot;AS IS&quot; BASIS,</span>
<span class="c1"># WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.</span>
<span class="c1"># See the License for the specific language governing permissions and</span>
<span class="c1"># limitations under the License.</span>
<span class="kn">import</span> <span class="nn">copy</span>
<span class="kn">import</span> <span class="nn">json</span>
<span class="kn">from</span> <span class="nn">pathlib</span> <span class="kn">import</span> <span class="n">Path</span>
<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">List</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Union</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">import</span> <span class="nn">tensorrt_llm</span>
<span class="kn">import</span> <span class="nn">tensorrt_llm.profiler</span> <span class="k">as</span> <span class="nn">profiler</span>
<span class="kn">from</span> <span class="nn">tensorrt_llm.logger</span> <span class="kn">import</span> <span class="n">logger</span>
<span class="kn">from</span> <span class="nn">tensorrt_llm.quantization</span> <span class="kn">import</span> <span class="n">QuantMode</span>
<span class="kn">from</span> <span class="nn">tensorrt_llm.runtime</span> <span class="kn">import</span> <span class="p">(</span><span class="n">ChatGLMGenerationSession</span><span class="p">,</span> <span class="n">GenerationSession</span><span class="p">,</span>
<span class="n">LoraManager</span><span class="p">,</span> <span class="n">ModelConfig</span><span class="p">,</span>
<span class="n">QWenForCausalLMGenerationSession</span><span class="p">,</span>
<span class="n">SamplingConfig</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">get_engine_name</span><span class="p">(</span><span class="n">model</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">dtype</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">tp_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">pp_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">rank</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">str</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Get the serialized engine file name.</span>
<span class="sd"> Args:</span>
<span class="sd"> model (str):</span>
<span class="sd"> Model name, e.g., bloom, gpt.</span>
<span class="sd"> dtype (str):</span>
<span class="sd"> Data type, e.g., float32, float16, bfloat16,</span>
<span class="sd"> tp_size (int):</span>
<span class="sd"> The size of tensor parallel.</span>
<span class="sd"> pp_size (int):</span>
<span class="sd"> The size of pipeline parallel.</span>
<span class="sd"> rank (int):</span>
<span class="sd"> The rank id.</span>
<span class="sd"> Returns:</span>
<span class="sd"> str: The serialized engine file name.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">if</span> <span class="n">pp_size</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
<span class="k">return</span> <span class="s1">&#39;</span><span class="si">{}</span><span class="s1">_</span><span class="si">{}</span><span class="s1">_tp</span><span class="si">{}</span><span class="s1">_rank</span><span class="si">{}</span><span class="s1">.engine&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">dtype</span><span class="p">,</span> <span class="n">tp_size</span><span class="p">,</span> <span class="n">rank</span><span class="p">)</span>
<span class="k">return</span> <span class="s1">&#39;</span><span class="si">{}</span><span class="s1">_</span><span class="si">{}</span><span class="s1">_tp</span><span class="si">{}</span><span class="s1">_pp</span><span class="si">{}</span><span class="s1">_rank</span><span class="si">{}</span><span class="s1">.engine&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">dtype</span><span class="p">,</span> <span class="n">tp_size</span><span class="p">,</span>
<span class="n">pp_size</span><span class="p">,</span> <span class="n">rank</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">read_config</span><span class="p">(</span><span class="n">config_path</span><span class="p">:</span> <span class="n">Path</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">ModelConfig</span><span class="p">,</span> <span class="nb">dict</span><span class="p">]:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Read the engine config file and create a ModelConfig instance, return the ModelConfig instance</span>
<span class="sd"> and other config fields in a dict.</span>
<span class="sd"> Args:</span>
<span class="sd"> config_path (Path):</span>
<span class="sd"> The path of engine config file.</span>
<span class="sd"> Returns:</span>
<span class="sd"> Tuple[ModelConfig, dict]: A ModelConfig instance and other config fields.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">config_path</span><span class="p">,</span> <span class="s1">&#39;r&#39;</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
<span class="n">config</span> <span class="o">=</span> <span class="n">json</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">f</span><span class="p">)</span>
<span class="n">builder_config</span> <span class="o">=</span> <span class="n">config</span><span class="p">[</span><span class="s1">&#39;builder_config&#39;</span><span class="p">]</span>
<span class="n">model_name</span> <span class="o">=</span> <span class="n">builder_config</span><span class="p">[</span><span class="s1">&#39;name&#39;</span><span class="p">]</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="n">builder_config</span><span class="p">[</span><span class="s1">&#39;precision&#39;</span><span class="p">]</span>
<span class="n">tp_size</span> <span class="o">=</span> <span class="n">builder_config</span><span class="p">[</span><span class="s1">&#39;tensor_parallel&#39;</span><span class="p">]</span>
<span class="n">pp_size</span> <span class="o">=</span> <span class="n">builder_config</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;pipeline_parallel&#39;</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">world_size</span> <span class="o">=</span> <span class="n">tp_size</span> <span class="o">*</span> <span class="n">pp_size</span>
<span class="k">assert</span> <span class="n">world_size</span> <span class="o">==</span> <span class="n">tensorrt_llm</span><span class="o">.</span><span class="n">mpi_world_size</span><span class="p">(),</span> \
<span class="sa">f</span><span class="s1">&#39;Engine world size (</span><span class="si">{</span><span class="n">tp_size</span><span class="si">}</span><span class="s1"> * </span><span class="si">{</span><span class="n">pp_size</span><span class="si">}</span><span class="s1">) != Runtime world size (</span><span class="si">{</span><span class="n">tensorrt_llm</span><span class="o">.</span><span class="n">mpi_world_size</span><span class="p">()</span><span class="si">}</span><span class="s1">)&#39;</span>
<span class="n">num_heads</span> <span class="o">=</span> <span class="n">builder_config</span><span class="p">[</span><span class="s1">&#39;num_heads&#39;</span><span class="p">]</span>
<span class="k">assert</span> <span class="n">num_heads</span> <span class="o">%</span> <span class="n">tp_size</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> \
<span class="sa">f</span><span class="s2">&quot;The number of heads (</span><span class="si">{</span><span class="n">num_heads</span><span class="si">}</span><span class="s2">) is not a multiple of tp_size (</span><span class="si">{</span><span class="n">tp_size</span><span class="si">}</span><span class="s2">)&quot;</span>
<span class="n">num_kv_heads</span> <span class="o">=</span> <span class="n">builder_config</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;num_kv_heads&#39;</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">)</span>
<span class="c1"># TODO: multi_query_mode should be removed</span>
<span class="n">multi_query_mode</span> <span class="o">=</span> <span class="n">builder_config</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;multi_query_mode&#39;</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span>
<span class="k">if</span> <span class="n">multi_query_mode</span><span class="p">:</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
<span class="s2">&quot;`multi_query_mode` config is deprecated. Please rebuild the engine.&quot;</span>
<span class="p">)</span>
<span class="c1"># num_kv_heads, if exists in config, should override multi_query_mode</span>
<span class="k">if</span> <span class="n">multi_query_mode</span> <span class="ow">and</span> <span class="p">(</span><span class="s1">&#39;num_kv_heads&#39;</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">builder_config</span><span class="p">):</span>
<span class="n">num_kv_heads</span> <span class="o">=</span> <span class="mi">1</span>
<span class="n">num_heads</span> <span class="o">=</span> <span class="n">num_heads</span> <span class="o">//</span> <span class="n">tp_size</span>
<span class="n">num_kv_heads</span> <span class="o">=</span> <span class="p">(</span><span class="n">num_kv_heads</span> <span class="o">+</span> <span class="n">tp_size</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">//</span> <span class="n">tp_size</span>
<span class="n">hidden_size</span> <span class="o">=</span> <span class="n">builder_config</span><span class="p">[</span><span class="s1">&#39;hidden_size&#39;</span><span class="p">]</span> <span class="o">//</span> <span class="n">tp_size</span>
<span class="n">vocab_size</span> <span class="o">=</span> <span class="n">builder_config</span><span class="p">[</span><span class="s1">&#39;vocab_size&#39;</span><span class="p">]</span>
<span class="n">num_layers</span> <span class="o">=</span> <span class="n">builder_config</span><span class="p">[</span><span class="s1">&#39;num_layers&#39;</span><span class="p">]</span>
<span class="n">cross_attention</span> <span class="o">=</span> <span class="n">builder_config</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;cross_attention&#39;</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span>
<span class="n">has_position_embedding</span> <span class="o">=</span> <span class="n">builder_config</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;has_position_embedding&#39;</span><span class="p">,</span> <span class="kc">True</span><span class="p">)</span>
<span class="n">has_token_type_embedding</span> <span class="o">=</span> <span class="n">builder_config</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;has_token_type_embedding&#39;</span><span class="p">,</span>
<span class="kc">False</span><span class="p">)</span>
<span class="n">gather_all_token_logits</span> <span class="o">=</span> <span class="n">builder_config</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;gather_all_token_logits&#39;</span><span class="p">,</span>
<span class="kc">False</span><span class="p">)</span>
<span class="n">max_prompt_embedding_table_size</span> <span class="o">=</span> <span class="n">builder_config</span><span class="o">.</span><span class="n">get</span><span class="p">(</span>
<span class="s1">&#39;max_prompt_embedding_table_size&#39;</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
<span class="n">quant_mode</span> <span class="o">=</span> <span class="n">QuantMode</span><span class="p">(</span><span class="n">builder_config</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;quant_mode&#39;</span><span class="p">,</span> <span class="mi">0</span><span class="p">))</span>
<span class="n">lora_target_modules</span> <span class="o">=</span> <span class="n">builder_config</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;lora_target_modules&#39;</span><span class="p">)</span>
<span class="n">plugin_config</span> <span class="o">=</span> <span class="n">config</span><span class="p">[</span><span class="s1">&#39;plugin_config&#39;</span><span class="p">]</span>
<span class="n">use_gpt_attention_plugin</span> <span class="o">=</span> <span class="nb">bool</span><span class="p">(</span><span class="n">plugin_config</span><span class="p">[</span><span class="s1">&#39;gpt_attention_plugin&#39;</span><span class="p">])</span>
<span class="n">remove_input_padding</span> <span class="o">=</span> <span class="n">plugin_config</span><span class="p">[</span><span class="s1">&#39;remove_input_padding&#39;</span><span class="p">]</span>
<span class="n">paged_kv_cache</span> <span class="o">=</span> <span class="n">plugin_config</span><span class="p">[</span><span class="s1">&#39;paged_kv_cache&#39;</span><span class="p">]</span>
<span class="n">tokens_per_block</span> <span class="o">=</span> <span class="n">plugin_config</span><span class="p">[</span><span class="s1">&#39;tokens_per_block&#39;</span><span class="p">]</span>
<span class="n">use_custom_all_reduce</span> <span class="o">=</span> <span class="n">plugin_config</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;use_custom_all_reduce&#39;</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span>
<span class="n">lora_plugin</span> <span class="o">=</span> <span class="n">plugin_config</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;lora_plugin&#39;</span><span class="p">)</span>
<span class="n">model_config</span> <span class="o">=</span> <span class="n">ModelConfig</span><span class="p">(</span>
<span class="n">vocab_size</span><span class="o">=</span><span class="n">vocab_size</span><span class="p">,</span>
<span class="n">num_layers</span><span class="o">=</span><span class="n">num_layers</span><span class="p">,</span>
<span class="n">num_heads</span><span class="o">=</span><span class="n">num_heads</span><span class="p">,</span>
<span class="n">num_kv_heads</span><span class="o">=</span><span class="n">num_kv_heads</span><span class="p">,</span>
<span class="n">hidden_size</span><span class="o">=</span><span class="n">hidden_size</span><span class="p">,</span>
<span class="n">gpt_attention_plugin</span><span class="o">=</span><span class="n">use_gpt_attention_plugin</span><span class="p">,</span>
<span class="n">remove_input_padding</span><span class="o">=</span><span class="n">remove_input_padding</span><span class="p">,</span>
<span class="n">model_name</span><span class="o">=</span><span class="n">model_name</span><span class="p">,</span>
<span class="n">paged_kv_cache</span><span class="o">=</span><span class="n">paged_kv_cache</span><span class="p">,</span>
<span class="n">cross_attention</span><span class="o">=</span><span class="n">cross_attention</span><span class="p">,</span>
<span class="n">has_position_embedding</span><span class="o">=</span><span class="n">has_position_embedding</span><span class="p">,</span>
<span class="n">has_token_type_embedding</span><span class="o">=</span><span class="n">has_token_type_embedding</span><span class="p">,</span>
<span class="n">tokens_per_block</span><span class="o">=</span><span class="n">tokens_per_block</span><span class="p">,</span>
<span class="n">max_prompt_embedding_table_size</span><span class="o">=</span><span class="n">max_prompt_embedding_table_size</span><span class="p">,</span>
<span class="n">quant_mode</span><span class="o">=</span><span class="n">quant_mode</span><span class="p">,</span>
<span class="n">gather_all_token_logits</span><span class="o">=</span><span class="n">gather_all_token_logits</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span>
<span class="n">use_custom_all_reduce</span><span class="o">=</span><span class="n">use_custom_all_reduce</span><span class="p">,</span>
<span class="n">lora_plugin</span><span class="o">=</span><span class="n">lora_plugin</span><span class="p">,</span>
<span class="n">lora_target_modules</span><span class="o">=</span><span class="n">lora_target_modules</span><span class="p">)</span>
<span class="n">other_config</span> <span class="o">=</span> <span class="p">{</span>
<span class="s1">&#39;world_size&#39;</span><span class="p">:</span> <span class="n">world_size</span><span class="p">,</span>
<span class="s1">&#39;tp_size&#39;</span><span class="p">:</span> <span class="n">tp_size</span><span class="p">,</span>
<span class="s1">&#39;pp_size&#39;</span><span class="p">:</span> <span class="n">pp_size</span><span class="p">,</span>
<span class="s1">&#39;max_batch_size&#39;</span><span class="p">:</span> <span class="n">builder_config</span><span class="p">[</span><span class="s1">&#39;max_batch_size&#39;</span><span class="p">],</span>
<span class="s1">&#39;max_input_len&#39;</span><span class="p">:</span> <span class="n">builder_config</span><span class="p">[</span><span class="s1">&#39;max_input_len&#39;</span><span class="p">]</span>
<span class="p">}</span>
<span class="k">return</span> <span class="n">model_config</span><span class="p">,</span> <span class="n">other_config</span>
<div class="viewcode-block" id="ModelRunner">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.ModelRunner">[docs]</a>
<span class="k">class</span> <span class="nc">ModelRunner</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> An interface class that wraps GenerationSession and provides generation methods.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">session</span><span class="p">:</span> <span class="n">GenerationSession</span><span class="p">,</span>
<span class="n">max_batch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">max_input_len</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">lora_manager</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">LoraManager</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Create a ModelRunner instance.</span>
<span class="sd"> You are recommended to use the from_dir method to load the engine and create a ModelRunner instance.</span>
<span class="sd"> Args:</span>
<span class="sd"> session (GenerationSession):</span>
<span class="sd"> The TensorRT session created from an engine.</span>
<span class="sd"> max_batch_size (int):</span>
<span class="sd"> The maximum batch size allowed for the input.</span>
<span class="sd"> max_input_len (int):</span>
<span class="sd"> The maximum input length allowed for the input.</span>
<span class="sd"> lora_manager (LoraManager):</span>
<span class="sd"> The LoRA manager to handle LoRA weights.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">session</span> <span class="o">=</span> <span class="n">session</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_batch_size</span> <span class="o">=</span> <span class="n">max_batch_size</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_input_len</span> <span class="o">=</span> <span class="n">max_input_len</span>
<span class="bp">self</span><span class="o">.</span><span class="n">lora_manager</span> <span class="o">=</span> <span class="n">lora_manager</span>
<div class="viewcode-block" id="ModelRunner.from_dir">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.ModelRunner.from_dir">[docs]</a>
<span class="nd">@classmethod</span>
<span class="k">def</span> <span class="nf">from_dir</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span>
<span class="n">engine_dir</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
<span class="n">lora_dir</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">rank</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
<span class="n">debug_mode</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s1">&#39;ModelRunner&#39;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Create a ModelRunner instance from an engine directory.</span>
<span class="sd"> Args:</span>
<span class="sd"> engine_dir (str):</span>
<span class="sd"> The directory that contains the serialized engine files and config files.</span>
<span class="sd"> lora_dir (str):</span>
<span class="sd"> The directory that contains LoRA weights.</span>
<span class="sd"> rank (int):</span>
<span class="sd"> The runtime rank id.</span>
<span class="sd"> debug_mode (int):</span>
<span class="sd"> Whether or not to turn on the debug mode.</span>
<span class="sd"> Returns:</span>
<span class="sd"> ModelRunner: An instance of ModelRunner.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="c1"># session setup</span>
<span class="n">engine_dir</span> <span class="o">=</span> <span class="n">Path</span><span class="p">(</span><span class="n">engine_dir</span><span class="p">)</span>
<span class="n">config_path</span> <span class="o">=</span> <span class="n">engine_dir</span> <span class="o">/</span> <span class="s2">&quot;config.json&quot;</span>
<span class="n">model_config</span><span class="p">,</span> <span class="n">other_config</span> <span class="o">=</span> <span class="n">read_config</span><span class="p">(</span><span class="n">config_path</span><span class="p">)</span>
<span class="n">world_size</span> <span class="o">=</span> <span class="n">other_config</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;world_size&#39;</span><span class="p">)</span>
<span class="n">tp_size</span> <span class="o">=</span> <span class="n">other_config</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;tp_size&#39;</span><span class="p">)</span>
<span class="n">pp_size</span> <span class="o">=</span> <span class="n">other_config</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;pp_size&#39;</span><span class="p">)</span>
<span class="n">runtime_mapping</span> <span class="o">=</span> <span class="n">tensorrt_llm</span><span class="o">.</span><span class="n">Mapping</span><span class="p">(</span><span class="n">world_size</span><span class="o">=</span><span class="n">world_size</span><span class="p">,</span>
<span class="n">rank</span><span class="o">=</span><span class="n">rank</span><span class="p">,</span>
<span class="n">tp_size</span><span class="o">=</span><span class="n">tp_size</span><span class="p">,</span>
<span class="n">pp_size</span><span class="o">=</span><span class="n">pp_size</span><span class="p">)</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">set_device</span><span class="p">(</span><span class="n">rank</span> <span class="o">%</span> <span class="n">runtime_mapping</span><span class="o">.</span><span class="n">gpus_per_node</span><span class="p">)</span>
<span class="n">engine_name</span> <span class="o">=</span> <span class="n">get_engine_name</span><span class="p">(</span><span class="n">model_config</span><span class="o">.</span><span class="n">model_name</span><span class="p">,</span>
<span class="n">model_config</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="n">tp_size</span><span class="p">,</span> <span class="n">pp_size</span><span class="p">,</span>
<span class="n">rank</span><span class="p">)</span>
<span class="n">serialize_path</span> <span class="o">=</span> <span class="n">engine_dir</span> <span class="o">/</span> <span class="n">engine_name</span>
<span class="n">profiler</span><span class="o">.</span><span class="n">start</span><span class="p">(</span><span class="s1">&#39;load tensorrt_llm engine&#39;</span><span class="p">)</span>
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">serialize_path</span><span class="p">,</span> <span class="s1">&#39;rb&#39;</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
<span class="n">engine_buffer</span> <span class="o">=</span> <span class="n">f</span><span class="o">.</span><span class="n">read</span><span class="p">()</span>
<span class="k">if</span> <span class="n">model_config</span><span class="o">.</span><span class="n">model_name</span> <span class="ow">in</span> <span class="p">(</span><span class="s1">&#39;chatglm_6b&#39;</span><span class="p">,</span> <span class="s1">&#39;glm_10b&#39;</span><span class="p">):</span>
<span class="n">session_cls</span> <span class="o">=</span> <span class="n">ChatGLMGenerationSession</span>
<span class="k">elif</span> <span class="n">model_config</span><span class="o">.</span><span class="n">model_name</span> <span class="o">==</span> <span class="s1">&#39;qwen&#39;</span><span class="p">:</span>
<span class="n">session_cls</span> <span class="o">=</span> <span class="n">QWenForCausalLMGenerationSession</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">session_cls</span> <span class="o">=</span> <span class="n">GenerationSession</span>
<span class="n">session</span> <span class="o">=</span> <span class="n">session_cls</span><span class="p">(</span><span class="n">model_config</span><span class="p">,</span>
<span class="n">engine_buffer</span><span class="p">,</span>
<span class="n">runtime_mapping</span><span class="p">,</span>
<span class="n">debug_mode</span><span class="o">=</span><span class="n">debug_mode</span><span class="p">)</span>
<span class="n">profiler</span><span class="o">.</span><span class="n">stop</span><span class="p">(</span><span class="s1">&#39;load tensorrt_llm engine&#39;</span><span class="p">)</span>
<span class="n">loading_time</span> <span class="o">=</span> <span class="n">profiler</span><span class="o">.</span><span class="n">elapsed_time_in_sec</span><span class="p">(</span><span class="s2">&quot;load tensorrt_llm engine&quot;</span><span class="p">)</span>
<span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Load engine takes: </span><span class="si">{</span><span class="n">loading_time</span><span class="si">}</span><span class="s1"> sec&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="n">session</span><span class="o">.</span><span class="n">use_lora_plugin</span><span class="p">:</span>
<span class="k">assert</span> <span class="n">lora_dir</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">,</span> \
<span class="s2">&quot;lora_dir should not be None for engine built with lora_plugin enabled.&quot;</span>
<span class="n">lora_manager</span> <span class="o">=</span> <span class="n">LoraManager</span><span class="p">()</span>
<span class="n">lora_manager</span><span class="o">.</span><span class="n">load_from_hf</span><span class="p">(</span><span class="n">model_dir</span><span class="o">=</span><span class="n">lora_dir</span><span class="p">,</span>
<span class="n">model_config</span><span class="o">=</span><span class="n">model_config</span><span class="p">,</span>
<span class="n">runtime_mapping</span><span class="o">=</span><span class="n">runtime_mapping</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">lora_manager</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">return</span> <span class="bp">cls</span><span class="p">(</span><span class="n">session</span><span class="p">,</span> <span class="n">lora_manager</span><span class="o">=</span><span class="n">lora_manager</span><span class="p">,</span> <span class="o">**</span><span class="n">other_config</span><span class="p">)</span></div>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">remove_input_padding</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">session</span><span class="o">.</span><span class="n">remove_input_padding</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">use_lora_plugin</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">session</span><span class="o">.</span><span class="n">use_lora_plugin</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">max_prompt_embedding_table_size</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">int</span><span class="p">:</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">session</span><span class="o">.</span><span class="n">max_prompt_embedding_table_size</span>
<span class="k">def</span> <span class="nf">_prepare_inputs</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_input_ids</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span>
<span class="n">pad_id</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]:</span>
<span class="c1"># Remove potential additional dim, cast to int32</span>
<span class="n">batch_input_ids</span> <span class="o">=</span> <span class="p">[</span>
<span class="n">x</span><span class="o">.</span><span class="n">flatten</span><span class="p">()</span><span class="o">.</span><span class="n">type</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">batch_input_ids</span>
<span class="p">]</span>
<span class="n">input_lengths</span> <span class="o">=</span> <span class="p">[</span><span class="n">x</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">batch_input_ids</span><span class="p">]</span>
<span class="n">max_length</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="n">input_lengths</span><span class="p">)</span>
<span class="k">if</span> <span class="n">max_length</span> <span class="o">&gt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_input_len</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;Maximum input length (</span><span class="si">{</span><span class="n">max_length</span><span class="si">}</span><span class="s2">) exceeds the engine limit (</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">max_input_len</span><span class="si">}</span><span class="s2">)&quot;</span>
<span class="p">)</span>
<span class="n">batch_size</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">batch_input_ids</span><span class="p">)</span>
<span class="k">if</span> <span class="n">batch_size</span> <span class="o">&gt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_batch_size</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;Input batch size (</span><span class="si">{</span><span class="n">batch_size</span><span class="si">}</span><span class="s2">) exceeds the engine limit (</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">max_batch_size</span><span class="si">}</span><span class="s2">)&quot;</span>
<span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
<span class="n">batch_input_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">concat</span><span class="p">(</span><span class="n">batch_input_ids</span><span class="p">)</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="c1"># Right padding for trt-llm</span>
<span class="n">paddings</span> <span class="o">=</span> <span class="p">[</span>
<span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">max_length</span> <span class="o">-</span> <span class="n">l</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span> <span class="o">*</span> <span class="n">pad_id</span>
<span class="k">for</span> <span class="n">l</span> <span class="ow">in</span> <span class="n">input_lengths</span>
<span class="p">]</span>
<span class="n">batch_input_ids</span> <span class="o">=</span> <span class="p">[</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">x</span><span class="p">,</span> <span class="n">pad</span><span class="p">])</span> <span class="k">for</span> <span class="n">x</span><span class="p">,</span> <span class="n">pad</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">batch_input_ids</span><span class="p">,</span> <span class="n">paddings</span><span class="p">)</span>
<span class="p">]</span>
<span class="n">batch_input_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">batch_input_ids</span><span class="p">)</span>
<span class="n">input_lengths</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">input_lengths</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
<span class="k">return</span> <span class="n">batch_input_ids</span><span class="p">,</span> <span class="n">input_lengths</span>
<span class="k">def</span> <span class="nf">_prepare_outputs</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">outputs</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">dict</span><span class="p">],</span>
<span class="n">input_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">dict</span><span class="p">:</span>
<span class="k">if</span> <span class="n">outputs</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="s1">&#39;context_logits&#39;</span> <span class="ow">in</span> <span class="n">outputs</span><span class="p">:</span>
<span class="n">batch_size</span> <span class="o">=</span> <span class="n">input_lengths</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="n">context_logits</span> <span class="o">=</span> <span class="n">outputs</span><span class="p">[</span><span class="s1">&#39;context_logits&#39;</span><span class="p">]</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
<span class="n">context_logits</span> <span class="o">=</span> <span class="n">context_logits</span><span class="o">.</span><span class="n">flatten</span><span class="p">(</span><span class="n">end_dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="n">seg_points</span> <span class="o">=</span> <span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">+</span> <span class="n">input_lengths</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">tolist</span><span class="p">()</span>
<span class="n">context_logits</span> <span class="o">=</span> <span class="p">[</span>
<span class="n">context_logits</span><span class="p">[</span><span class="n">s</span><span class="p">:</span><span class="n">e</span><span class="p">]</span>
<span class="k">for</span> <span class="n">s</span><span class="p">,</span> <span class="n">e</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">seg_points</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">seg_points</span><span class="p">[</span><span class="mi">1</span><span class="p">:])</span>
<span class="p">]</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">context_logits</span> <span class="o">=</span> <span class="p">[</span>
<span class="n">context_logits</span><span class="p">[</span><span class="n">bidx</span><span class="p">,</span> <span class="p">:</span><span class="n">input_lengths</span><span class="p">[</span><span class="n">bidx</span><span class="p">]]</span>
<span class="k">for</span> <span class="n">bidx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">)</span>
<span class="p">]</span>
<span class="n">outputs</span><span class="p">[</span><span class="s1">&#39;context_logits&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">context_logits</span>
<span class="k">return</span> <span class="n">outputs</span>
<div class="viewcode-block" id="ModelRunner.generate">
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.ModelRunner.generate">[docs]</a>
<span class="k">def</span> <span class="nf">generate</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">batch_input_ids</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span>
<span class="n">sampling_config</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">SamplingConfig</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">prompt_table_path</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">prompt_tasks</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">lora_uids</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">list</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">streaming</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Union</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">dict</span><span class="p">]:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Generates sequences of token ids.</span>
<span class="sd"> The generation-controlling parameters are set in the sampling_config; it will be set to a default one if not passed.</span>
<span class="sd"> You can override any sampling_config&#39;s attributes by passing corresponding parameters.</span>
<span class="sd"> Args:</span>
<span class="sd"> batch_input_ids (List[torch.Tensor]):</span>
<span class="sd"> A list of input id tensors. Each tensor is of shape (sequence_length, ).</span>
<span class="sd"> sampling_config (Optional[SamplingConfig]):</span>
<span class="sd"> The sampling configuration to be used as base parametrization for the generation call.</span>
<span class="sd"> The passed **kwargs matching the sampling_config&#39;s attributes will override them.</span>
<span class="sd"> If the sampling_config is not provided, a default will be used.</span>
<span class="sd"> prompt_table_path (str):</span>
<span class="sd"> The file path of prompt table (.npy format, exported by nemo_prompt_convert.py).</span>
<span class="sd"> prompt_tasks (str):</span>
<span class="sd"> The prompt tuning task ids for the input batch, in format of comma-separated list (e.g., 0,3,1,0).</span>
<span class="sd"> lora_uids (list):</span>
<span class="sd"> The uids of LoRA weights for the input batch. Use -1 to disable the LoRA module.</span>
<span class="sd"> kwargs (Dict[str, Any]:</span>
<span class="sd"> Ad hoc parametrization of sampling_config.</span>
<span class="sd"> The passed **kwargs matching the sampling_config&#39;s attributes will override them.</span>
<span class="sd"> Returns:</span>
<span class="sd"> torch.Tensor or dict:</span>
<span class="sd"> If return_dict=False, the method returns generated output_ids.</span>
<span class="sd"> If return_dict=True, the method returns a dict of output_ids,</span>
<span class="sd"> sequence_lengths (if sampling_config.output_sequence_lengths=True),</span>
<span class="sd"> context_logits and generation_logits (if self.session.gather_all_token_logits=True).</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="c1"># Use sampling_config like HF&#39;s generation_config</span>
<span class="k">if</span> <span class="n">sampling_config</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">sampling_config</span> <span class="o">=</span> <span class="n">SamplingConfig</span><span class="p">(</span><span class="n">end_id</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">pad_id</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">sampling_config</span> <span class="o">=</span> <span class="n">copy</span><span class="o">.</span><span class="n">deepcopy</span><span class="p">(</span><span class="n">sampling_config</span><span class="p">)</span>
<span class="n">sampling_config</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
<span class="n">batch_size</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">batch_input_ids</span><span class="p">)</span>
<span class="n">batch_input_ids</span><span class="p">,</span> <span class="n">input_lengths</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_prepare_inputs</span><span class="p">(</span>
<span class="n">batch_input_ids</span><span class="p">,</span> <span class="n">sampling_config</span><span class="o">.</span><span class="n">pad_id</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_lora_plugin</span><span class="p">:</span>
<span class="k">assert</span> <span class="n">lora_uids</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">,</span> \
<span class="s2">&quot;lora_uids should not be None for engine built with lora_plugin enabled.&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">session</span><span class="o">.</span><span class="n">setup</span><span class="p">(</span>
<span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span>
<span class="n">max_context_length</span><span class="o">=</span><span class="n">input_lengths</span><span class="o">.</span><span class="n">max</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">(),</span>
<span class="n">max_new_tokens</span><span class="o">=</span><span class="n">sampling_config</span><span class="o">.</span><span class="n">max_new_tokens</span><span class="p">,</span>
<span class="n">beam_width</span><span class="o">=</span><span class="n">sampling_config</span><span class="o">.</span><span class="n">num_beams</span><span class="p">,</span>
<span class="n">max_kv_cache_length</span><span class="o">=</span><span class="n">sampling_config</span><span class="o">.</span><span class="n">max_kv_cache_length</span><span class="p">,</span>
<span class="n">lora_manager</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">lora_manager</span><span class="p">,</span>
<span class="n">lora_uids</span><span class="o">=</span><span class="n">lora_uids</span><span class="p">)</span>
<span class="n">batch_input_ids</span> <span class="o">=</span> <span class="n">batch_input_ids</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
<span class="n">input_lengths</span> <span class="o">=</span> <span class="n">input_lengths</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
<span class="n">ptuning_kwargs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_prepare_ptuning</span><span class="p">(</span><span class="n">prompt_table_path</span><span class="p">,</span> <span class="n">prompt_tasks</span><span class="p">,</span>
<span class="n">batch_size</span><span class="p">)</span>
<span class="n">outputs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">session</span><span class="o">.</span><span class="n">decode</span><span class="p">(</span>
<span class="n">batch_input_ids</span><span class="p">,</span>
<span class="n">input_lengths</span><span class="p">,</span>
<span class="n">sampling_config</span><span class="p">,</span>
<span class="n">stop_words_list</span><span class="o">=</span><span class="n">sampling_config</span><span class="o">.</span><span class="n">stop_words_list</span><span class="p">,</span>
<span class="n">bad_words_list</span><span class="o">=</span><span class="n">sampling_config</span><span class="o">.</span><span class="n">bad_words_list</span><span class="p">,</span>
<span class="n">output_sequence_lengths</span><span class="o">=</span><span class="n">sampling_config</span><span class="o">.</span><span class="n">output_sequence_lengths</span><span class="p">,</span>
<span class="n">return_dict</span><span class="o">=</span><span class="n">sampling_config</span><span class="o">.</span><span class="n">return_dict</span><span class="p">,</span>
<span class="n">streaming</span><span class="o">=</span><span class="n">streaming</span><span class="p">,</span>
<span class="o">**</span><span class="n">ptuning_kwargs</span><span class="p">)</span>
<span class="k">if</span> <span class="n">sampling_config</span><span class="o">.</span><span class="n">return_dict</span><span class="p">:</span>
<span class="k">if</span> <span class="n">streaming</span><span class="p">:</span>
<span class="n">outputs</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_prepare_outputs</span><span class="p">(</span><span class="n">curr_outputs</span><span class="p">,</span> <span class="n">input_lengths</span><span class="p">)</span>
<span class="k">for</span> <span class="n">curr_outputs</span> <span class="ow">in</span> <span class="n">outputs</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">outputs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_prepare_outputs</span><span class="p">(</span><span class="n">outputs</span><span class="p">,</span> <span class="n">input_lengths</span><span class="p">)</span>
<span class="k">return</span> <span class="n">outputs</span></div>
<span class="k">def</span> <span class="nf">_prepare_ptuning</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">prompt_table_path</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">tasks</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
<span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_prompt_embedding_table_size</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="k">return</span> <span class="p">{}</span>
<span class="k">if</span> <span class="n">prompt_table_path</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">prompt_table</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">from_numpy</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">prompt_table_path</span><span class="p">))</span>
<span class="n">_</span><span class="p">,</span> <span class="n">task_vocab_size</span><span class="p">,</span> <span class="n">hidden_size</span> <span class="o">=</span> <span class="n">prompt_table</span><span class="o">.</span><span class="n">size</span><span class="p">()</span>
<span class="n">task_vocab_size</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="n">task_vocab_size</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
<span class="n">prompt_table</span> <span class="o">=</span> <span class="n">prompt_table</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">)</span>
<span class="n">prompt_table</span> <span class="o">=</span> <span class="n">prompt_table</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">session</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">prompt_table</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">([</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">session</span><span class="o">.</span><span class="n">hidden_size</span><span class="p">])</span>
<span class="n">task_vocab_size</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="mi">1</span><span class="p">])</span>
<span class="k">if</span> <span class="n">tasks</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">tasks</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">t</span><span class="p">)</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="n">tasks</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s1">&#39;,&#39;</span><span class="p">)],</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">tasks</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">==</span> <span class="n">batch_size</span><span class="p">,</span> \
<span class="sa">f</span><span class="s2">&quot;Number of supplied tasks (</span><span class="si">{</span><span class="n">tasks</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="si">}</span><span class="s2">) must match input batch size (</span><span class="si">{</span><span class="n">batch_size</span><span class="si">}</span><span class="s2">)&quot;</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">tasks</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">batch_size</span><span class="p">])</span>
<span class="k">return</span> <span class="p">{</span>
<span class="s1">&#39;prompt_embedding_table&#39;</span><span class="p">:</span> <span class="n">prompt_table</span><span class="o">.</span><span class="n">cuda</span><span class="p">(),</span>
<span class="s1">&#39;tasks&#39;</span><span class="p">:</span> <span class="n">tasks</span><span class="o">.</span><span class="n">cuda</span><span class="p">(),</span>
<span class="s1">&#39;prompt_vocab_size&#39;</span><span class="p">:</span> <span class="n">task_vocab_size</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
<span class="p">}</span></div>
</pre></div>
</div>
</div>
<footer>
<hr/>
<div role="contentinfo">
<p>&#169; Copyright 2023, NVidia.</p>
</div>
Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
<a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
provided by <a href="https://readthedocs.org">Read the Docs</a>.
</footer>
</div>
</div>
</section>
</div>
<script>
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
</script>
</body>
</html>